diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 41957e30a2..b4c327ec09 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -85,6 +85,10 @@ TAGS, SESSION_DEFAULT_S3_BUCKET_PATH, SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, + FEATURE_GROUP, + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, ) # Setting LOGGER for backward compatibility, in case users import it... @@ -1607,6 +1611,222 @@ def delete_endpoint_config(self, endpoint_config_name): logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name) self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + def delete_feature_group(self, feature_group_name): + """Delete an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to delete. + """ + logger.info("Deleting feature group with name: %s", feature_group_name) + self.sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name) + + def create_feature_group( + self, + feature_group_name, + record_identifier_name, + event_time_feature_name, + feature_definitions, + role_arn=None, + online_store_config=None, + offline_store_config=None, + throughput_config=None, + description=None, + tags=None, + ): + """Create an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + record_identifier_name (str): Name of the record identifier feature. + event_time_feature_name (str): Name of the event time feature. + feature_definitions (list): List of feature definitions. + role_arn (str): ARN of the role used to execute the API (default: None). + Resolved from SageMaker Config if not provided. + online_store_config (dict): Online store configuration (default: None). + offline_store_config (dict): Offline store configuration (default: None). + throughput_config (dict): Throughput configuration (default: None). + description (str): Description of the Feature Group (default: None). + tags (Optional[Tags]): Tags for labeling the Feature Group (default: None). + + Returns: + dict: Response from the CreateFeatureGroup API. + """ + tags = format_tags(tags) + tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) + ) + + role_arn = resolve_value_from_config( + role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self + ) + + inferred_online_store_config = update_nested_dictionary_with_values_from_config( + online_store_config, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + sagemaker_session=self, + ) + if inferred_online_store_config is not None: + # OnlineStore should be handled differently because if you set KmsKeyId, then you + # need to set EnableOnlineStore key as well + inferred_online_store_config["EnableOnlineStore"] = True + + inferred_offline_store_config = update_nested_dictionary_with_values_from_config( + offline_store_config, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, + sagemaker_session=self, + ) + + kwargs = dict( + FeatureGroupName=feature_group_name, + RecordIdentifierFeatureName=record_identifier_name, + EventTimeFeatureName=event_time_feature_name, + FeatureDefinitions=feature_definitions, + RoleArn=role_arn, + ) + update_args( + kwargs, + OnlineStoreConfig=inferred_online_store_config, + OfflineStoreConfig=inferred_offline_store_config, + ThroughputConfig=throughput_config, + Description=description, + Tags=tags, + ) + + logger.info("Creating feature group with name: %s", feature_group_name) + return self.sagemaker_client.create_feature_group(**kwargs) + + def describe_feature_group(self, feature_group_name, next_token=None): + """Describe an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to describe. + next_token (str): A token for paginated results (default: None). + + Returns: + dict: Response from the DescribeFeatureGroup API. + """ + args = {"FeatureGroupName": feature_group_name} + update_args(args, NextToken=next_token) + return self.sagemaker_client.describe_feature_group(**args) + + def update_feature_group( + self, + feature_group_name, + feature_additions=None, + online_store_config=None, + throughput_config=None, + ): + """Update an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to update. + feature_additions (list): List of feature definitions to add (default: None). + online_store_config (dict): Online store configuration updates (default: None). + throughput_config (dict): Throughput configuration updates (default: None). + + Returns: + dict: Response from the UpdateFeatureGroup API. + """ + args = {"FeatureGroupName": feature_group_name} + update_args( + args, + FeatureAdditions=feature_additions, + OnlineStoreConfig=online_store_config, + ThroughputConfig=throughput_config, + ) + return self.sagemaker_client.update_feature_group(**args) + + def list_feature_groups( + self, + name_contains=None, + feature_group_status_equals=None, + offline_store_status_equals=None, + creation_time_after=None, + creation_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ): + """List Amazon SageMaker Feature Groups. + + Args: + name_contains (str): Filter by name substring (default: None). + feature_group_status_equals (str): Filter by status (default: None). + offline_store_status_equals (str): Filter by offline store status (default: None). + creation_time_after (datetime): Filter by creation time lower bound (default: None). + creation_time_before (datetime): Filter by creation time upper bound (default: None). + sort_order (str): Sort order, 'Ascending' or 'Descending' (default: None). + sort_by (str): Sort by field (default: None). + max_results (int): Maximum number of results (default: None). + next_token (str): Pagination token (default: None). + + Returns: + dict: Response from the ListFeatureGroups API. + """ + args = {} + update_args( + args, + NameContains=name_contains, + FeatureGroupStatusEquals=feature_group_status_equals, + OfflineStoreStatusEquals=offline_store_status_equals, + CreationTimeAfter=creation_time_after, + CreationTimeBefore=creation_time_before, + SortOrder=sort_order, + SortBy=sort_by, + MaxResults=max_results, + NextToken=next_token, + ) + return self.sagemaker_client.list_feature_groups(**args) + + def update_feature_metadata( + self, + feature_group_name, + feature_name, + description=None, + parameter_additions=None, + parameter_removals=None, + ): + """Update metadata for a feature in an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + feature_name (str): Name of the feature to update metadata for. + description (str): Updated description for the feature (default: None). + parameter_additions (list): Parameters to add (default: None). + parameter_removals (list): Parameters to remove (default: None). + + Returns: + dict: Response from the UpdateFeatureMetadata API. + """ + args = { + "FeatureGroupName": feature_group_name, + "FeatureName": feature_name, + } + update_args( + args, + Description=description, + ParameterAdditions=parameter_additions, + ParameterRemovals=parameter_removals, + ) + return self.sagemaker_client.update_feature_metadata(**args) + + def describe_feature_metadata(self, feature_group_name, feature_name): + """Describe metadata for a feature in an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + feature_name (str): Name of the feature to describe metadata for. + + Returns: + dict: Response from the DescribeFeatureMetadata API. + """ + return self.sagemaker_client.describe_feature_metadata( + FeatureGroupName=feature_group_name, + FeatureName=feature_name, + ) + def wait_for_optimization_job(self, job, poll=5): """Wait for an Amazon SageMaker Optimization job to complete. diff --git a/sagemaker-core/tests/unit/session/test_session_helper.py b/sagemaker-core/tests/unit/session/test_session_helper.py index ca4fd81aa8..7e2004c1d0 100644 --- a/sagemaker-core/tests/unit/session/test_session_helper.py +++ b/sagemaker-core/tests/unit/session/test_session_helper.py @@ -29,6 +29,11 @@ update_args, NOTEBOOK_METADATA_FILE, ) +from sagemaker.core.config.config_schema import ( + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, +) class TestSession: @@ -451,3 +456,543 @@ def test_update_args_with_none_values(self): assert args["existing"] == "value" assert "new_key" not in args assert args["another_key"] == "another_value" + +class TestFeatureGroupSessionMethods: + """Test cases for Feature Group session methods""" + + @pytest.fixture + def session_with_mock_client(self): + """Create a Session with a mocked sagemaker_client.""" + mock_boto_session = Mock() + mock_boto_session.region_name = "us-west-2" + mock_boto_session.client.return_value = Mock() + mock_boto_session.resource.return_value = Mock() + session = Session(boto_session=mock_boto_session) + session.sagemaker_client = Mock() + return session + + # --- delete_feature_group --- + + def test_delete_feature_group(self, session_with_mock_client): + """Test delete_feature_group delegates to sagemaker_client.""" + session = session_with_mock_client + session.delete_feature_group("my-feature-group") + + session.sagemaker_client.delete_feature_group.assert_called_once_with( + FeatureGroupName="my-feature-group" + ) + + # --- describe_feature_group --- + + def test_describe_feature_group(self, session_with_mock_client): + """Test describe_feature_group delegates and returns response.""" + session = session_with_mock_client + expected = {"FeatureGroupName": "my-fg", "CreationTime": "2024-01-01"} + session.sagemaker_client.describe_feature_group.return_value = expected + + result = session.describe_feature_group("my-fg") + + session.sagemaker_client.describe_feature_group.assert_called_once_with( + FeatureGroupName="my-fg" + ) + assert result == expected + + def test_describe_feature_group_with_next_token(self, session_with_mock_client): + """Test describe_feature_group includes NextToken when provided.""" + session = session_with_mock_client + session.sagemaker_client.describe_feature_group.return_value = {} + + session.describe_feature_group("my-fg", next_token="abc123") + + session.sagemaker_client.describe_feature_group.assert_called_once_with( + FeatureGroupName="my-fg", NextToken="abc123" + ) + + def test_describe_feature_group_omits_none_next_token(self, session_with_mock_client): + """Test describe_feature_group omits NextToken when None.""" + session = session_with_mock_client + session.sagemaker_client.describe_feature_group.return_value = {} + + session.describe_feature_group("my-fg", next_token=None) + + call_kwargs = session.sagemaker_client.describe_feature_group.call_args[1] + assert "NextToken" not in call_kwargs + + # --- update_feature_group --- + + def test_update_feature_group_all_params(self, session_with_mock_client): + """Test update_feature_group with all optional params provided.""" + session = session_with_mock_client + expected = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123:feature-group/my-fg"} + session.sagemaker_client.update_feature_group.return_value = expected + + additions = [{"FeatureName": "new_feat", "FeatureType": "String"}] + online_cfg = {"EnableOnlineStore": True} + throughput_cfg = {"ThroughputMode": "OnDemand"} + + result = session.update_feature_group( + "my-fg", + feature_additions=additions, + online_store_config=online_cfg, + throughput_config=throughput_cfg, + ) + + session.sagemaker_client.update_feature_group.assert_called_once_with( + FeatureGroupName="my-fg", + FeatureAdditions=additions, + OnlineStoreConfig=online_cfg, + ThroughputConfig=throughput_cfg, + ) + assert result == expected + + def test_update_feature_group_omits_none_params(self, session_with_mock_client): + """Test update_feature_group omits None optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_group.return_value = {} + + session.update_feature_group("my-fg") + + call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] + assert call_kwargs == {"FeatureGroupName": "my-fg"} + + def test_update_feature_group_partial_params(self, session_with_mock_client): + """Test update_feature_group with only some optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_group.return_value = {} + + throughput_cfg = {"ThroughputMode": "Provisioned"} + session.update_feature_group("my-fg", throughput_config=throughput_cfg) + + call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "ThroughputConfig": throughput_cfg, + } + + # --- list_feature_groups --- + + def test_list_feature_groups_no_params(self, session_with_mock_client): + """Test list_feature_groups with no filters delegates with empty args.""" + session = session_with_mock_client + expected = {"FeatureGroupSummaries": []} + session.sagemaker_client.list_feature_groups.return_value = expected + + result = session.list_feature_groups() + + session.sagemaker_client.list_feature_groups.assert_called_once_with() + assert result == expected + + def test_list_feature_groups_all_params(self, session_with_mock_client): + """Test list_feature_groups with all params provided.""" + session = session_with_mock_client + session.sagemaker_client.list_feature_groups.return_value = {} + + session.list_feature_groups( + name_contains="test", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after="2024-01-01", + creation_time_before="2024-12-31", + sort_order="Ascending", + sort_by="Name", + max_results=10, + next_token="token123", + ) + + session.sagemaker_client.list_feature_groups.assert_called_once_with( + NameContains="test", + FeatureGroupStatusEquals="Created", + OfflineStoreStatusEquals="Active", + CreationTimeAfter="2024-01-01", + CreationTimeBefore="2024-12-31", + SortOrder="Ascending", + SortBy="Name", + MaxResults=10, + NextToken="token123", + ) + + def test_list_feature_groups_omits_none_params(self, session_with_mock_client): + """Test list_feature_groups omits None params.""" + session = session_with_mock_client + session.sagemaker_client.list_feature_groups.return_value = {} + + session.list_feature_groups(name_contains="test", max_results=5) + + call_kwargs = session.sagemaker_client.list_feature_groups.call_args[1] + assert call_kwargs == {"NameContains": "test", "MaxResults": 5} + + # --- update_feature_metadata --- + + def test_update_feature_metadata_all_params(self, session_with_mock_client): + """Test update_feature_metadata with all optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + additions = [{"Key": "team", "Value": "ml"}] + removals = [{"Key": "deprecated"}] + + result = session.update_feature_metadata( + "my-fg", + "my-feature", + description="Updated desc", + parameter_additions=additions, + parameter_removals=removals, + ) + + session.sagemaker_client.update_feature_metadata.assert_called_once_with( + FeatureGroupName="my-fg", + FeatureName="my-feature", + Description="Updated desc", + ParameterAdditions=additions, + ParameterRemovals=removals, + ) + assert result == {} + + def test_update_feature_metadata_omits_none_params(self, session_with_mock_client): + """Test update_feature_metadata omits None optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + session.update_feature_metadata("my-fg", "my-feature") + + call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "FeatureName": "my-feature", + } + + def test_update_feature_metadata_partial_params(self, session_with_mock_client): + """Test update_feature_metadata with only description.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + session.update_feature_metadata("my-fg", "my-feature", description="New desc") + + call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "FeatureName": "my-feature", + "Description": "New desc", + } + + # --- describe_feature_metadata --- + + def test_describe_feature_metadata(self, session_with_mock_client): + """Test describe_feature_metadata delegates and returns response.""" + session = session_with_mock_client + expected = {"FeatureGroupName": "my-fg", "FeatureName": "my-feature"} + session.sagemaker_client.describe_feature_metadata.return_value = expected + + result = session.describe_feature_metadata("my-fg", "my-feature") + + session.sagemaker_client.describe_feature_metadata.assert_called_once_with( + FeatureGroupName="my-fg", FeatureName="my-feature" + ) + assert result == expected + +MODULE = "sagemaker.core.helper.session_helper" + + +class TestCreateFeatureGroup: + """Test cases for create_feature_group session method.""" + + @pytest.fixture + def session(self): + """Create a Session with a mocked sagemaker_client.""" + mock_boto_session = Mock() + mock_boto_session.region_name = "us-west-2" + mock_boto_session.client.return_value = Mock() + mock_boto_session.resource.return_value = Mock() + session = Session(boto_session=mock_boto_session) + session.sagemaker_client = Mock() + return session + + @pytest.fixture + def base_args(self): + """Minimal required arguments for create_feature_group.""" + return dict( + feature_group_name="my-fg", + record_identifier_name="record_id", + event_time_feature_name="event_time", + feature_definitions=[{"FeatureName": "f1", "FeatureType": "String"}], + ) + + # --- Full parameter pass-through --- + + def test_create_feature_group_all_params(self, session, base_args): + """Test that all parameters are passed through to sagemaker_client.""" + role = "arn:aws:iam::123456789012:role/Role" + online_cfg = {"SecurityConfig": {"KmsKeyId": "key-123"}} + offline_cfg = {"S3StorageConfig": {"S3Uri": "s3://bucket"}} + throughput_cfg = {"ThroughputMode": "ON_DEMAND"} + description = "My feature group" + tags = [{"Key": "team", "Value": "ml"}] + + expected_response = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-fg"} + session.sagemaker_client.create_feature_group.return_value = expected_response + + with patch(f"{MODULE}.format_tags", return_value=tags) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=tags) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=tags), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=role), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", side_effect=[online_cfg, offline_cfg]): + + result = session.create_feature_group( + **base_args, + role_arn=role, + online_store_config=online_cfg, + offline_store_config=offline_cfg, + throughput_config=throughput_cfg, + description=description, + tags=tags, + ) + + assert result == expected_response + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["FeatureGroupName"] == "my-fg" + assert call_kwargs["RecordIdentifierFeatureName"] == "record_id" + assert call_kwargs["EventTimeFeatureName"] == "event_time" + assert call_kwargs["FeatureDefinitions"] == base_args["feature_definitions"] + assert call_kwargs["RoleArn"] == role + # EnableOnlineStore is set to True when online config is inferred + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + assert call_kwargs["OfflineStoreConfig"] == offline_cfg + assert call_kwargs["ThroughputConfig"] == throughput_cfg + assert call_kwargs["Description"] == description + assert call_kwargs["Tags"] == tags + + # --- Tag processing pipeline --- + + def test_tag_processing_pipeline_order(self, session, base_args): + """Test that tags go through format_tags -> _append_project_tags -> _append_sagemaker_config_tags.""" + raw_tags = {"team": "ml"} + formatted = [{"Key": "team", "Value": "ml"}] + with_project = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}] + with_config = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}, {"Key": "cfg", "Value": "v"}] + + with patch(f"{MODULE}.format_tags", return_value=formatted) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=with_project) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=with_config) as mock_cfg, \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, tags=raw_tags) + + # format_tags is called with the raw input + mock_format.assert_called_once_with(raw_tags) + # _append_project_tags receives the formatted tags + mock_proj.assert_called_once_with(formatted) + # _append_sagemaker_config_tags receives the project-appended tags + mock_cfg.assert_called_once_with(with_project, "SageMaker.FeatureGroup.Tags") + + # Final tags in the API call should be the config-appended tags + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["Tags"] == with_config + + def test_tags_none_still_processed(self, session, base_args): + """Test that None tags still go through the pipeline (format_tags handles None).""" + with patch(f"{MODULE}.format_tags", return_value=None) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=None) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, tags=None) + + mock_format.assert_called_once_with(None) + mock_proj.assert_called_once_with(None) + # Tags=None should be omitted from the API call via update_args + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "Tags" not in call_kwargs + + # --- role_arn resolution from config --- + + def test_role_arn_resolved_from_config_when_none(self, session, base_args): + """Test that role_arn is resolved from SageMaker Config when not provided.""" + config_role = "arn:aws:iam::123456789012:role/ConfigRole" + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=config_role) as mock_resolve, \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, role_arn=None) + + mock_resolve.assert_called_once_with( + None, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["RoleArn"] == config_role + + def test_role_arn_passed_through_when_provided(self, session, base_args): + """Test that an explicit role_arn is passed to resolve_value_from_config (which returns it).""" + explicit_role = "arn:aws:iam::123456789012:role/ExplicitRole" + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=explicit_role) as mock_resolve, \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, role_arn=explicit_role) + + mock_resolve.assert_called_once_with( + explicit_role, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["RoleArn"] == explicit_role + + # --- online_store_config merging and EnableOnlineStore --- + + def test_online_store_config_merged_and_enable_set(self, session, base_args): + """Test that online_store_config is merged from config and EnableOnlineStore=True is set.""" + inferred_online = {"SecurityConfig": {"KmsKeyId": "config-key"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[inferred_online, None]) as mock_update: + + session.create_feature_group(**base_args, online_store_config=None) + + # First call is for online store config + mock_update.assert_any_call( + None, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + assert call_kwargs["OnlineStoreConfig"]["SecurityConfig"]["KmsKeyId"] == "config-key" + + def test_online_store_config_none_when_no_config(self, session, base_args): + """Test that OnlineStoreConfig is omitted when config returns None.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "OnlineStoreConfig" not in call_kwargs + + def test_online_store_config_explicit_gets_enable_set(self, session, base_args): + """Test that explicitly provided online_store_config also gets EnableOnlineStore=True.""" + explicit_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} + # update_nested_dictionary returns the merged result + merged_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[merged_online, None]): + + session.create_feature_group(**base_args, online_store_config=explicit_online) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + + # --- offline_store_config merging --- + + def test_offline_store_config_merged_from_config(self, session, base_args): + """Test that offline_store_config is merged from SageMaker Config.""" + inferred_offline = {"S3StorageConfig": {"S3Uri": "s3://config-bucket"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[None, inferred_offline]) as mock_update: + + session.create_feature_group(**base_args, offline_store_config=None) + + # Second call is for offline store config + mock_update.assert_any_call( + None, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OfflineStoreConfig"] == inferred_offline + + def test_offline_store_config_none_when_no_config(self, session, base_args): + """Test that OfflineStoreConfig is omitted when config returns None.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "OfflineStoreConfig" not in call_kwargs + + # --- None optional parameters omitted --- + + def test_none_optional_params_omitted(self, session, base_args): + """Test that None optional params (throughput, description, tags) are omitted from API call.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "ThroughputConfig" not in call_kwargs + assert "Description" not in call_kwargs + assert "Tags" not in call_kwargs + assert "OnlineStoreConfig" not in call_kwargs + assert "OfflineStoreConfig" not in call_kwargs + # Required params should still be present + assert "FeatureGroupName" in call_kwargs + assert "RecordIdentifierFeatureName" in call_kwargs + assert "EventTimeFeatureName" in call_kwargs + assert "FeatureDefinitions" in call_kwargs + assert "RoleArn" in call_kwargs + + def test_partial_optional_params(self, session, base_args): + """Test that only provided optional params appear in the API call.""" + throughput = {"ThroughputMode": "ON_DEMAND"} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group( + **base_args, + throughput_config=throughput, + description="test desc", + ) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["ThroughputConfig"] == throughput + assert call_kwargs["Description"] == "test desc" + assert "Tags" not in call_kwargs + assert "OnlineStoreConfig" not in call_kwargs + assert "OfflineStoreConfig" not in call_kwargs + + # --- Return value --- + + def test_returns_api_response(self, session, base_args): + """Test that the method returns the sagemaker_client response.""" + expected = {"FeatureGroupArn": "arn:fg"} + session.sagemaker_client.create_feature_group.return_value = expected + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + result = session.create_feature_group(**base_args) + + assert result == expected diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md b/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md new file mode 100644 index 0000000000..40942fa6f3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md @@ -0,0 +1,513 @@ +# SageMaker FeatureStore V2 to V3 Migration Guide + +## Overview + +V3 uses **sagemaker-core** as the foundation, which provides: +- Pydantic-based shapes with automatic serialization +- Resource classes that manage boto clients internally +- No need for explicit Session management in most cases + +## File Mapping + +| V2 File | V3 File | Notes | +|---------|---------|-------| +| `feature_group.py` | Re-exported from `sagemaker_core.main.resources` | No wrapper class needed | +| `feature_store.py` | Re-exported from `sagemaker_core.main.resources` | `FeatureStore.search()` available | +| `feature_definition.py` | `feature_definition.py` | Helper factories retained | +| `feature_utils.py` | `feature_utils.py` | Standalone functions | +| `inputs.py` | `inputs.py` | Enums only (shapes from core) | +| `dataset_builder.py` | `dataset_builder.py` | Converted to dataclass | +| N/A | `athena_query.py` | Extracted from feature_group.py | +| N/A | `ingestion_manager_pandas.py` | Extracted from feature_group.py | + +--- + +## FeatureGroup Operations + +### Create FeatureGroup + +**V2:** +```python +from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.session import Session + +session = Session() +fg = FeatureGroup(name="my-fg", sagemaker_session=session) +fg.load_feature_definitions(data_frame=df) +fg.create( + s3_uri="s3://bucket/prefix", + record_identifier_name="id", + event_time_feature_name="ts", + role_arn=role, + enable_online_store=True, +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ( + FeatureGroup, + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + load_feature_definitions_from_dataframe, +) + +feature_defs = load_feature_definitions_from_dataframe(df) + +FeatureGroup.create( + feature_group_name="my-fg", + feature_definitions=feature_defs, + record_identifier_feature_name="id", + event_time_feature_name="ts", + role_arn=role, + online_store_config=OnlineStoreConfig(enable_online_store=True), + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/prefix") + ), +) +``` + +### Get/Describe FeatureGroup + +**V2:** +```python +fg = FeatureGroup(name="my-fg", sagemaker_session=session) +response = fg.describe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup + +fg = FeatureGroup.get(feature_group_name="my-fg") +# fg is now a typed object with attributes: +# fg.feature_group_name, fg.feature_definitions, fg.offline_store_config, etc. +``` + +### Delete FeatureGroup + +**V2:** +```python +fg.delete() +``` + +**V3:** +```python +FeatureGroup(feature_group_name="my-fg").delete() +# or +fg = FeatureGroup.get(feature_group_name="my-fg") +fg.delete() +``` + +### Update FeatureGroup + +**V2:** +```python +fg.update( + feature_additions=[FeatureDefinition("new_col", FeatureTypeEnum.STRING)], + throughput_config=ThroughputConfigUpdate(mode=ThroughputModeEnum.ON_DEMAND), +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup, ThroughputConfig + +fg = FeatureGroup.get(feature_group_name="my-fg") +fg.update( + feature_additions=[{"FeatureName": "new_col", "FeatureType": "String"}], + throughput_config=ThroughputConfig(throughput_mode="OnDemand"), +) +``` + +--- + +## Record Operations + +### Put Record + +**V2:** +```python +from sagemaker.feature_store.inputs import FeatureValue + +fg.put_record( + record=[ + FeatureValue(feature_name="id", value_as_string="123"), + FeatureValue(feature_name="name", value_as_string="John"), + ], + target_stores=[TargetStoreEnum.ONLINE_STORE], +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup, FeatureValue + +FeatureGroup(feature_group_name="my-fg").put_record( + record=[ + FeatureValue(feature_name="id", value_as_string="123"), + FeatureValue(feature_name="name", value_as_string="John"), + ], + target_stores=["OnlineStore"], # strings, not enums +) +``` + +### Get Record + +**V2:** +```python +response = fg.get_record(record_identifier_value_as_string="123") +``` + +**V3:** +```python +response = FeatureGroup(feature_group_name="my-fg").get_record( + record_identifier_value_as_string="123" +) +``` + +### Delete Record + +**V2:** +```python +fg.delete_record( + record_identifier_value_as_string="123", + event_time="2024-01-15T00:00:00Z", + deletion_mode=DeletionModeEnum.SOFT_DELETE, +) +``` + +**V3:** +```python +FeatureGroup(feature_group_name="my-fg").delete_record( + record_identifier_value_as_string="123", + event_time="2024-01-15T00:00:00Z", + deletion_mode="SoftDelete", # string, not enum +) +``` + +### Batch Get Record + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore +from sagemaker.feature_store.inputs import Identifier + +fs = FeatureStore(sagemaker_session=session) +response = fs.batch_get_record( + identifiers=[ + Identifier(feature_group_name="my-fg", record_identifiers_value_as_string=["123", "456"]) + ] +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup + +response = FeatureGroup(feature_group_name="my-fg").batch_get_record( + identifiers=[ + {"FeatureGroupName": "my-fg", "RecordIdentifiersValueAsString": ["123", "456"]} + ] +) +``` + +--- + +## DataFrame Ingestion + +**V2:** +```python +fg.ingest(data_frame=df, max_workers=4, max_processes=2, wait=True) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ingest_dataframe + +manager = ingest_dataframe( + feature_group_name="my-fg", + data_frame=df, + max_workers=4, + max_processes=2, + wait=True, +) +# Access failed rows: manager.failed_rows +``` + +--- + +## Athena Query + +**V2:** +```python +query = fg.athena_query() +query.run(query_string="SELECT * FROM ...", output_location="s3://...") +query.wait() +df = query.as_dataframe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import create_athena_query + +query = create_athena_query("my-fg", session) +query.run(query_string="SELECT * FROM ...", output_location="s3://...") +query.wait() +df = query.as_dataframe() +``` + +--- + +## Hive DDL Generation + +**V2:** +```python +ddl = fg.as_hive_ddl(database="mydb", table_name="mytable") +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import as_hive_ddl + +ddl = as_hive_ddl("my-fg", database="mydb", table_name="mytable") +``` + +--- + +## Feature Definitions + +**V2:** +```python +fg.load_feature_definitions(data_frame=df) +# Modifies fg.feature_definitions in place +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import load_feature_definitions_from_dataframe + +defs = load_feature_definitions_from_dataframe(df) +# Returns list, doesn't modify any object +``` + +### Using Helper Factories + +**V2 & V3 (same):** +```python +from sagemaker.mlops.feature_store import ( + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, +) + +defs = [ + IntegralFeatureDefinition("id"), + StringFeatureDefinition("name"), + FractionalFeatureDefinition("embedding", VectorCollectionType(128)), +] +``` + +--- + +## Search + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore +from sagemaker.feature_store.inputs import Filter, ResourceEnum + +fs = FeatureStore(sagemaker_session=session) +response = fs.search( + resource=ResourceEnum.FEATURE_GROUP, + filters=[Filter(name="FeatureGroupName", value="my-prefix", operator=FilterOperatorEnum.CONTAINS)], +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureStore, Filter, SearchExpression + +response = FeatureStore.search( + resource="FeatureGroup", + search_expression=SearchExpression( + filters=[Filter(name="FeatureGroupName", value="my-prefix", operator="Contains")] + ), +) +``` + +--- + +## Feature Metadata + +**V2:** +```python +fg.describe_feature_metadata(feature_name="my-feature") +fg.update_feature_metadata(feature_name="my-feature", description="Updated desc") +fg.list_parameters_for_feature_metadata(feature_name="my-feature") +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureMetadata + +# Get metadata +metadata = FeatureMetadata.get(feature_group_name="my-fg", feature_name="my-feature") +print(metadata.description) +print(metadata.parameters) + +# Update metadata +metadata.update(description="Updated desc") +``` + +--- + +## Dataset Builder + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore + +fs = FeatureStore(sagemaker_session=session) +builder = fs.create_dataset( + base=fg, + output_path="s3://bucket/output", +) +builder.with_feature_group(other_fg, target_feature_name_in_base="id") +builder.point_in_time_accurate_join() +df, query = builder.to_dataframe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import create_dataset, FeatureGroup + +fg = FeatureGroup.get(feature_group_name="my-fg") +other_fg = FeatureGroup.get(feature_group_name="other-fg") + +builder = create_dataset( + base=fg, + output_path="s3://bucket/output", + session=session, +) +builder.with_feature_group(other_fg, target_feature_name_in_base="id") +builder.point_in_time_accurate_join() +df, query = builder.to_dataframe() +``` + +--- + +## Config Objects (Shapes) + +**V2:** +```python +from sagemaker.feature_store.inputs import ( + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + TtlDuration, +) + +config = OnlineStoreConfig(enable_online_store=True, ttl_duration=TtlDuration(unit="Hours", value=24)) +config.to_dict() # Manual serialization required +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ( + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + TtlDuration, +) + +config = OnlineStoreConfig(enable_online_store=True, ttl_duration=TtlDuration(unit="Hours", value=24)) +# No to_dict() needed - Pydantic handles serialization automatically +``` + +--- + +## Key Differences Summary + +| Aspect | V2 | V3 | +|--------|----|----| +| **Session** | Required for most operations | Optional - core manages clients | +| **FeatureGroup** | Wrapper class with session | Direct core resource class | +| **Shapes** | `@attr.s` with `to_dict()` | Pydantic with auto-serialization | +| **Enums** | `TargetStoreEnum.ONLINE_STORE.value` | Just use strings: `"OnlineStore"` | +| **Methods** | Instance methods on FeatureGroup | Standalone functions + core methods | +| **Ingestion** | `fg.ingest(df)` | `ingest_dataframe(name, df)` | +| **Athena** | `fg.athena_query()` | `create_athena_query(name, session)` | +| **DDL** | `fg.as_hive_ddl()` | `as_hive_ddl(name)` | +| **Feature Defs** | `fg.load_feature_definitions(df)` | `load_feature_definitions_from_dataframe(df)` | +| **Imports** | Multiple modules | Single `__init__.py` re-exports all | + +--- + +## Missing in V3 (Intentionally) + +These V2 features are **not wrapped** because core provides them directly: + +- `FeatureGroup.create()` - use `FeatureGroup.create()` from core +- `FeatureGroup.delete()` - use `FeatureGroup(...).delete()` from core +- `FeatureGroup.describe()` - use `FeatureGroup.get()` from core (returns typed object) +- `FeatureGroup.update()` - use `FeatureGroup(...).update()` from core +- `FeatureGroup.put_record()` - use `FeatureGroup(...).put_record()` from core +- `FeatureGroup.get_record()` - use `FeatureGroup(...).get_record()` from core +- `FeatureGroup.delete_record()` - use `FeatureGroup(...).delete_record()` from core +- `FeatureGroup.batch_get_record()` - use `FeatureGroup(...).batch_get_record()` from core +- `FeatureStore.search()` - use `FeatureStore.search()` from core +- `FeatureStore.list_feature_groups()` - use `FeatureGroup.get_all()` from core +- All config shapes (`OnlineStoreConfig`, etc.) - re-exported from core + +--- + +## Import Cheatsheet + +```python +# V3 - Everything from one place +from sagemaker.mlops.feature_store import ( + # Resources (from core) + FeatureGroup, + FeatureStore, + FeatureMetadata, + + # Shapes (from core) + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + TtlDuration, + FeatureValue, + FeatureParameter, + ThroughputConfig, + Filter, + SearchExpression, + + # Enums (local) + TargetStoreEnum, + OnlineStoreStorageTypeEnum, + TableFormatEnum, + DeletionModeEnum, + ThroughputModeEnum, + + # Feature Definition helpers (local) + FeatureDefinition, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, + + # Utility functions (local) + create_athena_query, + as_hive_ddl, + load_feature_definitions_from_dataframe, + ingest_dataframe, + create_dataset, + + # Classes (local) + DatasetBuilder, +) +``` diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py new file mode 100644 index 0000000000..f15d6d3845 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -0,0 +1,125 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""SageMaker FeatureStore V3 - powered by sagemaker-core.""" + +# Resources from core +from sagemaker.core.resources import FeatureGroup, FeatureMetadata + +# Shapes from core (Pydantic - no to_dict() needed) +from sagemaker.core.shapes import ( + DataCatalogConfig, + FeatureParameter, + FeatureValue, + Filter, + OfflineStoreConfig, + OnlineStoreConfig, + OnlineStoreSecurityConfig, + S3StorageConfig, + SearchExpression, + ThroughputConfig, + TtlDuration, +) + +# Enums (local - core uses strings) +from sagemaker.mlops.feature_store.inputs import ( + DeletionModeEnum, + ExpirationTimeResponseEnum, + FilterOperatorEnum, + OnlineStoreStorageTypeEnum, + ResourceEnum, + SearchOperatorEnum, + SortOrderEnum, + TableFormatEnum, + TargetStoreEnum, + ThroughputModeEnum, +) + +# Feature Definition helpers (local) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, + CollectionTypeEnum, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + ListCollectionType, + SetCollectionType, + VectorCollectionType, +) + +# Utility functions (local) +from sagemaker.mlops.feature_store.feature_utils import ( + as_hive_ddl, + create_athena_query, + get_session_from_role, + ingest_dataframe, + load_feature_definitions_from_dataframe, +) + +# Classes (local) +from sagemaker.mlops.feature_store.athena_query import AthenaQuery +from sagemaker.mlops.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + JoinComparatorEnum, + JoinTypeEnum, + TableType, +) +from sagemaker.mlops.feature_store.ingestion_manager_pandas import ( + IngestionError, + IngestionManagerPandas, +) + +__all__ = [ + # Resources + "FeatureGroup", + "FeatureMetadata", + # Shapes + "DataCatalogConfig", + "FeatureParameter", + "FeatureValue", + "Filter", + "OfflineStoreConfig", + "OnlineStoreConfig", + "OnlineStoreSecurityConfig", + "S3StorageConfig", + "SearchExpression", + "ThroughputConfig", + "TtlDuration", + # Enums + "DeletionModeEnum", + "ExpirationTimeResponseEnum", + "FilterOperatorEnum", + "OnlineStoreStorageTypeEnum", + "ResourceEnum", + "SearchOperatorEnum", + "SortOrderEnum", + "TableFormatEnum", + "TargetStoreEnum", + "ThroughputModeEnum", + # Feature Definitions + "FeatureDefinition", + "FeatureTypeEnum", + "CollectionTypeEnum", + "FractionalFeatureDefinition", + "IntegralFeatureDefinition", + "StringFeatureDefinition", + "ListCollectionType", + "SetCollectionType", + "VectorCollectionType", + # Utility functions + "as_hive_ddl", + "create_athena_query", + "get_session_from_role", + "ingest_dataframe", + "load_feature_definitions_from_dataframe", + # Classes + "AthenaQuery", + "DatasetBuilder", + "FeatureGroupToBeMerged", + "IngestionError", + "IngestionManagerPandas", + "JoinComparatorEnum", + "JoinTypeEnum", + "TableType", +] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py new file mode 100644 index 0000000000..123b3c4305 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py @@ -0,0 +1,112 @@ +import os +import tempfile +from dataclasses import dataclass, field +from typing import Any, Dict +from urllib.parse import urlparse +import pandas as pd +from pandas import DataFrame + +from sagemaker.mlops.feature_store.feature_utils import ( + start_query_execution, + get_query_execution, + wait_for_athena_query, + download_athena_query_result, +) + +from sagemaker.core.helper.session_helper import Session + +@dataclass +class AthenaQuery: + """Class to manage querying of feature store data with AWS Athena. + + This class instantiates a AthenaQuery object that is used to retrieve data from feature store + via standard SQL queries. + + Attributes: + catalog (str): name of the data catalog. + database (str): name of the database. + table_name (str): name of the table. + sagemaker_session (Session): instance of the Session class to perform boto calls. + """ + + catalog: str + database: str + table_name: str + sagemaker_session: Session + _current_query_execution_id: str = field(default=None, init=False) + _result_bucket: str = field(default=None, init=False) + _result_file_prefix: str = field(default=None, init=False) + + def run( + self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None + ) -> str: + """Execute a SQL query given a query string, output location and kms key. + + This method executes the SQL query using Athena and outputs the results to output_location + and returns the execution id of the query. + + Args: + query_string: SQL query string. + output_location: S3 URI of the query result. + kms_key: KMS key id. If set, will be used to encrypt the query result file. + workgroup (str): The name of the workgroup in which the query is being started. + + Returns: + Execution id of the query. + """ + response = start_query_execution( + session=self.sagemaker_session, + catalog=self.catalog, + database=self.database, + query_string=query_string, + output_location=output_location, + kms_key=kms_key, + workgroup=workgroup, + ) + + self._current_query_execution_id = response["QueryExecutionId"] + parsed_result = urlparse(output_location, allow_fragments=False) + self._result_bucket = parsed_result.netloc + self._result_file_prefix = parsed_result.path.strip("/") + return self._current_query_execution_id + + def wait(self): + """Wait for the current query to finish.""" + wait_for_athena_query(self.sagemaker_session, self._current_query_execution_id) + + def get_query_execution(self) -> Dict[str, Any]: + """Get execution status of the current query. + + Returns: + Response dict from Athena. + """ + return get_query_execution(self.sagemaker_session, self._current_query_execution_id) + + def as_dataframe(self, **kwargs) -> DataFrame: + """Download the result of the current query and load it into a DataFrame. + + Args: + **kwargs (object): key arguments used for the method pandas.read_csv to be able to + have a better tuning on data. For more info read: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html + + Returns: + A pandas DataFrame contains the query result. + """ + state = self.get_query_execution()["QueryExecution"]["Status"]["State"] + if state != "SUCCEEDED": + if state in ("QUEUED", "RUNNING"): + raise RuntimeError(f"Query {self._current_query_execution_id} still executing.") + raise RuntimeError(f"Query {self._current_query_execution_id} failed.") + + output_file = os.path.join(tempfile.gettempdir(), f"{self._current_query_execution_id}.csv") + download_athena_query_result( + session=self.sagemaker_session, + bucket=self._result_bucket, + prefix=self._result_file_prefix, + query_execution_id=self._current_query_execution_id, + filename=output_file, + ) + kwargs.pop("delimiter", None) + return pd.read_csv(output_file, delimiter=",", **kwargs) + diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py new file mode 100644 index 0000000000..72e9535320 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py @@ -0,0 +1,768 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Dataset Builder for FeatureStore.""" +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Union +import datetime + +import pandas as pd + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store import FeatureGroup +from sagemaker.mlops.feature_store.feature_definition import FeatureDefinition, FeatureTypeEnum +from sagemaker.mlops.feature_store.feature_utils import ( + upload_dataframe_to_s3, + download_csv_from_s3, + run_athena_query, +) + +_DEFAULT_CATALOG = "AwsDataCatalog" +_DEFAULT_DATABASE = "sagemaker_featurestore" + +_DTYPE_TO_FEATURE_TYPE = { + "object": "String", "string": "String", + "int64": "Integral", "int32": "Integral", + "float64": "Fractional", "float32": "Fractional", +} + +_DTYPE_TO_ATHENA_TYPE = { + "object": "STRING", "int64": "INT", "float64": "DOUBLE", + "bool": "BOOLEAN", "datetime64[ns]": "TIMESTAMP", +} + + +class TableType(Enum): + FEATURE_GROUP = "FeatureGroup" + DATA_FRAME = "DataFrame" + + +class JoinTypeEnum(Enum): + INNER_JOIN = "JOIN" + LEFT_JOIN = "LEFT JOIN" + RIGHT_JOIN = "RIGHT JOIN" + FULL_JOIN = "FULL JOIN" + CROSS_JOIN = "CROSS JOIN" + + +class JoinComparatorEnum(Enum): + EQUALS = "=" + GREATER_THAN = ">" + GREATER_THAN_OR_EQUAL_TO = ">=" + LESS_THAN = "<" + LESS_THAN_OR_EQUAL_TO = "<=" + NOT_EQUAL_TO = "<>" + + +@dataclass +class FeatureGroupToBeMerged: + """FeatureGroup metadata which will be used for SQL join. + + This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names, + a list of feature names which will be included in SQL query, a database, an Athena table name, + a feature name of record identifier, a feature name of event time identifier and a feature name + of base which is the target join key. + + Attributes: + features (List[str]): A list of strings representing feature names of this FeatureGroup. + included_feature_names (List[str]): A list of strings representing features to be + included in the SQL join. + projected_feature_names (List[str]): A list of strings representing features to be + included for final projection in output. + catalog (str): A string representing the catalog. + database (str): A string representing the database. + table_name (str): A string representing the Athena table name of this FeatureGroup. + record_identifier_feature_name (str): A string representing the record identifier feature. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + target_feature_name_in_base (str): A string representing the feature name in base which will + be used as target join key (default: None). + table_type (TableType): A TableType representing the type of table if it is Feature Group or + Panda Data Frame (default: None). + feature_name_in_target (str): A string representing the feature name in the target feature + group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + """ + features: List[str] + included_feature_names: List[str] + projected_feature_names: List[str] + catalog: str + database: str + table_name: str + record_identifier_feature_name: str + event_time_identifier_feature: FeatureDefinition + target_feature_name_in_base: str = None + table_type: TableType = None + feature_name_in_target: str = None + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN + + +def construct_feature_group_to_be_merged( + target_feature_group: FeatureGroup, + included_feature_names: List[str], + target_feature_name_in_base: str = None, + feature_name_in_target: str = None, + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS, + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN, +) -> FeatureGroupToBeMerged: + """Construct a FeatureGroupToBeMerged object by provided parameters. + + Args: + target_feature_group (FeatureGroup): A FeatureGroup object. + included_feature_names (List[str]): A list of strings representing features to be + included in the output. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + feature_name_in_target (str): A string representing the feature name in the target feature + group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + + Returns: + A FeatureGroupToBeMerged object. + + Raises: + RuntimeError: No metastore is configured with the FeatureGroup. + ValueError: Invalid feature name(s) in included_feature_names. + """ + fg = FeatureGroup.get(feature_group_name=target_feature_group.feature_group_name) + + if not fg.offline_store_config or not fg.offline_store_config.data_catalog_config: + raise RuntimeError(f"No metastore configured for FeatureGroup {fg.feature_group_name}.") + + catalog_config = fg.offline_store_config.data_catalog_config + disable_glue = catalog_config.disable_glue_table_creation or False + + features = [fd.feature_name for fd in fg.feature_definitions] + record_id = fg.record_identifier_feature_name + event_time_name = fg.event_time_feature_name + event_time_type = next( + (fd.feature_type for fd in fg.feature_definitions if fd.feature_name == event_time_name), + None + ) + + if feature_name_in_target and feature_name_in_target not in features: + raise ValueError(f"Feature {feature_name_in_target} not found in {fg.feature_group_name}") + + for feat in included_feature_names or []: + if feat not in features: + raise ValueError(f"Feature {feat} not found in {fg.feature_group_name}") + + if not included_feature_names: + included_feature_names = features.copy() + projected_feature_names = features.copy() + else: + projected_feature_names = included_feature_names.copy() + if record_id not in included_feature_names: + included_feature_names.append(record_id) + if event_time_name not in included_feature_names: + included_feature_names.append(event_time_name) + + return FeatureGroupToBeMerged( + features=features, + included_feature_names=included_feature_names, + projected_feature_names=projected_feature_names, + catalog=catalog_config.catalog if disable_glue else _DEFAULT_CATALOG, + database=catalog_config.database, + table_name=catalog_config.table_name, + record_identifier_feature_name=record_id, + event_time_identifier_feature=FeatureDefinition( + feature_name=event_time_name, feature_type=FeatureTypeEnum(event_time_type).value + ), + target_feature_name_in_base=target_feature_name_in_base, + table_type=TableType.FEATURE_GROUP, + feature_name_in_target=feature_name_in_target, + join_comparator=join_comparator, + join_type=join_type, + ) + + +@dataclass +class DatasetBuilder: + """DatasetBuilder definition. + + This class instantiates a DatasetBuilder object that comprises a base, a list of feature names, + an output path and a KMS key ID. + + Attributes: + _sagemaker_session (Session): Session instance to perform boto calls. + _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset. + _output_path (str): An S3 URI which stores the output .csv file. + _record_identifier_feature_name (str): A string representing the record identifier feature + if base is a DataFrame (default: None). + _event_time_identifier_feature_name (str): A string representing the event time identifier + feature if base is a DataFrame (default: None). + _included_feature_names (List[str]): A list of strings representing features to be + included in the output. If not set, all features will be included in the output. + (default: None). + _kms_key_id (str): A KMS key id. If set, will be used to encrypt the result file + (default: None). + _point_in_time_accurate_join (bool): A boolean representing if point-in-time join + is applied to the resulting dataframe when calling "to_dataframe". + When set to True, users can retrieve data using "row-level time travel" + according to the event times provided to the DatasetBuilder. This requires that the + entity dataframe with event times is submitted as the base in the constructor + (default: False). + _include_duplicated_records (bool): A boolean representing whether the resulting dataframe + when calling "to_dataframe" should include duplicated records (default: False). + _include_deleted_records (bool): A boolean representing whether the resulting + dataframe when calling "to_dataframe" should include deleted records (default: False). + _number_of_recent_records (int): An integer representing how many records will be + returned for each record identifier (default: 1). + _number_of_records (int): An integer representing the number of records that should be + returned in the resulting dataframe when calling "to_dataframe" (default: None). + _write_time_ending_timestamp (datetime.datetime): A datetime that represents the latest + write time for a record to be included in the resulting dataset. Records with a + newer write time will be omitted from the resulting dataset. (default: None). + _event_time_starting_timestamp (datetime.datetime): A datetime that represents the earliest + event time for a record to be included in the resulting dataset. Records + with an older event time will be omitted from the resulting dataset. (default: None). + _event_time_ending_timestamp (datetime.datetime): A datetime that represents the latest + event time for a record to be included in the resulting dataset. Records + with a newer event time will be omitted from the resulting dataset. (default: None). + _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of + FeatureGroupToBeMerged which will be joined to base (default: []). + _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the + type of event time identifier feature (default: None). + """ + + _sagemaker_session: Session + _base: Union[FeatureGroup, pd.DataFrame] + _output_path: str + _record_identifier_feature_name: str = None + _event_time_identifier_feature_name: str = None + _included_feature_names: List[str] = None + _kms_key_id: str = None + _event_time_identifier_feature_type: FeatureTypeEnum = None + + _point_in_time_accurate_join: bool = field(default=False, init=False) + _include_duplicated_records: bool = field(default=False, init=False) + _include_deleted_records: bool = field(default=False, init=False) + _number_of_recent_records: int = field(default=None, init=False) + _number_of_records: int = field(default=None, init=False) + _write_time_ending_timestamp: datetime.datetime = field(default=None, init=False) + _event_time_starting_timestamp: datetime.datetime = field(default=None, init=False) + _event_time_ending_timestamp: datetime.datetime = field(default=None, init=False) + _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = field(default_factory=list, init=False) + + @classmethod + def create( + cls, + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + session: Session, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: List[str] = None, + kms_key_id: str = None, + ) -> "DatasetBuilder": + """Create a DatasetBuilder for generating a Dataset. + + Args: + base: A FeatureGroup or DataFrame to use as the base. + output_path: S3 URI for output. + session: SageMaker session. + record_identifier_feature_name: Required if base is DataFrame. + event_time_identifier_feature_name: Required if base is DataFrame. + included_feature_names: Features to include in output. + kms_key_id: KMS key for encryption. + + Returns: + DatasetBuilder instance. + """ + if isinstance(base, pd.DataFrame): + if not record_identifier_feature_name or not event_time_identifier_feature_name: + raise ValueError( + "record_identifier_feature_name and event_time_identifier_feature_name " + "are required when base is a DataFrame." + ) + return cls( + _sagemaker_session=session, + _base=base, + _output_path=output_path, + _record_identifier_feature_name=record_identifier_feature_name, + _event_time_identifier_feature_name=event_time_identifier_feature_name, + _included_feature_names=included_feature_names, + _kms_key_id=kms_key_id, + ) + + def with_feature_group( + self, + feature_group: FeatureGroup, + target_feature_name_in_base: str = None, + included_feature_names: List[str] = None, + feature_name_in_target: str = None, + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS, + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN, + ) -> "DatasetBuilder": + """Join FeatureGroup with base. + + Args: + feature_group (FeatureGroup): A target FeatureGroup which will be joined to base. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as a join key (default: None). + included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + feature_name_in_target (str): A string representing the feature name in the target + feature group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + + Returns: + This DatasetBuilder object. + """ + self._feature_groups_to_be_merged.append( + construct_feature_group_to_be_merged( + feature_group, included_feature_names, target_feature_name_in_base, + feature_name_in_target, join_comparator, join_type, + ) + ) + return self + + def point_in_time_accurate_join(self) -> "DatasetBuilder": + """Enable point-in-time accurate join. + + Returns: + This DatasetBuilder object. + """ + self._point_in_time_accurate_join = True + return self + + def include_duplicated_records(self) -> "DatasetBuilder": + """Include duplicated records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_duplicated_records = True + return self + + def include_deleted_records(self) -> "DatasetBuilder": + """Include deleted records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_deleted_records = True + return self + + def with_number_of_recent_records_by_record_identifier(self, n: int) -> "DatasetBuilder": + """Set number_of_recent_records field with provided input. + + Args: + n (int): An int that how many recent records will be returned for + each record identifier. + + Returns: + This DatasetBuilder object. + """ + self._number_of_recent_records = n + return self + + def with_number_of_records_from_query_results(self, n: int) -> "DatasetBuilder": + """Set number_of_records field with provided input. + + Args: + n (int): An int that how many records will be returned. + + Returns: + This DatasetBuilder object. + """ + self._number_of_records = n + return self + + def as_of(self, timestamp: datetime.datetime) -> "DatasetBuilder": + """Set write_time_ending_timestamp field with provided input. + + Args: + timestamp (datetime.datetime): A datetime that all records' write time in dataset will + be before it. + + Returns: + This DatasetBuilder object. + """ + self._write_time_ending_timestamp = timestamp + return self + + def with_event_time_range( + self, + starting_timestamp: datetime.datetime = None, + ending_timestamp: datetime.datetime = None, + ) -> "DatasetBuilder": + """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs. + + Args: + starting_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be after it (default: None). + ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset + will be before it (default: None). + + Returns: + This DatasetBuilder object. + """ + self._event_time_starting_timestamp = starting_timestamp + self._event_time_ending_timestamp = ending_timestamp + return self + + def to_csv_file(self) -> tuple[str, str]: + """Get query string and result in .csv format file. + + Returns: + The S3 path of the .csv file. + The query string executed. + """ + if isinstance(self._base, pd.DataFrame): + return self._to_csv_from_dataframe() + if isinstance(self._base, FeatureGroup): + return self._to_csv_from_feature_group() + raise ValueError("Base must be either a FeatureGroup or a DataFrame.") + + def to_dataframe(self) -> tuple[pd.DataFrame, str]: + """Get query string and result in pandas.DataFrame. + + Returns: + The pandas.DataFrame object. + The query string executed. + """ + csv_file, query_string = self.to_csv_file() + df = download_csv_from_s3(csv_file, self._sagemaker_session, self._kms_key_id) + if "row_recent" in df.columns: + df = df.drop("row_recent", axis="columns") + return df, query_string + + + def _to_csv_from_dataframe(self) -> tuple[str, str]: + s3_folder, temp_table_name = upload_dataframe_to_s3( + self._base, self._output_path, self._sagemaker_session, self._kms_key_id + ) + self._create_temp_table(temp_table_name, s3_folder) + + base_features = list(self._base.columns) + event_time_dtype = str(self._base[self._event_time_identifier_feature_name].dtypes) + self._event_time_identifier_feature_type = FeatureTypeEnum( + _DTYPE_TO_FEATURE_TYPE.get(event_time_dtype, "String") + ) + + included = self._included_feature_names or base_features + fg_to_merge = FeatureGroupToBeMerged( + features=base_features, + included_feature_names=included, + projected_feature_names=included, + catalog=_DEFAULT_CATALOG, + database=_DEFAULT_DATABASE, + table_name=temp_table_name, + record_identifier_feature_name=self._record_identifier_feature_name, + event_time_identifier_feature=FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + table_type=TableType.DATA_FRAME, + ) + + query_string = self._construct_query_string(fg_to_merge) + result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + return self._extract_result(result) + + def _to_csv_from_feature_group(self) -> tuple[str, str]: + base_fg = construct_feature_group_to_be_merged(self._base, self._included_feature_names) + self._record_identifier_feature_name = base_fg.record_identifier_feature_name + self._event_time_identifier_feature_name = base_fg.event_time_identifier_feature.feature_name + self._event_time_identifier_feature_type = base_fg.event_time_identifier_feature.feature_type + + query_string = self._construct_query_string(base_fg) + result = self._run_query(query_string, base_fg.catalog, base_fg.database) + return self._extract_result(result) + + def _extract_result(self, query_result: dict) -> tuple[str, str]: + execution = query_result.get("QueryExecution", {}) + return ( + execution.get("ResultConfiguration", {}).get("OutputLocation"), + execution.get("Query"), + ) + + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: + return run_athena_query( + session=self._sagemaker_session, + catalog=catalog, + database=database, + query_string=query_string, + output_location=self._output_path, + kms_key=self._kms_key_id, + ) + + def _create_temp_table(self, temp_table_name: str, s3_folder: str): + columns = ", ".join( + f"{col} {_DTYPE_TO_ATHENA_TYPE.get(str(self._base[col].dtypes), 'STRING')}" + for col in self._base.columns + ) + serde = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"' + query = ( + f"CREATE EXTERNAL TABLE {temp_table_name} ({columns}) " + f"ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + f"WITH SERDEPROPERTIES ({serde}) LOCATION '{s3_folder}';" + ) + self._run_query(query, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + + + def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: + base_query = self._construct_table_query(base, "base") + query = f"WITH fg_base AS ({base_query})" + + for i, fg in enumerate(self._feature_groups_to_be_merged): + fg_query = self._construct_table_query(fg, str(i)) + query += f",\nfg_{i} AS ({fg_query})" + + selected = ", ".join(f"fg_base.{f}" for f in base.projected_feature_names) + selected_final = ", ".join(base.projected_feature_names) + + for i, fg in enumerate(self._feature_groups_to_be_merged): + selected += ", " + ", ".join( + f'fg_{i}."{f}" as "{f}.{i+1}"' for f in fg.projected_feature_names + ) + selected_final += ", " + ", ".join( + f'"{f}.{i+1}"' for f in fg.projected_feature_names + ) + + query += ( + f"\nSELECT {selected_final}\nFROM (\n" + f"SELECT {selected}, row_number() OVER (\n" + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n' + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC' + ) + + join_strings = [] + for i, fg in enumerate(self._feature_groups_to_be_merged): + if not fg.target_feature_name_in_base: + fg.target_feature_name_in_base = self._record_identifier_feature_name + elif fg.target_feature_name_in_base not in base.features: + raise ValueError(f"Feature {fg.target_feature_name_in_base} not found in base") + query += f', fg_{i}."{fg.event_time_identifier_feature.feature_name}" DESC' + join_strings.append(self._construct_join_condition(fg, str(i))) + + recent_where = "" + if self._number_of_recent_records is not None and self._number_of_recent_records >= 0: + recent_where = f"WHERE row_recent <= {self._number_of_recent_records}" + + query += f"\n) AS row_recent\nFROM fg_base{''.join(join_strings)}\n)\n{recent_where}" + + if self._number_of_records is not None and self._number_of_records >= 0: + query += f"\nLIMIT {self._number_of_records}" + + return query + + def _construct_table_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + included = ", ".join(f'table_{suffix}."{f}"' for f in fg.included_feature_names) + included_with_write = included + if fg.table_type is TableType.FEATURE_GROUP: + included_with_write += f', table_{suffix}."write_time"' + + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + + if self._include_duplicated_records and self._include_deleted_records: + return ( + f"SELECT {included}\n" + f'FROM "{fg.database}"."{fg.table_name}" table_{suffix}\n' + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, ["NOT is_deleted"]) + ) + + if fg.table_type is TableType.FEATURE_GROUP and self._include_deleted_records: + rank = f'ORDER BY origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + return ( + f"SELECT {included}\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}", origin_{suffix}."{event_time}"\n' + f"{rank}) AS row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"WHERE NOT is_deleted) AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, [f"row_{suffix} = 1"]) + ) + + if fg.table_type is TableType.FEATURE_GROUP: + dedup = self._construct_dedup_query(fg, suffix) + deleted = self._construct_deleted_query(fg, suffix) + rank_cond = ( + f'OR (table_{suffix}."{event_time}" = deleted_{suffix}."{event_time}" ' + f'AND table_{suffix}."api_invocation_time" > deleted_{suffix}."api_invocation_time")\n' + f'OR (table_{suffix}."{event_time}" = deleted_{suffix}."{event_time}" ' + f'AND table_{suffix}."api_invocation_time" = deleted_{suffix}."api_invocation_time" ' + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n' + ) + + if self._include_duplicated_records: + return ( + f"WITH {deleted}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\n" + f'FROM "{fg.database}"."{fg.table_name}" table_{suffix}\n' + f"LEFT JOIN deleted_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'WHERE deleted_{suffix}."{record_id}" IS NULL\n' + f"UNION ALL\n" + f"SELECT {included_with_write}\nFROM deleted_{suffix}\n" + f'JOIN "{fg.database}"."{fg.table_name}" table_{suffix}\n' + f'ON table_{suffix}."{record_id}" = deleted_{suffix}."{record_id}"\n' + f'AND (table_{suffix}."{event_time}" > deleted_{suffix}."{event_time}"\n{rank_cond})\n' + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + return ( + f"WITH {dedup},\n{deleted}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\nFROM table_{suffix}\n" + f"LEFT JOIN deleted_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'WHERE deleted_{suffix}."{record_id}" IS NULL\n' + f"UNION ALL\n" + f"SELECT {included_with_write}\nFROM deleted_{suffix}\n" + f"JOIN table_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'AND (table_{suffix}."{event_time}" > deleted_{suffix}."{event_time}"\n{rank_cond})\n' + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + dedup = self._construct_dedup_query(fg, suffix) + return ( + f"WITH {dedup}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\nFROM table_{suffix}\n" + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + def _construct_dedup_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + rank = "" + is_fg = fg.table_type is TableType.FEATURE_GROUP + + if is_fg: + rank = f'ORDER BY origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + + where_conds = [] + if is_fg and self._write_time_ending_timestamp: + where_conds.append(self._construct_write_time_condition(f"origin_{suffix}")) + where_conds.extend(self._construct_event_time_conditions(f"origin_{suffix}", fg.event_time_identifier_feature)) + where_str = f"WHERE {' AND '.join(where_conds)}\n" if where_conds else "" + + dedup_where = f"WHERE dedup_row_{suffix} = 1\n" if is_fg else "" + + return ( + f"table_{suffix} AS (\n" + f"SELECT *\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}", origin_{suffix}."{event_time}"\n' + f"{rank}) AS dedup_row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"{where_str})\n{dedup_where})" + ) + + def _construct_deleted_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + rank = f'ORDER BY origin_{suffix}."{event_time}" DESC' + + if fg.table_type is TableType.FEATURE_GROUP: + rank += f', origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + + write_cond = "" + if fg.table_type is TableType.FEATURE_GROUP and self._write_time_ending_timestamp: + write_cond = f" AND {self._construct_write_time_condition(f'origin_{suffix}')}\n" + + event_conds = "" + if self._event_time_starting_timestamp and self._event_time_ending_timestamp: + conds = self._construct_event_time_conditions(f"origin_{suffix}", fg.event_time_identifier_feature) + event_conds = "".join(f"AND {c}\n" for c in conds) + + return ( + f"deleted_{suffix} AS (\n" + f"SELECT *\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}"\n' + f"{rank}) AS deleted_row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"WHERE is_deleted{write_cond}{event_conds})\n" + f"WHERE deleted_row_{suffix} = 1\n)" + ) + + def _construct_where_query_string( + self, suffix: str, event_time_feature: FeatureDefinition, conditions: List[str] + ) -> str: + self._validate_options() + + if isinstance(self._base, FeatureGroup) and self._write_time_ending_timestamp: + conditions.append(self._construct_write_time_condition(f"table_{suffix}")) + + conditions.extend(self._construct_event_time_conditions(f"table_{suffix}", event_time_feature)) + return f"WHERE {' AND '.join(conditions)}" if conditions else "" + + def _validate_options(self): + is_df_base = isinstance(self._base, pd.DataFrame) + no_joins = len(self._feature_groups_to_be_merged) == 0 + + if self._number_of_recent_records is not None and self._number_of_recent_records < 0: + raise ValueError("number_of_recent_records must be non-negative.") + if self._number_of_records is not None and self._number_of_records < 0: + raise ValueError("number_of_records must be non-negative.") + if is_df_base and no_joins: + if self._include_deleted_records: + raise ValueError("include_deleted_records() only works for FeatureGroup if no join.") + if self._include_duplicated_records: + raise ValueError("include_duplicated_records() only works for FeatureGroup if no join.") + if self._write_time_ending_timestamp: + raise ValueError("as_of() only works for FeatureGroup if no join.") + if self._point_in_time_accurate_join and no_joins: + raise ValueError("point_in_time_accurate_join() requires at least one join.") + + def _construct_event_time_conditions(self, table: str, event_time_feature: FeatureDefinition) -> List[str]: + cast_fn = "from_iso8601_timestamp" if event_time_feature.feature_type == FeatureTypeEnum.STRING else "from_unixtime" + conditions = [] + if self._event_time_starting_timestamp: + conditions.append( + f'{cast_fn}({table}."{event_time_feature.feature_name}") >= ' + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})" + ) + if self._event_time_ending_timestamp: + conditions.append( + f'{cast_fn}({table}."{event_time_feature.feature_name}") <= ' + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})" + ) + return conditions + + def _construct_write_time_condition(self, table: str) -> str: + ts = self._write_time_ending_timestamp.replace(microsecond=0) + return f'{table}."write_time" <= to_timestamp(\'{ts}\', \'yyyy-mm-dd hh24:mi:ss\')' + + def _construct_join_condition(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + target_feature = fg.feature_name_in_target or fg.record_identifier_feature_name + join = ( + f"\n{fg.join_type.value} fg_{suffix}\n" + f'ON fg_base."{fg.target_feature_name_in_base}" {fg.join_comparator.value} fg_{suffix}."{target_feature}"' + ) + + if self._point_in_time_accurate_join: + base_cast = "from_iso8601_timestamp" if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING else "from_unixtime" + fg_cast = "from_iso8601_timestamp" if fg.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING else "from_unixtime" + join += ( + f'\nAND {base_cast}(fg_base."{self._event_time_identifier_feature_name}") >= ' + f'{fg_cast}(fg_{suffix}."{fg.event_time_identifier_feature.feature_name}")' + ) + + return join diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py new file mode 100644 index 0000000000..32408e5585 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Definitions for FeatureStore.""" +from __future__ import absolute_import + +from enum import Enum +from typing import Optional, Union + +from sagemaker.core.shapes import ( + FeatureDefinition, + CollectionConfig, + VectorConfig, +) + +class FeatureTypeEnum(Enum): + """Feature data types: Fractional, Integral, or String.""" + + FRACTIONAL = "Fractional" + INTEGRAL = "Integral" + STRING = "String" + +class CollectionTypeEnum(Enum): + """Collection types: List, Set, or Vector.""" + + LIST = "List" + SET = "Set" + VECTOR = "Vector" + +class ListCollectionType: + """List collection type.""" + + collection_type = CollectionTypeEnum.LIST.value + collection_config = None + +class SetCollectionType: + """Set collection type.""" + + collection_type = CollectionTypeEnum.SET.value + collection_config = None + +class VectorCollectionType: + """Vector collection type with dimension.""" + + collection_type = CollectionTypeEnum.VECTOR.value + + def __init__(self, dimension: int): + self.collection_config = CollectionConfig( + vector_config=VectorConfig(dimension=dimension) + ) + +CollectionType = Union[ListCollectionType, SetCollectionType, VectorCollectionType] + +def _create_feature_definition( + feature_name: str, + feature_type: FeatureTypeEnum, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Internal helper to create FeatureDefinition from collection type.""" + return FeatureDefinition( + feature_name=feature_name, + feature_type=feature_type.value, + collection_type=collection_type.collection_type if collection_type else None, + collection_config=collection_type.collection_config if collection_type else None, + ) + +def FractionalFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with Fractional type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.FRACTIONAL, collection_type) + +def IntegralFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with Integral type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.INTEGRAL, collection_type) + +def StringFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with String type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.STRING, collection_type) + +__all__ = [ + "FeatureDefinition", + "FeatureTypeEnum", + "CollectionTypeEnum", + "ListCollectionType", + "SetCollectionType", + "VectorCollectionType", + "FractionalFeatureDefinition", + "IntegralFeatureDefinition", + "StringFeatureDefinition", +] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py new file mode 100644 index 0000000000..1051096d0e --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Exported classes for the sagemaker.mlops.feature_store.feature_processor module.""" +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( # noqa: F401 + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, + PySparkDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._exceptions import ( # noqa: F401 + IngestionError, +) +from sagemaker.mlops.feature_store.feature_processor.feature_processor import ( # noqa: F401 + feature_processor, +) +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( # noqa: F401 + to_pipeline, + schedule, + describe, + put_trigger, + delete_trigger, + enable_trigger, + disable_trigger, + delete_schedule, + list_pipelines, + execute, + TransformationCode, + FeatureProcessorPipelineEvents, +) +from sagemaker.mlops.feature_store.feature_processor._enums import ( # noqa: F401 + FeatureProcessorPipelineExecutionStatus, +) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py new file mode 100644 index 0000000000..d181218fb5 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py @@ -0,0 +1,209 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for preparing and uploading configs for a scheduled feature processor.""" +from __future__ import absolute_import +from typing import Callable, Dict, Optional, Tuple, List, Union + +import attr + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import ( + SPARK_JAR_FILES_PATH, + SPARK_PY_FILES_PATH, + SPARK_FILES_PATH, + S3_DATA_DISTRIBUTION_TYPE, +) +from sagemaker.core.inputs import TrainingInput +from sagemaker.core.shapes import Channel, DataSource, S3DataSource +from sagemaker.core.remote_function.core.stored_function import StoredFunction +from sagemaker.core.remote_function.job import ( + _prepare_and_upload_workspace, + _prepare_and_upload_runtime_scripts, + _JobSettings, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, + _prepare_and_upload_spark_dependent_files, +) +from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, +) +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.core.s3 import s3_path_join + + +@attr.s +class ConfigUploader: + """Prepares and uploads customer provided configs to S3""" + + remote_decorator_config: _JobSettings = attr.ib() + runtime_env_manager: RuntimeEnvironmentManager = attr.ib() + + def prepare_step_input_channel_for_spark_mode( + self, func: Callable, s3_base_uri: str, sagemaker_session: Session + ) -> Tuple[List[Channel], Dict]: + """Prepares input channels for SageMaker Pipeline Step. + + Returns: + Tuple of (List[Channel], spark_dependency_paths dict) + """ + self._prepare_and_upload_callable(func, s3_base_uri, sagemaker_session) + bootstrap_scripts_s3uri = self._prepare_and_upload_runtime_scripts( + self.remote_decorator_config.spark_config, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + ) + dependencies_list_path = self.runtime_env_manager.snapshot( + self.remote_decorator_config.dependencies + ) + user_workspace_s3uri = self._prepare_and_upload_workspace( + dependencies_list_path, + self.remote_decorator_config.include_local_workdir, + self.remote_decorator_config.pre_execution_commands, + self.remote_decorator_config.pre_execution_script, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + self.remote_decorator_config.custom_file_filter, + ) + + ( + submit_jars_s3_paths, + submit_py_files_s3_paths, + submit_files_s3_path, + config_file_s3_uri, + ) = self._prepare_and_upload_spark_dependent_files( + self.remote_decorator_config.spark_config, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + ) + + channels = [ + Channel( + channel_name=RUNTIME_SCRIPTS_CHANNEL_NAME, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=bootstrap_scripts_s3uri, + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ] + + if user_workspace_s3uri: + channels.append( + Channel( + channel_name=REMOTE_FUNCTION_WORKSPACE, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE), + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ) + + if config_file_s3_uri: + channels.append( + Channel( + channel_name=SPARK_CONF_CHANNEL_NAME, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=config_file_s3_uri, + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ) + + return channels, { + SPARK_JAR_FILES_PATH: submit_jars_s3_paths, + SPARK_PY_FILES_PATH: submit_py_files_s3_paths, + SPARK_FILES_PATH: submit_files_s3_path, + } + + def _prepare_and_upload_callable( + self, func: Callable, s3_base_uri: str, sagemaker_session: Session + ) -> None: + """Prepares and uploads callable to S3""" + stored_function = StoredFunction( + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=self.remote_decorator_config.s3_kms_key, + ) + stored_function.save(func) + + def _prepare_and_upload_workspace( + self, + local_dependencies_path: str, + include_local_workdir: bool, + pre_execution_commands: List[str], + pre_execution_script_local_path: str, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, + ) -> str: + """Upload the training step dependencies to S3 if present""" + return _prepare_and_upload_workspace( + local_dependencies_path=local_dependencies_path, + include_local_workdir=include_local_workdir, + pre_execution_commands=pre_execution_commands, + pre_execution_script_local_path=pre_execution_script_local_path, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=custom_file_filter, + ) + + def _prepare_and_upload_runtime_scripts( + self, + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + ) -> str: + """Copy runtime scripts to a folder and upload to S3""" + return _prepare_and_upload_runtime_scripts( + spark_config=spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + def _prepare_and_upload_spark_dependent_files( + self, + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + ) -> Tuple: + """Upload the spark dependencies to S3 if present""" + if not spark_config: + return None, None, None, None + + return _prepare_and_upload_spark_dependent_files( + spark_config=spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py new file mode 100644 index 0000000000..e010446904 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing constants for feature_processor and feature_scheduler module.""" +from __future__ import absolute_import + +from sagemaker.core.workflow.parameters import Parameter, ParameterTypeEnum + +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" +DEFAULT_SCHEDULE_STATE = "ENABLED" +DEFAULT_TRIGGER_STATE = "ENABLED" +UNDERSCORE = "_" +RESOURCE_NOT_FOUND_EXCEPTION = "ResourceNotFoundException" +RESOURCE_NOT_FOUND = "ResourceNotFound" +EXECUTION_TIME_PIPELINE_PARAMETER = "scheduled_time" +VALIDATION_EXCEPTION = "ValidationException" +EVENT_BRIDGE_INVOCATION_TIME = "" +SCHEDULED_TIME_PIPELINE_PARAMETER = Parameter( + name=EXECUTION_TIME_PIPELINE_PARAMETER, parameter_type=ParameterTypeEnum.STRING +) +EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT = "%Y-%m-%dT%H:%M:%SZ" # 2023-01-01T07:00:00Z +NO_FLEXIBLE_TIME_WINDOW = dict(Mode="OFF") +PIPELINE_NAME_MAXIMUM_LENGTH = 80 +PIPELINE_CONTEXT_TYPE = "FeatureEngineeringPipeline" +SPARK_JAR_FILES_PATH = "submit_jars_s3_paths" +SPARK_PY_FILES_PATH = "submit_py_files_s3_paths" +SPARK_FILES_PATH = "submit_files_s3_path" +FEATURE_PROCESSOR_TAG_KEY = "sm-fs-fe:created-from" +FEATURE_PROCESSOR_TAG_VALUE = "fp-to-pipeline" +FEATURE_GROUP_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):feature-group/(.*?)$" +PIPELINE_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):pipeline/(.*?)$" +EVENTBRIDGE_RULE_ARN_REGEX_PATTERN = r"arn:(.*?):events:(.*?):(.*?):rule/(.*?)$" +SAGEMAKER_WHL_FILE_S3_PATH = "s3://ada-private-beta/sagemaker-2.151.1.dev0-py2.py3-none-any.whl" +S3_DATA_DISTRIBUTION_TYPE = "FullyReplicated" +PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name" +PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name" +TO_PIPELINE_RESERVED_TAG_KEYS = [ + FEATURE_PROCESSOR_TAG_KEY, + PIPELINE_CONTEXT_NAME_TAG_KEY, + PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, +] +BASE_EVENT_PATTERN = { + "source": ["aws.sagemaker"], + "detail": {"currentPipelineExecutionStatus": [], "pipelineArn": []}, +} diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py new file mode 100644 index 0000000000..a6c452267c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes to define input data sources.""" +from __future__ import absolute_import + +from typing import Optional, Dict, Union, TypeVar, Generic +from abc import ABC, abstractmethod +from pyspark.sql import DataFrame, SparkSession + + +import attr + +T = TypeVar("T") + + +@attr.s +class BaseDataSource(Generic[T], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the specified type. + """ + + @abstractmethod + def read_data(self, *args, **kwargs) -> T: + """Read data from data source and return the specified type. + + Args: + args: Arguments for reading the data. + kwargs: Keyword argument for reading the data. + Returns: + T: The specified abstraction of data source. + """ + + @property + @abstractmethod + def data_source_unique_id(self) -> str: + """The identifier for the customized feature processor data source. + + Returns: + str: The data source unique id. + """ + + @property + @abstractmethod + def data_source_name(self) -> str: + """The name for the customized feature processor data source. + + Returns: + str: The data source name. + """ + + +@attr.s +class PySparkDataSource(BaseDataSource[DataFrame], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the Spark DataFrame. + """ + + @abstractmethod + def read_data( + self, spark: SparkSession, params: Optional[Dict[str, Union[str, Dict]]] = None + ) -> DataFrame: + """Read data from data source and convert the data to Spark DataFrame. + + Args: + spark (SparkSession): The Spark session to read the data. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. + Returns: + DataFrame: The Spark DataFrame as an abstraction on the data source. + """ + + +@attr.s +class FeatureGroupDataSource: + """A Feature Group data source definition for a FeatureProcessor. + + Attributes: + name (str): The name or ARN of the Feature Group. + input_start_offset (Optional[str], optional): A duration specified as a string in the + format ' ' where 'no' is a number and 'unit' is a unit of time in ['hours', + 'days', 'weeks', 'months', 'years'] (plural and singular forms). Inputs contain data + with event times no earlier than input_start_offset in the past. Offsets are relative + to the function execution time. If the function is executed by a Schedule, then the + offset is relative to the scheduled start time. Defaults to None. + input_end_offset (Optional[str], optional): The 'end' (as opposed to start) counterpart for + the 'input_start_offset'. Inputs will contain records with event times no later than + 'input_end_offset' in the past. Defaults to None. + """ + + name: str = attr.ib() + input_start_offset: Optional[str] = attr.ib(default=None) + input_end_offset: Optional[str] = attr.ib(default=None) + + +@attr.s +class CSVDataSource: + """An CSV data source definition for a FeatureProcessor. + + Attributes: + s3_uri (str): S3 URI of the data source. + csv_header (bool): Whether to read the first line of the CSV file as column names. This + option is only valid when file_format is set to csv. By default the value of this + option is true, and all column types are assumed to be a string. + infer_schema (bool): Whether to infer the schema of the CSV data source. This option is only + valid when file_format is set to csv. If set to true, two passes of the data is required + to load and infer the schema. + """ + + s3_uri: str = attr.ib() + csv_header: bool = attr.ib(default=True) + csv_infer_schema: bool = attr.ib(default=False) + + +@attr.s +class ParquetDataSource: + """An parquet data source definition for a FeatureProcessor. + + Attributes: + s3_uri (str): S3 URI of the data source. + """ + + s3_uri: str = attr.ib() + + +@attr.s +class IcebergTableDataSource: + """An iceberg table data source definition for FeatureProcessor + + Attributes: + warehouse_s3_uri (str): S3 URI of data warehouse. The value is usually + the URI where data is stored. + catalog (str): Name of the catalog. + database (str): Name of the database. + table (str): Name of the table. + """ + + warehouse_s3_uri: str = attr.ib() + catalog: str = attr.ib() + database: str = attr.ib() + table: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py new file mode 100644 index 0000000000..b63ed3a65a --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing Enums for the feature_processor module.""" +from __future__ import absolute_import + +from enum import Enum + + +class FeatureProcessorMode(Enum): + """Enum of feature_processor modes.""" + + PYSPARK = "pyspark" # Execute a pyspark job. + PYTHON = "python" # Execute a regular python script. + + +class FeatureProcessorPipelineExecutionStatus(Enum): + """Enum of feature_processor pipeline execution status.""" + + EXECUTING = "Executing" + STOPPING = "Stopping" + STOPPED = "Stopped" + FAILED = "Failed" + SUCCEEDED = "Succeeded" diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py new file mode 100644 index 0000000000..d4ccfb1197 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class that determines the current execution environment.""" +from __future__ import absolute_import + + +from typing import Dict, Optional +from datetime import datetime, timezone +import json +import logging +import os +import attr +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) + + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EnvironmentHelper: + """Helper class to retrieve info from environment. + + Attributes: + current_time (datetime): The current datetime. + """ + + current_time = attr.ib(default=datetime.now(timezone.utc)) + + def is_training_job(self) -> bool: + """Determine if the current execution environment is inside a SageMaker Training Job""" + return self.load_training_resource_config() is not None + + def get_instance_count(self) -> int: + """Determine the number of instances for the current execution environment.""" + resource_config = self.load_training_resource_config() + return len(resource_config["hosts"]) if resource_config else 1 + + def load_training_resource_config(self) -> Optional[Dict]: + """Load the contents of resourceconfig.json contents. + + Returns: + Optional[Dict]: None if not found. + """ + SM_TRAINING_CONFIG_FILE_PATH = "/opt/ml/input/config/resourceconfig.json" + try: + with open(SM_TRAINING_CONFIG_FILE_PATH, "r") as cfgfile: + resource_config = json.load(cfgfile) + logger.debug("Contents of %s: %s", SM_TRAINING_CONFIG_FILE_PATH, resource_config) + return resource_config + except FileNotFoundError: + return None + + def get_job_scheduled_time(self) -> str: + """Get the job scheduled time. + + Returns: + str: Timestamp when the job is scheduled. + """ + + scheduled_time = self.current_time.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + if self.is_training_job(): + envs = dict(os.environ) + return envs.get(EXECUTION_TIME_PIPELINE_PARAMETER, scheduled_time) + + return scheduled_time diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py new file mode 100644 index 0000000000..250e7d456f --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py @@ -0,0 +1,305 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for EventBridge Schedule management for a feature processor.""" +from __future__ import absolute_import + +import json +import logging +import re +from typing import Dict, List, Tuple, Optional, Any +import attr +from botocore.exceptions import ClientError +from botocore.paginate import PageIterator +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + RESOURCE_NOT_FOUND_EXCEPTION, + PIPELINE_ARN_REGEX_PATTERN, + BASE_EVENT_PATTERN, +) +from sagemaker.mlops.feature_store.feature_processor._enums import ( + FeatureProcessorPipelineExecutionStatus, +) +from sagemaker.core.common_utils import TagsDict + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EventBridgeRuleHelper: + """Contains helper methods for managing EventBridge rules for a feature processor.""" + + sagemaker_session: Session = attr.ib() + event_bridge_rule_client = attr.ib() + + def put_rule( + self, + source_pipeline_events: List[FeatureProcessorPipelineEvents], + target_pipeline: str, + event_pattern: str, + state: str, + ) -> str: + """Creates an EventBridge Rule for a given target pipeline. + + Args: + source_pipeline_events: The list of pipeline events that trigger the EventBridge Rule. + target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule. + event_pattern: The EventBridge EventPattern that triggers the EventBridge Rule. + If specified, will override source_pipeline_events. + state: Indicates whether the rule is enabled or disabled. + + Returns: + The Amazon Resource Name (ARN) of the rule. + """ + self._validate_feature_processor_pipeline_events(source_pipeline_events) + rule_name = target_pipeline + _event_patterns = ( + event_pattern + or self._generate_event_pattern_from_feature_processor_pipeline_events( + source_pipeline_events + ) + ) + rule_arn = self.event_bridge_rule_client.put_rule( + Name=rule_name, EventPattern=_event_patterns, State=state + )["RuleArn"] + return rule_arn + + def put_target( + self, + rule_name: str, + target_pipeline: str, + target_pipeline_parameters: Dict[str, str], + role_arn: str, + ) -> None: + """Attach target pipeline to an event based trigger. + + Args: + rule_name: The name of the EventBridge Rule. + target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule. + target_pipeline_parameters: The list of parameters to start execution of a pipeline. + role_arn: The Amazon Resource Name (ARN) of the IAM role associated with the rule. + """ + target_pipeline_arn_and_name = self._generate_pipeline_arn_and_name(target_pipeline) + target_pipeline_name = target_pipeline_arn_and_name["pipeline_name"] + target_pipeline_arn = target_pipeline_arn_and_name["pipeline_arn"] + target_request_dict = { + "Id": target_pipeline_name, + "Arn": target_pipeline_arn, + "RoleArn": role_arn, + } + if target_pipeline_parameters: + target_request_dict["SageMakerPipelineParameters"] = { + "PipelineParameterList": target_pipeline_parameters + } + put_targets_response = self.event_bridge_rule_client.put_targets( + Rule=rule_name, + Targets=[target_request_dict], + ) + if put_targets_response["FailedEntryCount"] != 0: + error_msg = put_targets_response["FailedEntries"][0]["ErrorMessage"] + raise Exception(f"Failed to add target pipeline to rule. Failure reason: {error_msg}") + + def delete_rule(self, rule_name: str) -> None: + """Deletes an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.delete_rule(Name=rule_name) + + def remove_targets(self, rule_name: str, ids: List[str]) -> None: + """Deletes an EventBridge Targets of a given rule if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + ids: The ids of the EventBridge Target. + """ + self.event_bridge_rule_client.remove_targets(Rule=rule_name, Ids=ids) + + def list_targets_by_rule(self, rule_name: str) -> PageIterator: + """List EventBridge Targets of a given rule. + + Args: + rule_name: The name of the EventBridge Rule. + + Returns: + The page iterator of list_targets_by_rule call. + """ + return self.event_bridge_rule_client.get_paginator("list_targets_by_rule").paginate( + Rule=rule_name + ) + + def describe_rule(self, rule_name: str) -> Optional[Dict[str, Any]]: + """Describe the EventBridge Rule ARN corresponding to a sagemaker pipeline + + Args: + rule_name: The name of the EventBridge Rule. + Returns: + Optional[Dict[str, str]] : Describe EventBridge Rule response if exists. + """ + try: + event_bridge_rule_response = self.event_bridge_rule_client.describe_rule(Name=rule_name) + return event_bridge_rule_response + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("No EventBridge Rule found for pipeline %s.", rule_name) + return None + raise e + + def enable_rule(self, rule_name: str) -> None: + """Enables an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.enable_rule(Name=rule_name) + logger.info("Enabled EventBridge Rule for pipeline %s.", rule_name) + + def disable_rule(self, rule_name: str) -> None: + """Disables an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.disable_rule(Name=rule_name) + logger.info("Disabled EventBridge Rule for pipeline %s.", rule_name) + + def add_tags(self, rule_arn: str, tags: List[TagsDict]) -> None: + """Adds tags to the EventBridge Rule. + + Args: + rule_arn: The ARN of the EventBridge Rule. + tags: List of tags to be added. + """ + self.event_bridge_rule_client.tag_resource(ResourceARN=rule_arn, Tags=tags) + + def _generate_event_pattern_from_feature_processor_pipeline_events( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> str: + """Generates the event pattern json string from the pipeline events. + + Args: + pipeline_events: List of pipeline events. + Returns: + str: The event pattern json string. + + Raises: + ValueError: If pipeline events contain duplicate pipeline names. + """ + + result_event_pattern = { + "detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], + } + filters = [] + desired_status_to_pipeline_names_map = ( + self._aggregate_pipeline_events_with_same_desired_status(pipeline_events) + ) + for desired_status in desired_status_to_pipeline_names_map: + pipeline_arns = [ + self._generate_pipeline_arn_and_name(pipeline_name)["pipeline_arn"] + for pipeline_name in desired_status_to_pipeline_names_map[desired_status] + ] + curr_filter = BASE_EVENT_PATTERN.copy() + curr_filter["detail"]["pipelineArn"] = pipeline_arns + curr_filter["detail"]["currentPipelineExecutionStatus"] = [ + status_enum.value for status_enum in desired_status + ] + filters.append(curr_filter) + if len(filters) > 1: + result_event_pattern["$or"] = filters + else: + result_event_pattern.update(filters[0]) + return json.dumps(result_event_pattern) + + def _validate_feature_processor_pipeline_events( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> None: + """Validates the pipeline events. + + Args: + pipeline_events: List of pipeline events. + Raises: + ValueError: If pipeline events contain duplicate pipeline names. + """ + + unique_pipelines = {event.pipeline_name for event in pipeline_events} + potential_infinite_loop = [] + if len(unique_pipelines) != len(pipeline_events): + raise ValueError("Pipeline names in pipeline_events must be unique.") + + for event in pipeline_events: + if FeatureProcessorPipelineExecutionStatus.EXECUTING in event.pipeline_execution_status: + potential_infinite_loop.append(event.pipeline_name) + if potential_infinite_loop: + logger.warning( + "Potential infinite loop detected for pipelines %s. " + "Setting pipeline_execution_status to EXECUTING might cause infinite loop. " + "Please consider a terminal status instead.", + potential_infinite_loop, + ) + + def _aggregate_pipeline_events_with_same_desired_status( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> Dict[Tuple, List[str]]: + """Aggregate pipeline events with same desired status. + + e.g. + { + (FeatureProcessorPipelineExecutionStatus.FAILED, + FeatureProcessorPipelineExecutionStatus.STOPPED): + ["pipeline_name_1", "pipeline_name_2"], + (FeatureProcessorPipelineExecutionStatus.STOPPED, + FeatureProcessorPipelineExecutionStatus.STOPPED): + ["pipeline_name_3"], + } + Args: + pipeline_events: List of pipeline events. + Returns: + Dict[Tuple, List[str]]: A dictionary of desired status keys and corresponding pipeline + names. + """ + events_by_desired_status = {} + + for event in pipeline_events: + sorted_execution_status = sorted(event.pipeline_execution_status, key=lambda x: x.value) + desired_status_keys = tuple(sorted_execution_status) + + if desired_status_keys not in events_by_desired_status: + events_by_desired_status[desired_status_keys] = [] + events_by_desired_status[desired_status_keys].append(event.pipeline_name) + + return events_by_desired_status + + def _generate_pipeline_arn_and_name(self, pipeline_uri: str) -> Dict[str, str]: + """Generate pipeline arn and pipeline name from pipeline uri. + + Args: + pipeline_uri: The name or arn of the pipeline. + Returns: + Dict[str, str]: The arn and name of the pipeline. + """ + match = re.match(PIPELINE_ARN_REGEX_PATTERN, pipeline_uri) + pipeline_arn = "" + pipeline_name = "" + if not match: + pipeline_name = pipeline_uri + describe_pipeline_response = self.sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_arn = describe_pipeline_response["PipelineArn"] + else: + pipeline_arn = pipeline_uri + pipeline_name = match.group(4) + return dict(pipeline_arn=pipeline_arn, pipeline_name=pipeline_name) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py new file mode 100644 index 0000000000..f454a217e2 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py @@ -0,0 +1,118 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for EventBridge Schedule management for a feature processor.""" +from __future__ import absolute_import +import logging +from datetime import datetime +from typing import Dict, Optional, Any +import attr +from botocore.exceptions import ClientError +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EVENT_BRIDGE_INVOCATION_TIME, + NO_FLEXIBLE_TIME_WINDOW, + RESOURCE_NOT_FOUND_EXCEPTION, +) + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EventBridgeSchedulerHelper: + """Contains helper methods for scheduling events to EventBridge""" + + sagemaker_session: Session = attr.ib() + event_bridge_scheduler_client = attr.ib() + + def upsert_schedule( + self, + schedule_name: str, + pipeline_arn: str, + schedule_expression: str, + state: str, + start_date: datetime, + role: str, + ) -> Dict: + """Creates or updates a Schedule for the given pipeline_arn and schedule_expression. + + Args: + schedule_name: The name of the schedule. + pipeline_arn: The ARN of the sagemaker pipeline that needs to scheduled. + schedule_expression: The schedule expression. + state: Specifies whether the schedule is enabled or disabled. Can only + be ENABLED or DISABLED. + start_date: The date, in UTC, after which the schedule can begin invoking its target. + role: The RoleArn used to execute the scheduled events. + + Returns: + schedule_arn: The arn of the schedule. + """ + pipeline_parameter = dict( + PipelineParameterList=[ + dict( + Name=EXECUTION_TIME_PIPELINE_PARAMETER, + Value=EVENT_BRIDGE_INVOCATION_TIME, + ) + ] + ) + create_or_update_schedule_request_dict = dict( + Name=schedule_name, + ScheduleExpression=schedule_expression, + FlexibleTimeWindow=NO_FLEXIBLE_TIME_WINDOW, + Target=dict( + Arn=pipeline_arn, + SageMakerPipelineParameters=pipeline_parameter, + RoleArn=role, + ), + State=state, + StartDate=start_date, + ) + try: + return self.event_bridge_scheduler_client.update_schedule( + **create_or_update_schedule_request_dict + ) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + return self.event_bridge_scheduler_client.create_schedule( + **create_or_update_schedule_request_dict + ) + raise e + + def delete_schedule(self, schedule_name: str) -> None: + """Deletes an EventBridge Schedule of a given pipeline if there is one. + + Args: + schedule_name: The name of the EventBridge Schedule. + """ + logger.info("Deleting EventBridge Schedule for pipeline %s.", schedule_name) + self.event_bridge_scheduler_client.delete_schedule(Name=schedule_name) + + def describe_schedule(self, schedule_name) -> Optional[Dict[str, Any]]: + """Describe the EventBridge Schedule ARN corresponding to a sagemaker pipeline + + Args: + schedule_name: The name of the EventBridge Schedule. + Returns: + Optional[Dict[str, str]] : Describe EventBridge Schedule response if exists. + """ + try: + event_bridge_scheduler_response = self.event_bridge_scheduler_client.get_schedule( + Name=schedule_name + ) + return event_bridge_scheduler_response + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("No EventBridge Schedule found for pipeline %s.", schedule_name) + return None + raise e diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py new file mode 100644 index 0000000000..0b21d10ab9 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores exceptions related to the feature_store.feature_processor module.""" +from __future__ import absolute_import + + +class IngestionError(Exception): + """Exception raised to indicate that ingestion did not complete successfully.""" diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py new file mode 100644 index 0000000000..f205c32665 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py @@ -0,0 +1,167 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains static factory classes to instantiate complex objects for the FeatureProcessor.""" +from __future__ import absolute_import + +from typing import Dict +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._params_loader import ( + ParamsLoader, + SystemParamsLoader, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, + SparkSessionFactory, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import SparkArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + SparkOutputReceiver, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ( + FeatureProcessorArgValidator, + InputValidator, + SparkUDFSignatureValidator, + InputOffsetValidator, + BaseDataSourceValidator, + ValidatorChain, +) + + +class ValidatorFactory: + """Static factory to handle ValidationChain instantiation.""" + + @staticmethod + def get_validation_chain(fp_config: FeatureProcessorConfig) -> ValidatorChain: + """Instantiate a ValidationChain""" + base_validators = [ + InputValidator(), + FeatureProcessorArgValidator(), + InputOffsetValidator(), + BaseDataSourceValidator(), + ] + + mode = fp_config.mode + if FeatureProcessorMode.PYSPARK == mode: + base_validators.append(SparkUDFSignatureValidator()) + return ValidatorChain(validators=base_validators) + + raise ValueError(f"FeatureProcessorMode {mode} is not supported.") + + +class UDFWrapperFactory: + """Static factory to handle UDFWrapper instantiation at runtime.""" + + @staticmethod + def get_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper: + """Instantiate a UDFWrapper based on the FeatureProcessingMode. + + Args: + fp_config (FeatureProcessorConfig): the configuration values for the + feature_processor decorator. + + Raises: + ValueError: if an unsupported FeatureProcessorMode is provided in fp_config. + + Returns: + UDFWrapper: An instance of UDFWrapper to decorate the UDF. + """ + mode = fp_config.mode + + if FeatureProcessorMode.PYSPARK == mode: + return UDFWrapperFactory._get_spark_udf_wrapper(fp_config) + + raise ValueError(f"FeatureProcessorMode {mode} is not supported.") + + @staticmethod + def _get_spark_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper[DataFrame]: + """Instantiate a new UDFWrapper for PySpark functions. + + Args: + fp_config (FeatureProcessorConfig): the configuration values for the feature_processor + decorator. + """ + spark_session_factory = UDFWrapperFactory._get_spark_session_factory(fp_config.spark_config) + feature_store_manager_factory = UDFWrapperFactory._get_feature_store_manager_factory() + + output_manager = UDFWrapperFactory._get_spark_output_receiver(feature_store_manager_factory) + arg_provider = UDFWrapperFactory._get_spark_arg_provider(spark_session_factory) + + return UDFWrapper[DataFrame](arg_provider, output_manager) + + @staticmethod + def _get_spark_arg_provider( + spark_session_factory: SparkSessionFactory, + ) -> SparkArgProvider: + """Instantiate a new SparkArgProvider for PySpark functions. + + Args: + spark_session_factory (SparkSessionFactory): A factory to provide a reference to the + SparkSession initialized for the feature_processor wrapped function. The factory + lazily loads the SparkSession, i.e. defers to function execution time. + + Returns: + SparkArgProvider: An instance that generates arguments to provide to the + feature_processor wrapped function. + """ + environment_helper = EnvironmentHelper() + + system_parameters_arg_provider = SystemParamsLoader(environment_helper) + params_arg_provider = ParamsLoader(system_parameters_arg_provider) + input_loader = SparkDataFrameInputLoader(spark_session_factory, environment_helper) + + return SparkArgProvider(params_arg_provider, input_loader, spark_session_factory) + + @staticmethod + def _get_spark_output_receiver( + feature_store_manager_factory: FeatureStoreManagerFactory, + ) -> SparkOutputReceiver: + """Instantiate a new SparkOutputManager for PySpark functions. + + Args: + feature_store_manager_factory (FeatureStoreManagerFactory): A factory to provide + that provides a FeatureStoreManager that handles data ingestion to a Feature Group. + The factory lazily loads the FeatureStoreManager. + + Returns: + SparkOutputReceiver: An instance that handles outputs of the wrapped function. + """ + return SparkOutputReceiver(feature_store_manager_factory) + + @staticmethod + def _get_spark_session_factory(spark_config: Dict[str, str]) -> SparkSessionFactory: + """Instantiate a new SparkSessionFactory + + Args: + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. + + Returns: + SparkSessionFactory: A Spark session factory instance. + """ + environment_helper = EnvironmentHelper() + return SparkSessionFactory(environment_helper, spark_config) + + @staticmethod + def _get_feature_store_manager_factory() -> FeatureStoreManagerFactory: + """Instantiate a new FeatureStoreManagerFactory""" + return FeatureStoreManagerFactory() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py new file mode 100644 index 0000000000..f5d4dd91f3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains data classes for the FeatureProcessor.""" +from __future__ import absolute_import + +from typing import Dict, List, Optional, Sequence, Union + +import attr + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode + + +@attr.s(frozen=True) +class FeatureProcessorConfig: + """Immutable data class containing the arguments for a FeatureProcessor. + + This class is used throughout sagemaker.mlops.feature_store.feature_processor module. Documentation + for each field can be be found in the feature_processor decorator. + + Defaults are defined as literals in the feature_processor decorator's parameters for usability + (i.e. literals in docs). Defaults, or any business logic, should not be added to this class. + It only serves as an immutable data class. + """ + + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib() + output: str = attr.ib() + mode: FeatureProcessorMode = attr.ib() + target_stores: Optional[List[str]] = attr.ib() + parameters: Optional[Dict[str, Union[str, Dict]]] = attr.ib() + enable_ingestion: bool = attr.ib() + spark_config: Dict[str, str] = attr.ib() + + @staticmethod + def create( + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], + output: str, + mode: FeatureProcessorMode, + target_stores: Optional[List[str]], + parameters: Optional[Dict[str, Union[str, Dict]]], + enable_ingestion: bool, + spark_config: Dict[str, str], + ) -> "FeatureProcessorConfig": + """Static initializer.""" + return FeatureProcessorConfig( + inputs=inputs, + output=output, + mode=mode, + target_stores=target_stores, + parameters=parameters, + enable_ingestion=enable_ingestion, + spark_config=spark_config, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py new file mode 100644 index 0000000000..4ce9fb1b76 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains data classes for the Feature Processor Pipeline Events.""" +from __future__ import absolute_import + +from typing import List +import attr +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorPipelineExecutionStatus + + +@attr.s(frozen=True) +class FeatureProcessorPipelineEvents: + """Immutable data class containing the execution events for a FeatureProcessor pipeline. + + This class is used for creating event based triggers for feature processor pipelines. + """ + + pipeline_name: str = attr.ib() + pipeline_execution_status: List[FeatureProcessorPipelineExecutionStatus] = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py new file mode 100644 index 0000000000..627de943c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py @@ -0,0 +1,366 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes that loads user specified input sources (e.g. Feature Groups, S3 URIs, etc).""" +from __future__ import absolute_import + +import logging +import re +from abc import ABC, abstractmethod +from typing import Generic, Optional, TypeVar, Union + +import attr +from pyspark.sql import DataFrame + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import FEATURE_GROUP_ARN_REGEX_PATTERN +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + IcebergTableDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +T = TypeVar("T") + +logger = logging.getLogger("sagemaker") + + +class InputLoader(Generic[T], ABC): + """Loads the contents of a Feature Group's offline store or contents at an S3 URI.""" + + @abstractmethod + def load_from_feature_group(self, feature_group_data_source: FeatureGroupDataSource) -> T: + """Load the data from a Feature Group's offline store. + + Args: + feature_group_data_source (FeatureGroupDataSource): the feature group source. + + Returns: + T: The contents of the offline store as an instance of type T. + """ + + @abstractmethod + def load_from_s3(self, s3_data_source: Union[CSVDataSource, ParquetDataSource]) -> T: + """Load the contents from an S3 based data source. + + Args: + s3_data_source (Union[CSVDataSource, ParquetDataSource]): a data source that is based + in S3. + + Returns: + T: The contents stored at the data source as an instance of type T. + """ + + +@attr.s +class SparkDataFrameInputLoader(InputLoader[DataFrame]): + """InputLoader that reads data in as a Spark DataFrame.""" + + spark_session_factory: SparkSessionFactory = attr.ib() + environment_helper: EnvironmentHelper = attr.ib() + sagemaker_session: Optional[Session] = attr.ib(default=None) + + _supported_table_format = ["Iceberg", "Glue", None] + + def load_from_feature_group( + self, feature_group_data_source: FeatureGroupDataSource + ) -> DataFrame: + """Load the contents of a Feature Group's offline store as a DataFrame. + + Args: + feature_group_data_source (FeatureGroupDataSource): the Feature Group source. + + Raises: + ValueError: If the Feature Group does not have an Offline Store. + ValueError: If the Feature Group's Table Type is not supported by the feature_processor. + + Returns: + DataFrame: A Spark DataFrame containing the contents of the Feature Group's + offline store. + """ + sagemaker_session: Session = self.sagemaker_session or Session() + + feature_group_name = feature_group_data_source.name + feature_group = sagemaker_session.describe_feature_group( + self._parse_name_from_arn(feature_group_name) + ) + logger.debug( + "Called describe_feature_group with %s and received: %s", + feature_group_name, + feature_group, + ) + + if "OfflineStoreConfig" not in feature_group: + raise ValueError( + f"Input Feature Groups must have an enabled Offline Store." + f" Feature Group: {feature_group_name} does not have an Offline Store enabled." + ) + + offline_store_uri = feature_group["OfflineStoreConfig"]["S3StorageConfig"][ + "ResolvedOutputS3Uri" + ] + + table_format = feature_group["OfflineStoreConfig"].get("TableFormat", None) + + if table_format not in self._supported_table_format: + raise ValueError( + f"Feature group with table format {table_format} is not supported. " + f"The table format should be one of {self._supported_table_format}." + ) + + start_offset = feature_group_data_source.input_start_offset + end_offset = feature_group_data_source.input_end_offset + + if table_format == "Iceberg": + data_catalog_config = feature_group["OfflineStoreConfig"]["DataCatalogConfig"] + return self.load_from_iceberg_table( + IcebergTableDataSource( + offline_store_uri, + data_catalog_config["Catalog"], + data_catalog_config["Database"], + data_catalog_config["TableName"], + ), + feature_group["EventTimeFeatureName"], + start_offset, + end_offset, + ) + + return self.load_from_date_partitioned_s3( + ParquetDataSource(offline_store_uri), start_offset, end_offset + ) + + def load_from_date_partitioned_s3( + self, + s3_data_source: ParquetDataSource, + input_start_offset: str, + input_end_offset: str, + ) -> DataFrame: + """Load the contents from a Feature Group's partitioned offline S3 as a DataFrame. + + Args: + s3_data_source (ParquetDataSource): + A data source that is based in S3. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the data loaded from S3. + """ + + spark_session = self.spark_session_factory.spark_session + s3a_uri = s3_data_source.s3_uri.replace("s3://", "s3a://") + filter_condition = self._get_s3_partitions_offset_filter_condition( + input_start_offset, input_end_offset + ) + + logger.info( + "Loading data from %s with filtering condition %s.", + s3a_uri, + filter_condition, + ) + input_df = spark_session.read.parquet(s3a_uri) + if filter_condition: + input_df = input_df.filter(filter_condition) + + return input_df + + def load_from_s3(self, s3_data_source: Union[CSVDataSource, ParquetDataSource]) -> DataFrame: + """Load the contents from an S3 based data source as a DataFrame. + + Args: + s3_data_source (Union[CSVDataSource, ParquetDataSource]): + A data source that is based in S3. + + Raises: + ValueError: If an invalid DataSource is provided. + + Returns: + DataFrame: Contents of the data loaded from S3. + """ + spark_session = self.spark_session_factory.spark_session + s3a_uri = s3_data_source.s3_uri.replace("s3://", "s3a://") + + if isinstance(s3_data_source, CSVDataSource): + # TODO: Accept `schema` parameter. (Inferring schema requires a pass through every row) + logger.info("Loading data from %s.", s3a_uri) + return spark_session.read.csv( + s3a_uri, + header=s3_data_source.csv_header, + inferSchema=s3_data_source.csv_infer_schema, + ) + + if isinstance(s3_data_source, ParquetDataSource): + logger.info("Loading data from %s.", s3a_uri) + return spark_session.read.parquet(s3a_uri) + + raise ValueError("An invalid data source was provided.") + + def load_from_iceberg_table( + self, + iceberg_table_data_source: IcebergTableDataSource, + event_time_feature_name: str, + input_start_offset: str, + input_end_offset: str, + ) -> DataFrame: + """Load the contents from an Iceberg table as a DataFrame. + + Args: + iceberg_table_data_source (IcebergTableDataSource): An Iceberg Table source. + event_time_feature_name (str): Event time feature's name of feature group. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the Iceberg Table as a Spark DataFrame. + """ + catalog = iceberg_table_data_source.catalog.lower() + database = iceberg_table_data_source.database.lower() + table = iceberg_table_data_source.table.lower() + iceberg_table = f"{catalog}.{database}.{table}" + + spark_session = self.spark_session_factory.get_spark_session_with_iceberg_config( + iceberg_table_data_source.warehouse_s3_uri, catalog + ) + + filter_condition = self._get_iceberg_offset_filter_condition( + event_time_feature_name, + input_start_offset, + input_end_offset, + ) + + iceberg_df = spark_session.table(iceberg_table) + + if filter_condition: + logger.info( + "The filter condition for iceberg feature group is %s.", + filter_condition, + ) + iceberg_df = iceberg_df.filter(filter_condition) + + return iceberg_df + + def _get_iceberg_offset_filter_condition( + self, + event_time_feature_name: str, + input_start_offset: str, + input_end_offset: str, + ): + """Load the contents from an Iceberg table as a DataFrame. + + Args: + iceberg_table_data_source (IcebergTableDataSource): An Iceberg Table source. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the Iceberg Table as a Spark DataFrame. + """ + if input_start_offset is None and input_end_offset is None: + return None + + offset_parser = InputOffsetParser(self.environment_helper.get_job_scheduled_time()) + start_offset_time = offset_parser.get_iso_format_offset_date(input_start_offset) + end_offset_time = offset_parser.get_iso_format_offset_date(input_end_offset) + + start_condition = ( + f"{event_time_feature_name} >= '{start_offset_time}'" if input_start_offset else None + ) + end_condition = ( + f"{event_time_feature_name} < '{end_offset_time}'" if input_end_offset else None + ) + + conditions = filter(None, [start_condition, end_condition]) + return " AND ".join(conditions) + + def _get_s3_partitions_offset_filter_condition( + self, input_start_offset: str, input_end_offset: str + ) -> str: + """Get s3 partitions filter condition based on input offsets. + + Args: + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + str: A SQL string that defines the condition of time range filter. + """ + if input_start_offset is None and input_end_offset is None: + return None + + offset_parser = InputOffsetParser(self.environment_helper.get_job_scheduled_time()) + ( + start_year, + start_month, + start_day, + start_hour, + ) = offset_parser.get_offset_date_year_month_day_hour(input_start_offset) + ( + end_year, + end_month, + end_day, + end_hour, + ) = offset_parser.get_offset_date_year_month_day_hour(input_end_offset) + + # Include all records that event time is between start_year and end_year + start_year_include_condition = f"year >= '{start_year}'" if input_start_offset else None + end_year_include_condition = f"year <= '{end_year}'" if input_end_offset else None + year_include_condition = " AND ".join( + filter(None, [start_year_include_condition, end_year_include_condition]) + ) + + # Exclude all records that the event time is earlier than the start or later than the end + start_offset_exclude_condition = ( + f"(year = '{start_year}' AND month < '{start_month}') " + f"OR (year = '{start_year}' AND month = '{start_month}' AND day < '{start_day}') " + f"OR (year = '{start_year}' AND month = '{start_month}' AND day = '{start_day}' " + f"AND hour < '{start_hour}')" + if input_start_offset + else None + ) + end_offset_exclude_condition = ( + f"(year = '{end_year}' AND month > '{end_month}') " + f"OR (year = '{end_year}' AND month = '{end_month}' AND day > '{end_day}') " + f"OR (year = '{end_year}' AND month = '{end_month}' AND day = '{end_day}' " + f"AND hour >= '{end_hour}')" + if input_end_offset + else None + ) + offset_exclude_condition = " OR ".join( + filter(None, [start_offset_exclude_condition, end_offset_exclude_condition]) + ) + + filter_condition = f"({year_include_condition}) AND NOT ({offset_exclude_condition})" + + logger.info("The filter condition for hive feature group is %s.", filter_condition) + + return filter_condition + + def _parse_name_from_arn(self, fg_uri: str) -> str: + """Parse a Feature Group's name from an arn. + + Args: + fg_uri (str): a string identifier of the Feature Group. + + Returns: + str: the name of the feature group. + """ + match = re.match(FEATURE_GROUP_ARN_REGEX_PATTERN, fg_uri) + if match: + feature_group_name = match.group(4) + return feature_group_name + return fg_uri diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py new file mode 100644 index 0000000000..89d816af49 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class that parse the input data start and end offset""" +from __future__ import absolute_import + +import re +from typing import Optional, Tuple, Union +from datetime import datetime, timezone +from dateutil.relativedelta import relativedelta +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) + +UNIT_RE = r"(\d+?)\s+([a-z]+?)s?" +VALID_UNITS = ["hour", "day", "week", "month", "year"] + + +class InputOffsetParser: + """Contains methods to parse the input offset to different formats. + + Args: + now (datetime): + The point of time that the parser should calculate offset against. + """ + + def __init__(self, now: Union[datetime, str] = None) -> None: + if now is None: + self.now = datetime.now(timezone.utc) + elif isinstance(now, datetime): + self.now = now + else: + self.now = datetime.strptime(now, EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + + def get_iso_format_offset_date(self, offset: Optional[str]) -> str: + """Get the iso format of target date based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + str: ISO-8061 formatted string of the offset date. + """ + if offset is None: + return None + + offset_datetime = self.get_offset_datetime(offset) + return offset_datetime.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + + def get_offset_datetime(self, offset: Optional[str]) -> datetime: + """Get the datetime format of target date based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + datetime: datetime instance of the offset date. + """ + if offset is None: + return None + + offset_td = InputOffsetParser.parse_offset_to_timedelta(offset) + + return self.now + offset_td + + def get_offset_date_year_month_day_hour( + self, offset: Optional[str] + ) -> Tuple[str, str, str, str]: + """Get the year, month, day and hour based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + Tuple[str, str, str, str]: A tuple that consists of extracted year, month, day, hour from offset date. + """ + if offset is None: + return (None, None, None, None) + + offset_dt = self.get_offset_datetime(offset) + return ( + offset_dt.strftime("%Y"), + offset_dt.strftime("%m"), + offset_dt.strftime("%d"), + offset_dt.strftime("%H"), + ) + + @staticmethod + def parse_offset_to_timedelta(offset: Optional[str]) -> relativedelta: + """Parse the offset to time delta. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Raises: + ValueError: If an offset is provided in a unrecognizable format. + ValueError: If an invalid offset unit is provided. + + Returns: + reletivedelta: Time delta representation of the time offset. + """ + if offset is None: + return None + + unit_match = re.fullmatch(UNIT_RE, offset) + + if not unit_match: + raise ValueError( + f"[{offset}] is not in a valid offset format. " + "Please pass a valid offset e.g '1 day'." + ) + + multiple, unit = unit_match.groups() + + if unit not in VALID_UNITS: + raise ValueError(f"[{unit}] is not a valid offset unit. Supported units: {VALID_UNITS}") + + shift_args = {f"{unit}s": -int(multiple)} + + return relativedelta(**shift_args) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py new file mode 100644 index 0000000000..f5be546e86 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py @@ -0,0 +1,83 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for loading the 'params' argument for the UDF.""" +from __future__ import absolute_import + +from typing import Dict, Union + +import attr + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +@attr.s +class SystemParamsLoader: + """Provides the fields for the params['system'] namespace. + + These are the parameters that the feature_processor automatically loads from various SageMaker + resources. + """ + + _SYSTEM_PARAMS_KEY = "system" + + environment_helper: EnvironmentHelper = attr.ib() + + def get_system_args(self) -> Dict[str, Union[str, Dict]]: + """Generates the system generated parameters for the feature_processor wrapped function. + + Args: + fp_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + + Returns: + Dict[str, Union[str, Dict]]: The system parameters. + """ + + return { + self._SYSTEM_PARAMS_KEY: { + "scheduled_time": self.environment_helper.get_job_scheduled_time(), + } + } + + +@attr.s +class ParamsLoader: + """Provides 'params' argument for the FeatureProcessor.""" + + _PARAMS_KEY = "params" + + system_parameters_arg_provider: SystemParamsLoader = attr.ib() + + def get_parameter_args( + self, + fp_config: FeatureProcessorConfig, + ) -> Dict[str, Union[str, Dict]]: + """Loads the 'params' argument for the FeatureProcessor. + + Args: + fp_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + + Returns: + Dict[str, Union[str, Dict]]: A dictionary that contains both user provided + parameters (feature_processor argument) and system parameters. + """ + return { + self._PARAMS_KEY: { + **(fp_config.parameters or {}), + **self.system_parameters_arg_provider.get_system_args(), + } + } diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py new file mode 100644 index 0000000000..d304185e85 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py @@ -0,0 +1,202 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains factory classes for instantiating Spark objects.""" +from __future__ import absolute_import + +from functools import lru_cache +from typing import List, Tuple, Dict + +import feature_store_pyspark +import feature_store_pyspark.FeatureStoreManager as fsm +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +from pyspark.sql import SparkSession + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +SPARK_APP_NAME = "FeatureProcessor" + + +class SparkSessionFactory: + """Lazy loading, memoizing, instantiation of SparkSessions. + + Useful when you want to defer SparkSession instantiation and provide access to the same + instance throughout the application. + """ + + def __init__( + self, environment_helper: EnvironmentHelper, spark_config: Dict[str, str] = None + ) -> None: + """Initialize the SparkSessionFactory. + + Args: + environment_helper (EnvironmentHelper): A helper class to determine the current + execution. + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. + """ + self.environment_helper = environment_helper + self.spark_config = spark_config + + @property + @lru_cache() + def spark_session(self) -> SparkSession: + """Instantiate a new SparkSession or return the existing one.""" + is_training_job = self.environment_helper.is_training_job() + instance_count = self.environment_helper.get_instance_count() + + spark_configs = self._get_spark_configs(is_training_job) + spark_conf = SparkConf().setAll(spark_configs).setAppName(SPARK_APP_NAME) + + if instance_count == 1: + spark_conf.setMaster("local[*]") + + sc = SparkContext.getOrCreate(conf=spark_conf) + + jsc = sc._jsc # Java Spark Context (JVM SparkContext) + for cfg in self._get_jsc_hadoop_configs(): + jsc.hadoopConfiguration().set(cfg[0], cfg[1]) + + return SparkSession(sparkContext=sc) + + def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]: + """Generate Spark Configurations optimized for feature_processing functionality. + + Args: + is_training_job (bool): a boolean indicating whether the current execution environment + is a training job or not. + + Returns: + List[Tuple[str, str]]: Spark configurations. + """ + spark_configs = [ + ( + "spark.hadoop.fs.s3a.aws.credentials.provider", + ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ), + ), + # spark-3.3.1#recommended-settings-for-writing-to-object-stores - https://tinyurl.com/54rkhef6 + ("spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version", "2"), + ( + "spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored", + "true", + ), + ("spark.hadoop.parquet.enable.summary-metadata", "false"), + # spark-3.3.1#parquet-io-settings https://tinyurl.com/59a7uhwu + ("spark.sql.parquet.mergeSchema", "false"), + ("spark.sql.parquet.filterPushdown", "true"), + ("spark.sql.hive.metastorePartitionPruning", "true"), + # hadoop-aws#performance - https://tinyurl.com/mutxj96f + ("spark.hadoop.fs.s3a.threads.max", "500"), + ("spark.hadoop.fs.s3a.connection.maximum", "500"), + ("spark.hadoop.fs.s3a.experimental.input.fadvise", "normal"), + ("spark.hadoop.fs.s3a.block.size", "128M"), + ("spark.hadoop.fs.s3a.fast.upload.buffer", "disk"), + ("spark.hadoop.fs.trash.interval", "0"), + ("spark.port.maxRetries", "50"), + ] + + if self.spark_config: + spark_configs.extend(self.spark_config.items()) + + if not is_training_job: + fp_spark_jars = feature_store_pyspark.classpath_jars() + fp_spark_packages = [ + "org.apache.hadoop:hadoop-aws:3.3.1", + "org.apache.hadoop:hadoop-common:3.3.1", + ] + + if self.spark_config and "spark.jars" in self.spark_config: + fp_spark_jars.append(self.spark_config.get("spark.jars")) + + if self.spark_config and "spark.jars.packages" in self.spark_config: + fp_spark_packages.append(self.spark_config.get("spark.jars.packages")) + + spark_configs.extend( + ( + ("spark.jars", ",".join(fp_spark_jars)), + ( + "spark.jars.packages", + ",".join(fp_spark_packages), + ), + ) + ) + + return spark_configs + + def _get_jsc_hadoop_configs(self) -> List[Tuple[str, str]]: + """JVM SparkContext Hadoop configurations.""" + # Skip generation of _SUCCESS files to speed up writes. + return [("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")] + + def _get_iceberg_configs(self, warehouse_s3_uri: str, catalog: str) -> List[Tuple[str, str]]: + """Spark configurations for reading and writing data from Iceberg Table sources. + + Args: + warehouse_s3_uri (str): The S3 URI of the warehouse. + catalog (str): The catalog. + + Returns: + List[Tuple[str, str]]: the Spark configurations. + """ + catalog = catalog.lower() + return [ + (f"spark.sql.catalog.{catalog}", "smfs.shaded.org.apache.iceberg.spark.SparkCatalog"), + (f"spark.sql.catalog.{catalog}.warehouse", warehouse_s3_uri), + ( + f"spark.sql.catalog.{catalog}.catalog-impl", + "smfs.shaded.org.apache.iceberg.aws.glue.GlueCatalog", + ), + ( + f"spark.sql.catalog.{catalog}.io-impl", + "smfs.shaded.org.apache.iceberg.aws.s3.S3FileIO", + ), + (f"spark.sql.catalog.{catalog}.glue.skip-name-validation", "true"), + ] + + def get_spark_session_with_iceberg_config(self, warehouse_s3_uri, catalog) -> SparkSession: + """Get an instance of the SparkSession with Iceberg settings configured. + + Args: + warehouse_s3_uri (str): The S3 URI of the warehouse. + catalog (str): The catalog. + + Returns: + SparkSession: A SparkSession ready to support reading and writing data from an Iceberg + Table. + """ + conf = self.spark_session.conf + + for cfg in self._get_iceberg_configs(warehouse_s3_uri, catalog): + conf.set(cfg[0], cfg[1]) + + return self.spark_session + + +class FeatureStoreManagerFactory: + """Lazy loading, memoizing, instantiation of FeatureStoreManagers. + + Useful when you want to defer FeatureStoreManagers instantiation and provide access to the same + instance throughout the application. + """ + + @property + @lru_cache() + def feature_store_manager(self) -> fsm.FeatureStoreManager: + """Instansiate a new FeatureStoreManager.""" + return fsm.FeatureStoreManager() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py new file mode 100644 index 0000000000..bd21e804eb --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py @@ -0,0 +1,239 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for loading arguments for the parameters defined in the UDF.""" +from __future__ import absolute_import + +from abc import ABC, abstractmethod +from inspect import signature +from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union, Optional + +import attr +from pyspark.sql import DataFrame, SparkSession + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, + PySparkDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._params_loader import ParamsLoader +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory + +T = TypeVar("T") + + +@attr.s +class UDFArgProvider(Generic[T], ABC): + """Base class for arguments providers for the UDF. + + Args: + Generic (T): The type of the auto-loaded data values. + """ + + @abstractmethod + def provide_input_args( + self, udf: Callable[..., T], fp_config: FeatureProcessorConfig + ) -> OrderedDict[str, T]: + """Provides a dict of (input name, auto-loaded data) using the feature_processor parameters. + + The input name is the udfs parameter name, and the data source is the one defined at the + same index (as the input name) in fp_config.inputs. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + OrderedDict[str, T]: The loaded data sources, in the same order as fp_config.inputs. + """ + + @abstractmethod + def provide_params_arg( + self, udf: Callable[..., T], fp_config: FeatureProcessorConfig + ) -> Dict[str, Dict]: + """Provides the 'params' argument that is provided to the UDF. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Dict[str, Dict]: A combination of user defined parameters (in fp_config) and system + provided parameters. + """ + + @abstractmethod + def provide_additional_kwargs(self, udf: Callable[..., T]) -> Dict[str, Any]: + """Provides any additional arguments to be provided to the UDF, dependent on the mode. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + + Returns: + Dict[str, Any]: additional kwargs for the user function. + """ + + +@attr.s +class SparkArgProvider(UDFArgProvider[DataFrame]): + """Provides arguments to Spark UDFs.""" + + PARAMS_ARG_NAME = "params" + SPARK_SESSION_ARG_NAME = "spark" + + params_loader: ParamsLoader = attr.ib() + input_loader: SparkDataFrameInputLoader = attr.ib() + spark_session_factory: SparkSessionFactory = attr.ib() + + def provide_input_args( + self, udf: Callable[..., DataFrame], fp_config: FeatureProcessorConfig + ) -> OrderedDict[str, DataFrame]: + """Provide a DataFrame for each requested input. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If the signature of the UDF does not match the fp_config.inputs. + ValueError: If there are no inputs provided to the user defined function. + + Returns: + OrderedDict[str, DataFrame]: The loaded data sources, in the same order as + fp_config.inputs. + """ + udf_parameter_names = list(signature(udf).parameters.keys()) + udf_input_names = self._get_input_parameters(udf_parameter_names) + udf_params = self.params_loader.get_parameter_args(fp_config).get( + self.PARAMS_ARG_NAME, None + ) + + if len(udf_input_names) == 0: + raise ValueError("Expected at least one input to the user defined function.") + + if len(udf_input_names) != len(fp_config.inputs): + raise ValueError( + f"The signature of the user defined function does not match the list of inputs" + f" requested. Expected {len(fp_config.inputs)} parameter(s)." + ) + + return OrderedDict( + (input_name, self._load_data_frame(data_source=input_uri, params=udf_params)) + for (input_name, input_uri) in zip(udf_input_names, fp_config.inputs) + ) + + def provide_params_arg( + self, udf: Callable[..., DataFrame], fp_config: FeatureProcessorConfig + ) -> Dict[str, Union[str, Dict]]: + """Provide params for the UDF. If the udf has a parameter named 'params'. + + Args: + udf (Callable[..., T]): the feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + return ( + self.params_loader.get_parameter_args(fp_config) + if self._has_param(udf, self.PARAMS_ARG_NAME) + else {} + ) + + def provide_additional_kwargs(self, udf: Callable[..., DataFrame]) -> Dict[str, SparkSession]: + """Provide the Spark session. If the udf has a parameter named 'spark'. + + Args: + udf (Callable[..., T]): the feature_processor wrapped user function. + """ + return ( + {self.SPARK_SESSION_ARG_NAME: self.spark_session_factory.spark_session} + if self._has_param(udf, self.SPARK_SESSION_ARG_NAME) + else {} + ) + + def _get_input_parameters(self, udf_parameter_names: List[str]) -> List[str]: + """Parses the parameter names from the UDF that correspond to the input data sources. + + This function assumes that the udf signature's `params` and `spark` parameters are at the + end, in any order, if provided. + + Args: + udf_parameter_names (List[str]): The full list of parameters names in the UDF. + + Returns: + List[str]: A subset of parameter names corresponding to the input data sources. + """ + inputs_end_index = len(udf_parameter_names) - 1 + + # Reduce range based on the position of optional kwargs of the UDF. + if self.PARAMS_ARG_NAME in udf_parameter_names: + inputs_end_index = udf_parameter_names.index(self.PARAMS_ARG_NAME) - 1 + + if self.SPARK_SESSION_ARG_NAME in udf_parameter_names: + inputs_end_index = min( + inputs_end_index, + udf_parameter_names.index(self.SPARK_SESSION_ARG_NAME) - 1, + ) + + return udf_parameter_names[: inputs_end_index + 1] + + def _load_data_frame( + self, + data_source: Union[ + FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource + ], + params: Optional[Dict[str, Union[str, Dict]]] = None, + ) -> DataFrame: + """Given a data source definition, load the data as a Spark DataFrame. + + Args: + data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]): A user specified data source from the feature_processor + decorator's parameters. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. + + Returns: + DataFrame: The contents of the data source as a Spark DataFrame. + """ + if isinstance(data_source, (CSVDataSource, ParquetDataSource)): + return self.input_loader.load_from_s3(data_source) + + if isinstance(data_source, FeatureGroupDataSource): + return self.input_loader.load_from_feature_group(data_source) + + if isinstance(data_source, PySparkDataSource): + spark_session = self.spark_session_factory.spark_session + return data_source.read_data(spark=spark_session, params=params) + + if isinstance(data_source, BaseDataSource): + return data_source.read_data(params=params) + + raise ValueError(f"Unknown data source type: {type(data_source)}") + + def _has_param(self, udf: Callable, name: str) -> bool: + """Determine if a function has a parameter with a given name. + + Args: + udf (Callable): the user defined function. + name (str): the name of the parameter. + + Returns: + bool: True if the udf contains a parameter with the name. + """ + return name in list(signature(udf).parameters.keys()) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py new file mode 100644 index 0000000000..a037e837c2 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py @@ -0,0 +1,98 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for handling UDF outputs""" +from __future__ import absolute_import + +import logging +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +import attr +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor import IngestionError +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, +) + +T = TypeVar("T") + +logger = logging.getLogger("sagemaker") + + +class UDFOutputReceiver(Generic[T], ABC): + """Base class for handling outputs of the UDF.""" + + @abstractmethod + def ingest_udf_output(self, output: T, fp_config: FeatureProcessorConfig) -> None: + """Ingests data to the output feature group. + + Args: + output (T): The output of the feature_processor wrapped function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + + +@attr.s +class SparkOutputReceiver(UDFOutputReceiver[DataFrame]): + """Handles the Spark DataFrame the output from the UDF""" + + feature_store_manager_factory: FeatureStoreManagerFactory = attr.ib() + + def ingest_udf_output(self, output: DataFrame, fp_config: FeatureProcessorConfig) -> None: + """Ingests UDF to the output Feature Group. + + Args: + output (T): The output of the feature_processor wrapped function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + Py4JError: If there is a problem with Py4J, including client code errors. + IngestionError: If any rows are not ingested successfully then a sample of the records, + with failure reasons, is logged. + """ + if fp_config.enable_ingestion is False: + logging.info("Ingestion is disabled. Skipping ingestion.") + return + + logger.info( + "Ingesting transformed data to %s with target_stores: %s", + fp_config.output, + fp_config.target_stores, + ) + + feature_store_manager = self.feature_store_manager_factory.feature_store_manager + try: + feature_store_manager.ingest_data( + input_data_frame=output, + feature_group_arn=fp_config.output, + target_stores=fp_config.target_stores, + ) + except Py4JJavaError as e: + if e.java_exception.getClass().getSimpleName() == "StreamIngestionFailureException": + logger.warning( + "Ingestion did not complete successfully. Failed records and error messages" + " have been printed to the console." + ) + feature_store_manager.get_failed_stream_ingestion_data_frame().show( + n=20, truncate=False + ) + raise IngestionError(e.java_exception) + + raise e + + logger.info("Ingestion to %s complete.", fp_config.output) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py new file mode 100644 index 0000000000..95b07de7c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides a wrapper for user provided functions.""" +from __future__ import absolute_import + +import functools +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar + +import attr + +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import UDFArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + UDFOutputReceiver, +) + +T = TypeVar("T") + + +@attr.s +class UDFWrapper(Generic[T]): + """Class that wraps a user provided function.""" + + udf_arg_provider: UDFArgProvider[T] = attr.ib() + udf_output_receiver: UDFOutputReceiver[T] = attr.ib() + + def wrap(self, udf: Callable[..., T], fp_config: FeatureProcessorConfig) -> Callable[..., None]: + """Wrap the provided UDF with the logic defined by the FeatureProcessorConfig. + + General functionality of the wrapper function includes but is not limited to loading data + sources and ingesting output data to a Feature Group. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Callable[..., None]: the user provided function wrapped with feature_processor logic. + """ + + @functools.wraps(udf) + def wrapper() -> None: + udf_args, udf_kwargs = self._prepare_udf_args( + udf=udf, + fp_config=fp_config, + ) + + output = udf(*udf_args, **udf_kwargs) + + self.udf_output_receiver.ingest_udf_output(output, fp_config) + + return wrapper + + def _prepare_udf_args( + self, + udf: Callable[..., T], + fp_config: FeatureProcessorConfig, + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """Generate the arguments for the user defined function, provided by the wrapper function. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Tuple[Tuple[Any, ...], Dict[str, Any]]: A tuple positional arguments and keyword + arguments for the UDF. + """ + args = () + kwargs = { + **self.udf_arg_provider.provide_input_args(udf, fp_config), + **self.udf_arg_provider.provide_params_arg(udf, fp_config), + **self.udf_arg_provider.provide_additional_kwargs(udf), + } + + return (args, kwargs) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py new file mode 100644 index 0000000000..307838be0c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py @@ -0,0 +1,210 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module that contains validators and a validation chain""" +from __future__ import absolute_import + +import inspect +import re +from abc import ABC, abstractmethod +from typing import Any, Callable, List + +import attr + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + FeatureGroupDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) + + +@attr.s +class Validator(ABC): + """Base class for all validators. Errors are raised if validation fails.""" + + @abstractmethod + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validates FeatureProcessorConfig and a UDF.""" + + +@attr.s +class ValidatorChain: + """Executes a series of validators.""" + + validators: List[Validator] = attr.ib() + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validates a value using the list of validators. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If there are any validation errors raised by the validators in this chain. + """ + for validator in self.validators: + validator.validate(udf, fp_config) + + +class FeatureProcessorArgValidator(Validator): + """A validator for arguments provided to FeatureProcessor.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Temporary validator for unsupported feature_processor parameters. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + # TODO: Validate target_stores values. + + +class InputValidator(Validator): + """A validator for the 'input' parameter.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the arguments provided to the decorator's input parameter. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If no inputs are provided. + """ + + inputs = fp_config.inputs + if inputs is None or len(inputs) == 0: + raise ValueError("At least one input is required.") + + +class SparkUDFSignatureValidator(Validator): + """A validator for PySpark UDF signatures.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the signature of the UDF based on the configurations provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. No input provided to feature_processor. + 2. Number of provided parameters does not match with that of provided inputs. + 3. Required parameters are not provided in the right order. + """ + parameters = list(inspect.signature(udf).parameters.keys()) + input_parameters = self._get_input_params(udf) + if len(input_parameters) < 1: + raise ValueError("feature_processor expects at least 1 input parameter.") + + # Validate count of input parameters against requested inputs. + num_data_sources = len(fp_config.inputs) + if len(input_parameters) != num_data_sources: + raise ValueError( + f"feature_processor expected a function with ({num_data_sources}) parameter(s)" + f" before any optional 'params' or 'spark' parameters for the ({num_data_sources})" + f" requested data source(s)." + ) + + # Validate position of non-input parameters. + if "params" in parameters and parameters[-1] != "params" and parameters[-2] != "params": + raise ValueError( + "feature_processor expected the 'params' parameter to be the last or second last" + " parameter after input parameters." + ) + + if "spark" in parameters and parameters[-1] != "spark" and parameters[-2] != "spark": + raise ValueError( + "feature_processor expected the 'spark' parameter to be the last or second last" + " parameter after input parameters." + ) + + def _get_input_params(self, udf: Callable[..., Any]) -> List[str]: + """Get the parameters that correspond to the inputs for a UDF. + + Args: + udf (Callable[..., Any]): the user provided function. + """ + parameters = list(inspect.signature(udf).parameters.keys()) + + # Remove non-input parameter names. + if "params" in parameters: + parameters.remove("params") + if "spark" in parameters: + parameters.remove("spark") + + return parameters + + +class InputOffsetValidator(Validator): + """An Validator for input offset.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the start and end input offset provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when input_start_offset is later than + input_end_offset. + """ + + for config_input in fp_config.inputs: + if isinstance(input, FeatureGroupDataSource): + input_start_offset = config_input.input_start_offset + input_end_offset = config_input.input_end_offset + start_td = InputOffsetParser.parse_offset_to_timedelta(input_start_offset) + end_td = InputOffsetParser.parse_offset_to_timedelta(input_end_offset) + if start_td and end_td and start_td > end_td: + raise ValueError("input_start_offset should be always before input_end_offset.") + + +class BaseDataSourceValidator(Validator): + """An Validator for BaseDataSource.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the BaseDataSource provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when data_source_unique_id or data_source_name + of the input data source is not valid. + """ + + for config_input in fp_config.inputs: + if isinstance(config_input, BaseDataSource): + source_name = config_input.data_source_name + source_id = config_input.data_source_unique_id + + source_name_pattern = r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$" + source_id_pattern = r"^.{1,2048}$" + + if not re.match(source_name_pattern, source_name): + raise ValueError( + f"data_source_name of input does not match pattern '{source_name_pattern}'." + ) + + if not re.match(source_id_pattern, source_id): + raise ValueError( + f"data_source_unique_id of input does not match " + f"pattern '{source_id_pattern}'." + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py new file mode 100644 index 0000000000..31593a3f1c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Processor decorator for feature transformation functions.""" +from __future__ import absolute_import + +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from sagemaker.mlops.feature_store.feature_processor import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +def feature_processor( + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], + output: str, + target_stores: Optional[List[str]] = None, + parameters: Optional[Dict[str, Union[str, Dict]]] = None, + enable_ingestion: bool = True, + spark_config: Dict[str, str] = None, +) -> Callable: + """Decorator to facilitate feature engineering for Feature Groups. + + If the decorated function is executed without arguments then the decorated function's arguments + are automatically loaded from the input data sources. Outputs are ingested to the output Feature + Group. If arguments are provided to this function, then arguments are not automatically loaded + (for testing). + + Decorated functions must conform to the expected signature. Parameters: one parameter of type + pyspark.sql.DataFrame for each DataSource in 'inputs'; followed by the optional parameters with + names and types in [params: Dict[str, Any], spark: SparkSession]. Outputs: a single return + value of type pyspark.sql.DataFrame. The function can have any name. + + **Example:** + + .. code-block:: python + + @feature_processor( + inputs=[FeatureGroupDataSource("input-fg"), CSVDataSource("s3://bucket/prefix)], + output='arn:aws:sagemaker:us-west-2:123456789012:feature-group/output-fg' + ) + def transform( + input_feature_group: DataFrame, input_csv: DataFrame, params: Dict[str, Any], + spark: SparkSession + ) -> DataFrame: + return ... + + **More concisely:** + + .. code-block:: python + + @feature_processor( + inputs=[FeatureGroupDataSource("input-fg"), CSVDataSource("s3://bucket/prefix)], + output='arn:aws:sagemaker:us-west-2:123456789012:feature-group/output-fg' + ) + def transform(input_feature_group, input_csv): + return ... + + Args: + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource,\ + BaseDataSource]]): A list of data sources. + output (str): A Feature Group ARN to write results of this function to. + target_stores (Optional[list[str]], optional): A list containing at least one of + 'OnlineStore' or 'OfflineStore'. If unspecified, data will be ingested to the enabled + stores of the output feature group. Defaults to None. + parameters (Optional[Dict[str, Union[str, Dict]]], optional): Parameters to be provided to + the decorated function, available as the 'params' argument. Useful for parameterized + functions. The params argument also contains the set of system provided parameters + under the key 'system'. E.g. 'scheduled_time': a timestamp representing the time that + the execution was scheduled to execute at, if triggered by a Scheduler, otherwise, the + current time. + enable_ingestion (bool, optional): A boolean indicating whether the decorated function's + return value is ingested to the 'output' Feature Group. This flag is useful during the + development phase to ensure that data is not used until the function is ready. It also + useful for users that want to manage their own data ingestion. Defaults to True. + spark_config (Dict[str, str]): A dict contains the key-value paris for Spark configurations. + + Raises: + IngestionError: If any rows are not ingested successfully then a sample of the records, + with failure reasons, is logged. + + Returns: + Callable: The decorated function. + """ + + def decorator(udf: Callable[..., Any]) -> Callable: + fp_config = FeatureProcessorConfig.create( + inputs=inputs, + output=output, + mode=FeatureProcessorMode.PYSPARK, + target_stores=target_stores, + parameters=parameters, + enable_ingestion=enable_ingestion, + spark_config=spark_config, + ) + + validator_chain = ValidatorFactory.get_validation_chain(fp_config) + udf_wrapper = UDFWrapperFactory.get_udf_wrapper(fp_config) + + validator_chain.validate(udf=udf, fp_config=fp_config) + wrapped_function = udf_wrapper.wrap(udf=udf, fp_config=fp_config) + + wrapped_function.feature_processor_config = fp_config + + return wrapped_function + + return decorator diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py new file mode 100644 index 0000000000..c9039d982c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py @@ -0,0 +1,1100 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Processor schedule APIs.""" +from __future__ import absolute_import +import logging +import json +import re +from datetime import datetime +from typing import Callable, List, Optional, Dict, Sequence, Union, Any, Tuple + +import pytz +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._config_uploader import ConfigUploader +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_group_lineage_context_name, + _get_feature_group_pipeline_lineage_context_name, + _get_feature_group_pipeline_version_lineage_context_name, + _get_feature_processor_pipeline_lineage_context_name, + _get_feature_processor_pipeline_version_lineage_context_name, +) +from sagemaker.core.lineage import context +from sagemaker.core.lineage._utils import get_resource_name_from_arn +from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, +) + +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.network import SUBNETS_KEY, SECURITY_GROUP_IDS_KEY + +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, + RESOURCE_NOT_FOUND_EXCEPTION, + SPARK_JAR_FILES_PATH, + SPARK_PY_FILES_PATH, + SPARK_FILES_PATH, + FEATURE_PROCESSOR_TAG_KEY, + FEATURE_PROCESSOR_TAG_VALUE, + PIPELINE_CONTEXT_TYPE, + DEFAULT_SCHEDULE_STATE, + SCHEDULED_TIME_PIPELINE_PARAMETER, + PIPELINE_CONTEXT_NAME_TAG_KEY, + PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + PIPELINE_NAME_MAXIMUM_LENGTH, + RESOURCE_NOT_FOUND, + FEATURE_GROUP_ARN_REGEX_PATTERN, + TO_PIPELINE_RESERVED_TAG_KEYS, + DEFAULT_TRIGGER_STATE, + EVENTBRIDGE_RULE_ARN_REGEX_PATTERN, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + +from sagemaker.core.s3 import s3_path_join + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.workflow.pipeline import Pipeline +from sagemaker.mlops.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) + +from sagemaker.mlops.workflow.steps import TrainingStep + +from sagemaker.train.model_trainer import ModelTrainer +from sagemaker.train.configs import Compute, Networking, StoppingCondition, SourceCode, Tag +from sagemaker.core.shapes import OutputDataConfig +from sagemaker.core.workflow.pipeline_context import PipelineSession + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage import ( + FeatureProcessorLineageHandler, + TransformationCode, +) + +from sagemaker.core.remote_function.job import ( + _JobSettings, + JOBS_CONTAINER_ENTRYPOINT, + SPARK_APP_SCRIPT_PATH, +) + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, +) + +logger = logging.getLogger("sagemaker") + + +def to_pipeline( + pipeline_name: str, + step: Callable, + role: Optional[str] = None, + transformation_code: Optional[TransformationCode] = None, + max_retries: Optional[int] = None, + tags: Optional[List[Tuple[str, str]]] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates a sagemaker pipeline that takes in a callable as a training step. + + To configure training step used in sagemaker pipeline, input argument step needs to be wrapped + by remote decorator in module sagemaker.remote_function. If not wrapped by remote decorator, + default configurations in sagemaker.remote_function.job._JobSettings will be used to create + training step. + + Args: + pipeline_name (str): The name of the pipeline. + step (Callable): A user provided function wrapped by feature_processor and optionally + wrapped by remote_decorator. + role (Optional[str]): The Amazon Resource Name (ARN) of the role used by the pipeline to + access and create resources. If not specified, it will default to the credentials + provided by the AWS configuration chain. + transformation_code (Optional[str]): The data source for a reference to the transformation + code for Lineage tracking. This code is not used for actual transformation. + max_retries (Optional[int]): The number of times to retry sagemaker pipeline step. + If not specified, sagemaker pipline step will not retry. + tags (List[Tuple[str, str]): A list of tags attached to the pipeline and all corresponding + lineage resources that support tags. If not specified, no custom tags will be attached. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: SageMaker Pipeline ARN. + """ + + _validate_input_for_to_pipeline_api(pipeline_name, step) + if tags: + _validate_tags_for_to_pipeline_api(tags) + + _sagemaker_session = sagemaker_session or Session() + + _validate_lineage_resources_for_to_pipeline_api( + step.feature_processor_config, _sagemaker_session + ) + + remote_decorator_config = _get_remote_decorator_config_from_input( + wrapped_func=step, sagemaker_session=_sagemaker_session + ) + _role = role or get_execution_role(_sagemaker_session) + + runtime_env_manager = RuntimeEnvironmentManager() + client_python_version = runtime_env_manager._current_python_version() + config_uploader = ConfigUploader(remote_decorator_config, runtime_env_manager) + + s3_base_uri = s3_path_join(remote_decorator_config.s3_root_uri, pipeline_name) + + ( + input_data_config, + spark_dependency_paths, + ) = config_uploader.prepare_step_input_channel_for_spark_mode( + func=getattr(step, "wrapped_func", step), + s3_base_uri=s3_base_uri, + sagemaker_session=_sagemaker_session, + ) + + pipeline_session = PipelineSession( + boto_session=_sagemaker_session.boto_session, + default_bucket=_sagemaker_session.default_bucket(), + default_bucket_prefix=_sagemaker_session.default_bucket_prefix, + ) + logger.info("Created PipelineSession for pipeline %s", pipeline_name) + + model_trainer = _prepare_model_trainer_from_remote_decorator_config( + remote_decorator_config=remote_decorator_config, + s3_base_uri=s3_base_uri, + client_python_version=client_python_version, + spark_dependency_paths=spark_dependency_paths, + pipeline_session=pipeline_session, + role=_role, + ) + + step_args = model_trainer.train(input_data_config=input_data_config) + logger.info("Obtained step_args from ModelTrainer.train() for pipeline %s", pipeline_name) + + step_name = "-".join([pipeline_name, "feature-processor"]) + training_step_request_dict = dict( + name=step_name, + step_args=step_args, + ) + logger.info("Created TrainingStep '%s' with step_args", step_name) + + if max_retries: + training_step_request_dict["retry_policies"] = [ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + max_attempts=max_retries, + ), + SageMakerJobStepRetryPolicy( + exception_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT, + ], + max_attempts=max_retries, + ), + ] + + pipeline_request_dict = dict( + name=pipeline_name, + steps=[TrainingStep(**training_step_request_dict)], + sagemaker_session=_sagemaker_session, + parameters=[SCHEDULED_TIME_PIPELINE_PARAMETER], + ) + pipeline_tags = [dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)] + if tags: + pipeline_tags.extend([dict(Key=k, Value=v) for k, v in tags]) + + pipeline = Pipeline(**pipeline_request_dict) + logger.info("Creating/Updating sagemaker pipeline %s", pipeline_name) + pipeline.upsert( + role_arn=_role, + tags=pipeline_tags, + ) + logger.info("Created sagemaker pipeline %s", pipeline_name) + + describe_pipeline_response = pipeline.describe() + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=pipeline_name, + pipeline_arn=pipeline_arn, + pipeline=describe_pipeline_response, + inputs=_get_feature_processor_inputs(wrapped_func=step), + output=_get_feature_processor_outputs(wrapped_func=step), + transformation_code=transformation_code, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_lineage(tags_propagate_to_lineage_resources) + lineage_handler.upsert_tags_for_lineage_resources(tags_propagate_to_lineage_resources) + + pipeline_lineage_names: Dict[str, str] = lineage_handler.get_pipeline_lineage_names() + + if pipeline_lineage_names is None: + raise RuntimeError("Failed to retrieve pipeline lineage. Pipeline Lineage does not exist") + + pipeline.upsert( + role_arn=_role, + tags=[ + { + "Key": PIPELINE_CONTEXT_NAME_TAG_KEY, + "Value": pipeline_lineage_names["pipeline_context_name"], + }, + { + "Key": PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + "Value": pipeline_lineage_names["pipeline_version_context_name"], + }, + ], + ) + return pipeline_arn + + +def schedule( + pipeline_name: str, + schedule_expression: str, + role_arn: Optional[str] = None, + state: Optional[str] = DEFAULT_SCHEDULE_STATE, + start_date: Optional[datetime] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates an EventBridge Schedule that schedules executions of a sagemaker pipeline. + + The pipeline created will also have a pipeline parameter `scheduled-time` indicating when the + pipeline is scheduled to run. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be scheduled. + schedule_expression (str): The expression that defines when the schedule runs. It supports + at expression, rate expression and cron expression. See the + `CreateSchedule API + `_ + for more details. + state (str): Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See the `State request parameter + `_ + for more details. If not specified, it will default to ENABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + role_arn (Optional[str]): The Amazon Resource Name (ARN) of the IAM role that EventBridge + Scheduler will assume for this target when the schedule is invoked. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The EventBridge Schedule ARN. + """ + + _sagemaker_session = sagemaker_session or Session() + _validate_pipeline_lineage_resources(pipeline_name, _sagemaker_session) + _start_date = start_date or datetime.now(tz=pytz.utc) + _role_arn = role_arn or get_execution_role(_sagemaker_session) + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("scheduler"), + ) + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + logger.info("Creating/Updating EventBridge Schedule for pipeline %s.", pipeline_name) + event_bridge_schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=pipeline_name, + pipeline_arn=pipeline_arn, + schedule_expression=schedule_expression, + state=state, + start_date=_start_date, + role=_role_arn, + ) + logger.info("Created/Updated EventBridge Schedule for pipeline %s.", pipeline_name) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=pipeline_name, + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline=describe_pipeline_response, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_schedule_lineage( + pipeline_name=pipeline_name, + schedule_arn=event_bridge_schedule_arn["ScheduleArn"], + schedule_expression=schedule_expression, + state=state, + start_date=_start_date, + tags=tags_propagate_to_lineage_resources, + ) + return event_bridge_schedule_arn["ScheduleArn"] + + +def put_trigger( + source_pipeline_events: List[FeatureProcessorPipelineEvents], + target_pipeline: str, + target_pipeline_parameters: Optional[Dict[str, str]] = None, + state: Optional[str] = DEFAULT_TRIGGER_STATE, + event_pattern: Optional[str] = None, + role_arn: Optional[str] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates an event based trigger that triggers executions of a sagemaker pipeline. + + Args: + source_pipeline_events (List[FeatureProcessorPipelineEvents]): The list of + FeatureProcessorPipelineEvents that will trigger the target_pipeline. + target_pipeline (str): The name of the SageMaker Pipeline that will be triggered. + target_pipeline_parameters (Optional[Dict[str, str]]): The list of parameters to start + execution of a pipeline. + state (Optional[str]): Indicates whether the rule is enabled or disabled. + If not specified, it will default to ENABLED. + event_pattern (Optional[str]): The EventBridge EventPattern that triggers the + target_pipeline. If specified, will override source_pipeline_events. For more + information, see + https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-event-patterns.html + in the Amazon EventBridge User Guide. + role_arn (Optional[str]): The Amazon Resource Name (ARN) of the IAM role that EventBridge + Scheduler will assume for this target when the schedule is invoked. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The EventBridge Rule ARN. + """ + _sagemaker_session = sagemaker_session or Session() + _role_arn = role_arn or get_execution_role(_sagemaker_session) + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + logger.info("Creating/Updating EventBridge Rule for pipeline %s.", target_pipeline) + rule_arn = event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline=target_pipeline, + event_pattern=event_pattern, + state=state, + ) + rule_name = _parse_name_from_arn(rule_arn, EVENTBRIDGE_RULE_ARN_REGEX_PATTERN) + logger.info("Created/Updated EventBridge Rule for pipeline %s.", target_pipeline) + + logger.info("Attaching pipeline %s to EventBridge Rule %s as target", target_pipeline, rule_arn) + event_bridge_rule_helper.put_target( + rule_name=rule_name, + target_pipeline=target_pipeline, + target_pipeline_parameters=target_pipeline_parameters, + role_arn=_role_arn, + ) + logger.info("Attached pipeline %s to EventBridge Rule %s as target", target_pipeline, rule_arn) + + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=target_pipeline + ) + describe_rule_response = event_bridge_rule_helper.describe_rule(rule_name=rule_name) + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + event_bridge_rule_helper.add_tags(rule_arn=rule_arn, tags=tags_propagate_to_lineage_resources) + + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=target_pipeline, + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline=describe_pipeline_response, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_trigger_lineage( + pipeline_name=target_pipeline, + trigger_arn=rule_arn, + state=state, + tags=tags_propagate_to_lineage_resources, + event_pattern=describe_rule_response["EventPattern"], + ) + return rule_arn + + +def enable_trigger( + pipeline_name: str, + sagemaker_session: Optional[Session] = None, +) -> None: + """Enable the EventBridge Rule that is associated with the pipeline. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule_helper.enable_rule(rule_name=pipeline_name) + logger.info("Enabled EventBridge Rule for pipeline %s.", pipeline_name) + + +def disable_trigger(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Disable the EventBridge Rule that is associated with the pipeline. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule_helper.disable_rule(rule_name=pipeline_name) + logger.info("Disabled EventBridge Rule for pipeline %s.", pipeline_name) + + +def execute( + pipeline_name: str, + execution_time: Optional[datetime] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Starts an execution of a SageMaker Pipeline created by feature_processor + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + execution_time (datetime): The date, in UTC, will be used as a sagemaker pipeline parameter + indicating the time which at which the execution is scheduled to execute. If not + specified, it will default to the current timestamp. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The pipeline execution ARN. + """ + _sagemaker_session = sagemaker_session or Session() + _validate_pipeline_lineage_resources(pipeline_name, _sagemaker_session) + _execution_time = execution_time or datetime.now() + start_pipeline_execution_request = dict( + PipelineName=pipeline_name, + PipelineParameters=[ + dict( + Name=EXECUTION_TIME_PIPELINE_PARAMETER, + Value=_execution_time.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ) + ], + ) + logger.info("Starting an execution for pipline %s", pipeline_name) + execution_response = _sagemaker_session.sagemaker_client.start_pipeline_execution( + **start_pipeline_execution_request + ) + execution_arn = execution_response["PipelineExecutionArn"] + logger.info( + "Execution %s for pipeline %s is successfully started.", + execution_arn, + pipeline_name, + ) + return execution_arn + + +def delete_schedule(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Delete EventBridge Schedule corresponding to a SageMaker Pipeline if there is one. + + Args: + pipeline_name (str): The name of the SageMaker Pipeline that needs to be deleted + sagemaker_session: (Optional[Session], optional): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, _sagemaker_session.boto_session.client("scheduler") + ) + try: + event_bridge_scheduler_helper.delete_schedule(pipeline_name) + logger.info("Deleted EventBridge Schedule for pipeline %s.", pipeline_name) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION != e.response["Error"]["Code"]: + raise e + + +def delete_trigger(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Delete EventBridge Rule corresponding to a SageMaker Pipeline if there is one. + + Args: + pipeline_name (str): The name of the SageMaker Pipeline that needs to be deleted + sagemaker_session: (Optional[Session], optional): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + try: + target_ids = [] + for page in event_bridge_rule_helper.list_targets_by_rule(pipeline_name): + target_ids.extend([target["Id"] for target in page["Targets"]]) + event_bridge_rule_helper.remove_targets(rule_name=pipeline_name, ids=target_ids) + event_bridge_rule_helper.delete_rule(pipeline_name) + logger.info("Deleted EventBridge Rule for pipeline %s.", pipeline_name) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION != e.response["Error"]["Code"]: + raise e + + +def describe( + pipeline_name: str, sagemaker_session: Optional[Session] = None +) -> Dict[str, Union[int, str]]: + """Describe feature processor and other related resources. + + This API will include details related to the feature processor including SageMaker Pipeline and + EventBridge Schedule. + + Args: + pipeline_name (str): Name of the pipeline. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + Dict[str, Union[int, str]]: Return information for resources related to feature processor. + """ + + _sagemaker_session = sagemaker_session or Session() + describe_response_dict = {} + + try: + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_definition = json.loads(describe_pipeline_response["PipelineDefinition"]) + pipeline_step = pipeline_definition["Steps"][0] + describe_response_dict = dict( + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline_execution_role_arn=describe_pipeline_response["RoleArn"], + ) + + if "RetryPolicies" in pipeline_step: + describe_response_dict["max_retries"] = pipeline_step["RetryPolicies"][0]["MaxAttempts"] + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("Pipeline %s does not exist.", pipeline_name) + + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("scheduler"), + ) + + event_bridge_schedule = event_bridge_scheduler_helper.describe_schedule(pipeline_name) + if event_bridge_schedule: + describe_response_dict.update( + dict( + schedule_arn=event_bridge_schedule["Arn"], + schedule_expression=event_bridge_schedule["ScheduleExpression"], + schedule_state=event_bridge_schedule["State"], + schedule_start_date=event_bridge_schedule["StartDate"].strftime( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT + ), + schedule_role=event_bridge_schedule["Target"]["RoleArn"], + ) + ) + + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_based_trigger = event_bridge_rule_helper.describe_rule(pipeline_name) + if event_based_trigger: + describe_response_dict.update( + dict( + trigger=event_based_trigger["Arn"], + event_pattern=event_based_trigger["EventPattern"], + trigger_state=event_based_trigger["State"], + ) + ) + + return describe_response_dict + + +def list_pipelines(sagemaker_session: Optional[Session] = None) -> List[Dict[str, Any]]: + """Lists all SageMaker Pipelines created by Feature Processor SDK. + + Args: + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + List[Dict[str, Any]]: Return list of SageMaker Pipeline metadata created for + feature_processor. + """ + + _sagemaker_session = sagemaker_session or Session() + next_token = None + list_response = [] + pipeline_names_so_far = set([]) + while True: + list_contexts_request = dict(ContextType=PIPELINE_CONTEXT_TYPE) + if next_token: + list_contexts_request["NextToken"] = next_token + list_contexts_response = _sagemaker_session.sagemaker_client.list_contexts( + **list_contexts_request + ) + for _context in list_contexts_response["ContextSummaries"]: + pipeline_name = get_resource_name_from_arn(_context["Source"]["SourceUri"]) + if pipeline_name not in pipeline_names_so_far: + list_response.append(dict(pipeline_name=pipeline_name)) + pipeline_names_so_far.add(pipeline_name) + next_token = list_contexts_response.get("NextToken") + if not next_token: + break + + return list_response + + +def _validate_input_for_to_pipeline_api(pipeline_name: str, step: Callable) -> None: + """Validate input to to_pipeline API. + + The provided callable is considered valid if it's wrapped by feature_processor decorator + and uses pyspark mode. + + Args: + pipeline_name (str): The name of the pipeline. + step (Callable): A user provided function wrapped by feature_processor and optionally + wrapped by remote_decorator. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. pipeline name is longer than 80 characters. + 2. function is not annotated with either feature_processor or remote decorator. + 3. provides a mode other than pyspark. + """ + if len(pipeline_name) > PIPELINE_NAME_MAXIMUM_LENGTH: + raise ValueError( + "Pipeline name used by feature processor should be less than 80 " + "characters. Please choose another pipeline name." + ) + + if not hasattr(step, "feature_processor_config") or not step.feature_processor_config: + raise ValueError( + "Please wrap step parameter with feature_processor decorator" + " in order to use to_pipeline API." + ) + + if not hasattr(step, "job_settings") or not step.job_settings: + raise ValueError( + "Please wrap step parameter with remote decorator in order to use to_pipeline API." + ) + + if FeatureProcessorMode.PYSPARK != step.feature_processor_config.mode: + raise ValueError( + f"Mode {step.feature_processor_config.mode} is not supported by to_pipeline API." + ) + + +def _validate_tags_for_to_pipeline_api(tags: List[Tuple[str, str]]) -> None: + """Validate tags provided to to_pipeline API. + + Args: + tags (List[Tuple[str, str]]): A list of tags attached to the pipeline. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. reserved tag keys are provided to API. + """ + provided_tag_keys = [tag_key_value_pair[0] for tag_key_value_pair in tags] + for reserved_tag_key in TO_PIPELINE_RESERVED_TAG_KEYS: + if reserved_tag_key in provided_tag_keys: + raise ValueError( + f"{reserved_tag_key} is a reserved tag key for to_pipeline API. Please choose another tag." + ) + + +def _validate_lineage_resources_for_to_pipeline_api( + feature_processor_config: FeatureProcessorConfig, sagemaker_session: Session +) -> None: + """Validate existence of feature group lineage resources for to_pipeline API. + + Args: + feature_processor_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + """ + inputs = feature_processor_config.inputs + output = feature_processor_config.output + for ds in inputs: + if isinstance(ds, FeatureGroupDataSource): + fg_name = _parse_name_from_arn(ds.name) + _validate_fg_lineage_resources(fg_name, sagemaker_session) + output_fg_name = _parse_name_from_arn(output) + _validate_fg_lineage_resources(output_fg_name, sagemaker_session) + + +def _validate_fg_lineage_resources(feature_group_name: str, sagemaker_session: Session) -> None: + """Validate existence of feature group lineage resources. + + Args: + feature_group_name (str): The name or arn of the feature group. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + + Raises (ValueError): raises ValueError when lineage resources are not created for feature + groups. + """ + + feature_group = sagemaker_session.describe_feature_group(feature_group_name=feature_group_name) + feature_group_creation_time = feature_group["CreationTime"].strftime("%s") + feature_group_context = _get_feature_group_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + feature_group_pipeline_context = _get_feature_group_pipeline_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + feature_group_pipeline_version_context = ( + _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + ) + for context_name in [ + feature_group_context, + feature_group_pipeline_context, + feature_group_pipeline_version_context, + ]: + try: + logger.info("Verifying existence of context %s.", context_name) + context.Context.load(context_name=context_name, sagemaker_session=sagemaker_session) + except ClientError as e: + if RESOURCE_NOT_FOUND == e.response["Error"]["Code"]: + raise ValueError( + f"Lineage resource {context_name} has not yet been created for feature group" + f" {feature_group_name} or has already been deleted. Please try again later." + ) + raise e + + +def _validate_pipeline_lineage_resources(pipeline_name: str, sagemaker_session: Session) -> None: + """Validate existence of pipeline lineage resources. + + Args: + pipeline_name (str): The name of the pipeline. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + """ + pipeline = sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=pipeline_name) + pipeline_creation_time = pipeline["CreationTime"].strftime("%s") + pipeline_context_name = _get_feature_processor_pipeline_lineage_context_name( + pipeline_name=pipeline_name, pipeline_creation_time=pipeline_creation_time + ) + try: + pipeline_context = context.Context.load( + context_name=pipeline_context_name, sagemaker_session=sagemaker_session + ) + last_update_time = pipeline_context.properties["LastUpdateTime"] + pipeline_version_context_name = ( + _get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name=pipeline_name, pipeline_last_update_time=last_update_time + ) + ) + context.Context.load( + context_name=pipeline_version_context_name, sagemaker_session=sagemaker_session + ) + except ClientError as e: + if RESOURCE_NOT_FOUND == e.response["Error"]["Code"]: + raise ValueError( + "Pipeline lineage resources have not been created yet or have already been deleted" + ". Please try again later." + ) + raise e + + +def _prepare_model_trainer_from_remote_decorator_config( + remote_decorator_config: _JobSettings, + s3_base_uri: str, + client_python_version: str, + spark_dependency_paths: Dict[str, Optional[str]], + pipeline_session: PipelineSession, + role: str, +) -> ModelTrainer: + """Prepares a ModelTrainer instance from remote decorator configuration. + + Args: + remote_decorator_config (_JobSettings): Configurations used for setting up + SageMaker Pipeline Step. + s3_base_uri (str): S3 URI used as destination for dependencies upload. + client_python_version (str): Python version used on client side. + spark_dependency_paths (Dict[str, Optional[str]]): A dictionary contains S3 paths spark + dependency files get uploaded to if present. + pipeline_session (PipelineSession): Pipeline-aware session that causes + ModelTrainer.train() to return step arguments instead of launching a job. + role (str): The IAM role ARN for the training job. + Returns: + ModelTrainer: A configured ModelTrainer instance. + """ + logger.info("Mapping remote decorator config to ModelTrainer params") + + # Build environment dict from remote_decorator_config (strings only for Pydantic validation) + environment = dict(remote_decorator_config.environment_variables or {}) + + # Build command from container entry point and arguments + entry_point_and_args = _get_container_entry_point_and_arguments( + remote_decorator_config=remote_decorator_config, + s3_base_uri=s3_base_uri, + client_python_version=client_python_version, + spark_dependency_paths=spark_dependency_paths, + ) + joined_command = " ".join( + entry_point_and_args["container_entry_point"] + + entry_point_and_args["container_arguments"] + ) + source_code = SourceCode(command=joined_command) + logger.info("SourceCode command: %s", joined_command) + + # Create Compute config + compute = Compute( + instance_type=remote_decorator_config.instance_type, + instance_count=remote_decorator_config.instance_count, + volume_size_in_gb=remote_decorator_config.volume_size, + volume_kms_key_id=remote_decorator_config.volume_kms_key, + ) + logger.info( + "Compute: instance_type=%s, instance_count=%s, volume_size=%s", + remote_decorator_config.instance_type, + remote_decorator_config.instance_count, + remote_decorator_config.volume_size, + ) + + # Create Networking config if VPC config is present + networking = None + if remote_decorator_config.vpc_config: + networking = Networking( + subnets=remote_decorator_config.vpc_config[SUBNETS_KEY], + security_group_ids=remote_decorator_config.vpc_config[SECURITY_GROUP_IDS_KEY], + enable_inter_container_traffic_encryption=( + remote_decorator_config.encrypt_inter_container_traffic + ), + ) + logger.info( + "Networking: subnets=%s, security_groups=%s, encrypt=%s", + remote_decorator_config.vpc_config[SUBNETS_KEY], + remote_decorator_config.vpc_config[SECURITY_GROUP_IDS_KEY], + remote_decorator_config.encrypt_inter_container_traffic, + ) + + # Create StoppingCondition if max_runtime_in_seconds is configured + stopping_condition = None + if remote_decorator_config.max_runtime_in_seconds: + stopping_condition = StoppingCondition( + max_runtime_in_seconds=remote_decorator_config.max_runtime_in_seconds, + ) + + # Create OutputDataConfig + output_data_config = OutputDataConfig( + s3_output_path=s3_base_uri, + kms_key_id=remote_decorator_config.s3_kms_key, + ) + + # Convert tags from List[Tuple[str, str]] to List[Tag] + tags = None + if remote_decorator_config.tags: + tags = [Tag(key=k, value=v) for k, v in remote_decorator_config.tags] + logger.info("Tags count: %d", len(tags) if tags else 0) + + logger.info("Environment keys: %s", list(environment.keys())) + + model_trainer = ModelTrainer( + training_image=remote_decorator_config.image_uri, + role=role, + sagemaker_session=pipeline_session, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + output_data_config=output_data_config, + source_code=source_code, + training_input_mode="File", + environment=environment, + tags=tags, + ) + + # Inject SCHEDULED_TIME_PIPELINE_PARAMETER after construction to bypass Pydantic + # validation (Parameter is not a string). The @runnable_by_pipeline decorator resolves + # Parameter objects to strings during pipeline definition serialization. + model_trainer.environment[EXECUTION_TIME_PIPELINE_PARAMETER] = SCHEDULED_TIME_PIPELINE_PARAMETER + + logger.info( + "Created ModelTrainer with image=%s, instance_type=%s, instance_count=%s", + remote_decorator_config.image_uri, + remote_decorator_config.instance_type, + remote_decorator_config.instance_count, + ) + + return model_trainer + + +def _get_container_entry_point_and_arguments( + remote_decorator_config: _JobSettings, + s3_base_uri: str, + client_python_version: str, + spark_dependency_paths: Dict[str, Optional[str]], +) -> Dict[str, List[str]]: + """Extracts the container entry point and container arguments from remote decorator configs + + Args: + remote_decorator_config (_JobSettings): Configurations used for setting up + SageMaker Pipeline Step. + s3_base_uri (str): S3 URI used as destination for dependencies upload. + client_python_version (str): Python version used on client side. + spark_dependency_paths (Dict[str, Optional[str]]): A dictionary contains S3 paths spark + dependency files get uploaded to if present. + Returns: + Dict[str, List[str]]: Request dictionary containing container entry point and + arguments setup. + """ + + spark_config = remote_decorator_config.spark_config + jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT.copy() + + if spark_dependency_paths[SPARK_JAR_FILES_PATH]: + jobs_container_entrypoint.extend(["--jars", spark_dependency_paths[SPARK_JAR_FILES_PATH]]) + + if spark_dependency_paths[SPARK_PY_FILES_PATH]: + jobs_container_entrypoint.extend( + ["--py-files", spark_dependency_paths[SPARK_PY_FILES_PATH]] + ) + + if spark_dependency_paths[SPARK_FILES_PATH]: + jobs_container_entrypoint.extend(["--files", spark_dependency_paths[SPARK_FILES_PATH]]) + + if spark_config and spark_config.spark_event_logs_uri: + jobs_container_entrypoint.extend( + ["--spark-event-logs-s3-uri", spark_config.spark_event_logs_uri] + ) + + if spark_config: + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + + container_args = ["--s3_base_uri", s3_base_uri] + container_args.extend(["--region", remote_decorator_config.sagemaker_session.boto_region_name]) + container_args.extend(["--client_python_version", client_python_version]) + + if remote_decorator_config.s3_kms_key: + container_args.extend(["--s3_kms_key", remote_decorator_config.s3_kms_key]) + + return dict( + container_entry_point=jobs_container_entrypoint, + container_arguments=container_args, + ) + + +def _get_remote_decorator_config_from_input( + wrapped_func: Callable, sagemaker_session: Session +) -> _JobSettings: + """Extracts the remote decorator configuration from the wrapped function and other inputs. + + Args: + wrapped_func (Callable): Wrapped user defined function. If it contains remote decorator + job settings, configs will be used to construct remote_decorator_config, otherwise + default job settings will be used. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + _JobSettings: Configurations used for creating sagemaker pipeline step. + """ + remote_decorator_config = getattr( + wrapped_func, + "job_settings", + ) + # TODO: Remove this after GA + remote_decorator_config.sagemaker_session = sagemaker_session + + # TODO: This needs to be removed when new mode is introduced. + if remote_decorator_config.spark_config is None: + remote_decorator_config.spark_config = SparkConfig() + remote_decorator_config.image_uri = _JobSettings._get_default_spark_image(sagemaker_session) + + return remote_decorator_config + + +def _get_feature_processor_inputs( + wrapped_func: Callable, +) -> Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]]: + """Retrieve Feature Processor Config Inputs""" + feature_processor_config: FeatureProcessorConfig = wrapped_func.feature_processor_config + return feature_processor_config.inputs + + +def _get_feature_processor_outputs( + wrapped_func: Callable, +) -> str: + """Retrieve Feature Processor Config Output""" + feature_processor_config: FeatureProcessorConfig = wrapped_func.feature_processor_config + return feature_processor_config.output + + +def _parse_name_from_arn( + name_or_arn: str, regex_pattern: str = FEATURE_GROUP_ARN_REGEX_PATTERN +) -> str: + """Parse the name from a string, if it's an ARN. Otherwise, return the string. + + Args: + fg_uri (str): The Feature Group Name or ARN. + + Returns: + str: The Feature Group Name. + """ + match = re.match(regex_pattern, name_or_arn) + if match: + name = match.group(4) + return name + return name_or_arn + + +def _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn: str, sagemaker_session: Session +) -> List[Dict[str, str]]: + """Retrieve custom tags attached to sagemakre pipeline + + Args: + pipeline_arn (str): SageMaker Pipeline Arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + List[Dict[str, str]]: List of custom tags to be propagated to lineage resources. + """ + tags_in_pipeline = sagemaker_session.sagemaker_client.list_tags(ResourceArn=pipeline_arn)[ + "Tags" + ] + return [d for d in tags_in_pipeline if d["Key"] not in TO_PIPELINE_RESERVED_TAG_KEYS] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py new file mode 100644 index 0000000000..2b4f134f0a --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store Feature Group Contexts""" +from __future__ import absolute_import +import attr + + +@attr.s +class FeatureGroupContexts: + """A Feature Group Context data source. + + Attributes: + feature_group_name (str): The name of the Feature Group. + feature_group_pipeline_context_arn (str): The ARN of the Feature Group Pipeline Context. + feature_group_pipeline_version_context_arn (str): + The ARN of the Feature Group Versions Context + """ + + name: str = attr.ib() + pipeline_context_arn: str = attr.ib() + pipeline_version_context_arn: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py new file mode 100644 index 0000000000..55230d7c1c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py @@ -0,0 +1,182 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Feature Processor Lineage""" +from __future__ import absolute_import + +import re +from typing import Dict, Any +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import FEATURE_GROUP_ARN_REGEX_PATTERN +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + FEATURE_GROUP, + CREATION_TIME, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_group_pipeline_lineage_context_name, + _get_feature_group_pipeline_version_lineage_context_name, +) + +logger = logging.getLogger(SAGEMAKER) + + +class FeatureGroupLineageEntityHandler: + """Class for handling Feature Group Lineage""" + + @staticmethod + def retrieve_feature_group_context_arns( + feature_group_name: str, sagemaker_session: Session + ) -> FeatureGroupContexts: + """Retrieve Feature Group Contexts. + + Arguments: + feature_group_name (str): The Feature Group Name. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + FeatureGroupContexts: The Feature Group Pipeline and Version Context. + """ + feature_group = FeatureGroupLineageEntityHandler._describe_feature_group( + feature_group_name=FeatureGroupLineageEntityHandler.parse_name_from_arn( + feature_group_name + ), + sagemaker_session=sagemaker_session, + ) + feature_group_name = feature_group[FEATURE_GROUP] + feature_group_creation_time = feature_group[CREATION_TIME].strftime("%s") + feature_group_pipeline_context = ( + FeatureGroupLineageEntityHandler._load_feature_group_pipeline_context( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + sagemaker_session=sagemaker_session, + ) + ) + feature_group_pipeline_version_context = ( + FeatureGroupLineageEntityHandler._load_feature_group_pipeline_version_context( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + sagemaker_session=sagemaker_session, + ) + ) + return FeatureGroupContexts( + name=feature_group_name, + pipeline_context_arn=feature_group_pipeline_context.context_arn, + pipeline_version_context_arn=feature_group_pipeline_version_context.context_arn, + ) + + @staticmethod + def _describe_feature_group( + feature_group_name: str, sagemaker_session: Session + ) -> Dict[str, Any]: + """Retrieve the Feature Group. + + Arguments: + feature_group_name (str): The Feature Group Name. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Dict[str, Any]: The Feature Group details. + """ + feature_group = sagemaker_session.describe_feature_group(feature_group_name) + logger.debug( + "Called describe_feature_group with %s and received: %s", + feature_group_name, + feature_group, + ) + return feature_group + + @staticmethod + def _load_feature_group_pipeline_context( + feature_group_name: str, + feature_group_creation_time: str, + sagemaker_session: Session, + ) -> Context: + """Retrieve Feature Group Pipeline Context + + Arguments: + feature_group_name (str): The Feature Group Name. + feature_group_creation_time (str): The Feature Group Creation Time, + in long epoch seconds. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The Feature Group Pipeline Context. + """ + feature_group_pipeline_context = _get_feature_group_pipeline_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + return Context.load( + context_name=feature_group_pipeline_context, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def _load_feature_group_pipeline_version_context( + feature_group_name: str, + feature_group_creation_time: str, + sagemaker_session: Session, + ) -> Context: + """Retrieve Feature Group Pipeline Version Context + + Arguments: + feature_group_name (str): The Feature Group Name. + feature_group_creation_time (str): The Feature Group Creation Time, + in long epoch seconds. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The Feature Group Pipeline Version Context. + """ + feature_group_pipeline_version_context = ( + _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + ) + return Context.load( + context_name=feature_group_pipeline_version_context, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def parse_name_from_arn(fg_uri: str) -> str: + """Parse the name from a string, if it's an ARN. Otherwise, return the string. + + Arguments: + fg_uri (str): The Feature Group Name or ARN. + + Returns: + str: The Feature Group Name. + """ + match = re.match(FEATURE_GROUP_ARN_REGEX_PATTERN, fg_uri) + if match: + feature_group_name = match.group(4) + return feature_group_name + return fg_uri diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py new file mode 100644 index 0000000000..d706b3b441 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py @@ -0,0 +1,759 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Lineage Associations""" +from __future__ import absolute_import +import logging +from datetime import datetime +from typing import Optional, Iterator, List, Dict, Set, Sequence, Union +import attr +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import ( + PipelineTrigger, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + LAST_UPDATE_TIME, + PIPELINE_CONTEXT_NAME_KEY, + PIPELINE_CONTEXT_VERSION_NAME_KEY, + FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + DATA_SET, + TRANSFORMATION_CODE, + CREATION_TIME, + RESOURCE_NOT_FOUND, + ERROR, + CODE, + LAST_MODIFIED_TIME, + TRANSFORMATION_CODE_STATUS_INACTIVE, + TRANSFORMATION_CODE_STATUS_ACTIVE, + CONTRIBUTED_TO, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.lineage.artifact import Artifact +from sagemaker.core.lineage.association import AssociationSummary +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) + +logger = logging.getLogger(SAGEMAKER) + + +@attr.s +class FeatureProcessorLineageHandler: + """Class to Create and Update FeatureProcessor Lineage Entities. + + Attributes: + pipeline_name (str): Pipeline Name. + pipeline_arn (str): The ARN of the Pipeline. + pipeline (str): The details of the Pipeline. + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]]): The inputs to the Feature processor. + output (str): The output Feature Group. + transformation_code (TransformationCode): The Transformation Code for Feature Processor. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + + pipeline_name: str = attr.ib() + pipeline_arn: str = attr.ib() + pipeline: Dict = attr.ib() + sagemaker_session: Session = attr.ib() + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib(default=None) + output: str = attr.ib(default=None) + transformation_code: TransformationCode = attr.ib(default=None) + + def create_lineage(self, tags: Optional[List[Dict[str, str]]] = None) -> None: + """Create and Update Feature Processor Lineage""" + input_feature_group_contexts: List[FeatureGroupContexts] = ( + self._retrieve_input_feature_group_contexts() + ) + output_feature_group_contexts: FeatureGroupContexts = ( + self._retrieve_output_feature_group_contexts() + ) + input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts() + transformation_code_artifact: Optional[Artifact] = ( + S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=self.transformation_code, + pipeline_last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + ) + if transformation_code_artifact is not None: + logger.info("Created Transformation Code Artifact: %s", transformation_code_artifact) + if tags: + transformation_code_artifact.set_tags(tags) # pylint: disable=E1101 + # Create the Pipeline Lineage for the first time + if not self._check_if_pipeline_lineage_exists(): + self._create_new_pipeline_lineage( + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + else: + self._update_pipeline_lineage( + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + + def get_pipeline_lineage_names(self) -> Optional[Dict[str, str]]: + """Retrieve Pipeline Lineage Names. + + Returns: + Optional[Dict[str, str]]: Pipeline and Pipeline version lineage names. + """ + if not self._check_if_pipeline_lineage_exists(): + return None + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + return { + PIPELINE_CONTEXT_NAME_KEY: pipeline_context.context_name, + PIPELINE_CONTEXT_VERSION_NAME_KEY: current_pipeline_version_context.context_name, + } + + def create_schedule_lineage( + self, + pipeline_name: str, + schedule_arn, + schedule_expression, + state, + start_date: datetime, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + """Class to Create and Update FeatureProcessor Lineage Entities. + + Arguments: + pipeline_name (str): Pipeline Name. + schedule_arn (str): The ARN of the Schedule. + schedule_expression (str): The expression that defines when the schedule runs. + It supports at expression, rate expression and cron expression. + state (str):Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/ + API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details. + If not specified, it will default to ENABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + tags (Optional[List[Dict[str, str]]]): Custom tags to be attached to schedule + lineage resource. + """ + pipeline_context: Context = self._get_pipeline_context() + pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + pipeline_schedule: PipelineSchedule = PipelineSchedule( + schedule_name=pipeline_name, + schedule_arn=schedule_arn, + schedule_expression=schedule_expression, + pipeline_name=pipeline_name, + state=state, + start_date=start_date.strftime("%s"), + ) + schedule_artifact: Artifact = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=pipeline_schedule, + sagemaker_session=self.sagemaker_session, + ) + if tags: + schedule_artifact.set_tags(tags) + + LineageAssociationHandler.add_upstream_schedule_associations( + schedule_artifact=schedule_artifact, + pipeline_version_context_arn=pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def create_trigger_lineage( + self, + pipeline_name: str, + trigger_arn: str, + event_pattern: str, + state: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + """Class to Create and Update FeatureProcessor Pipeline Trigger Lineage Entities. + + Arguments: + pipeline_name (str): Pipeline Name. + trigger_arn (str): The ARN of the EventBridge Rule. + event_pattern (str): The event pattern for the rule. + state (str): Specifies whether the trigger is enabled or disabled. Valid values are + ENABLED and DISABLED. If not specified, it will default to ENABLED. + tags (Optional[List[Dict[str, str]]]): Custom tags to be attached to trigger + lineage resource. + """ + pipeline_context: Context = self._get_pipeline_context() + pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + pipeline_trigger: PipelineTrigger = PipelineTrigger( + trigger_name=pipeline_name, + trigger_arn=trigger_arn, + event_pattern=event_pattern, + pipeline_name=pipeline_name, + state=state, + ) + trigger_artifact: Artifact = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=pipeline_trigger, + sagemaker_session=self.sagemaker_session, + ) + if tags: + trigger_artifact.set_tags(tags) + + LineageAssociationHandler._add_association( + source_arn=trigger_artifact.artifact_arn, + destination_arn=pipeline_version_context.context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=self.sagemaker_session, + ) + + def upsert_tags_for_lineage_resources(self, tags: List[Dict[str, str]]) -> None: + """Add or update tags for lineage resources using tags attached to sagemaker pipeline as + + source of truth. + + Args: + tags (List[Dict[str, str]]): Custom tags to be attached to lineage resources. + """ + if not tags: + return + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts() + pipeline_context.set_tags(tags) + current_pipeline_version_context.set_tags(tags) + for input_raw_data_artifact in input_raw_data_artifacts: + input_raw_data_artifact.set_tags(tags) + + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + self.sagemaker_session, + self.sagemaker_session.boto_session.client("scheduler"), + ) + event_bridge_schedule = event_bridge_scheduler_helper.describe_schedule(self.pipeline_name) + + event_bridge_rule_helper = EventBridgeRuleHelper( + self.sagemaker_session, + self.sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule = event_bridge_rule_helper.describe_rule(self.pipeline_name) + + if event_bridge_schedule: + schedule_artifact_summary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=event_bridge_schedule["Arn"], + sagemaker_session=self.sagemaker_session, + ) + if schedule_artifact_summary is not None: + pipeline_schedule_artifact: Artifact = ( + S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=schedule_artifact_summary.artifact_arn, + sagemaker_session=self.sagemaker_session, + ) + ) + pipeline_schedule_artifact.set_tags(tags) + + if event_bridge_rule: + rule_artifact_summary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=event_bridge_rule["Arn"], + sagemaker_session=self.sagemaker_session, + ) + if rule_artifact_summary: + pipeline_trigger_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=rule_artifact_summary.artifact_arn, + sagemaker_session=self.sagemaker_session, + ) + pipeline_trigger_artifact.set_tags(tags) + + def _create_new_pipeline_lineage( + self, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact], + ) -> None: + """Create pipeline lineage resources.""" + + pipeline_context = self._create_pipeline_lineage_for_new_pipeline() + pipeline_version_context = self._create_pipeline_version_lineage() + self._add_associations_for_pipeline( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_versions_context_arn=pipeline_version_context.context_arn, + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_version_context_arn=pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def _update_pipeline_lineage( + self, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact], + ) -> None: + """Update pipeline lineage resources.""" + + # If pipeline lineage exists then determine whether to create a new version. + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + upstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) + ) + + upstream_raw_data_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=DATA_SET, + sagemaker_session=self.sagemaker_session, + ) + ) + + upstream_transformation_code: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=TRANSFORMATION_CODE, + sagemaker_session=self.sagemaker_session, + ) + ) + + downstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_downstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + destination_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) + ) + + is_upstream_feature_group_equal: bool = self._compare_upstream_feature_groups( + upstream_feature_group_associations=upstream_feature_group_associations, + input_feature_group_contexts=input_feature_group_contexts, + ) + is_downstream_feature_group_equal: bool = self._compare_downstream_feature_groups( + downstream_feature_group_associations=downstream_feature_group_associations, + output_feature_group_contexts=output_feature_group_contexts, + ) + is_upstream_raw_data_equal: bool = self._compare_upstream_raw_data( + upstream_raw_data_associations=upstream_raw_data_associations, + input_raw_data_artifacts=input_raw_data_artifacts, + ) + + self._update_last_transformation_code( + upstream_transformation_code_associations=upstream_transformation_code + ) + if ( + not is_upstream_feature_group_equal + or not is_downstream_feature_group_equal + or not is_upstream_raw_data_equal + ): + if not is_upstream_raw_data_equal: + logger.info("Raw data inputs have changed from the last pipeline configuration.") + if not is_upstream_feature_group_equal: + logger.info( + "Feature group inputs have changed from the last pipeline configuration." + ) + if not is_downstream_feature_group_equal: + logger.info( + "Feature Group output has changed from the last pipeline configuration." + ) + pipeline_context.properties["LastUpdateTime"] = self.pipeline[ + "LastModifiedTime" + ].strftime("%s") + PipelineLineageEntityHandler.update_pipeline_context(pipeline_context=pipeline_context) + new_pipeline_version_context: Context = self._create_pipeline_version_lineage() + self._add_associations_for_pipeline( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_versions_context_arn=new_pipeline_version_context.context_arn, + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_version_context_arn=new_pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + elif transformation_code_artifact is not None: + # We will append the new transformation code artifact + # to the existing pipeline version. + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=transformation_code_artifact, + # pylint: disable=no-member + pipeline_version_context_arn=current_pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def _retrieve_input_raw_data_artifacts(self) -> List[Artifact]: + """Retrieve input Raw Data Artifacts. + + Returns: + List[Artifact]: List of Raw Data Artifacts. + """ + raw_data_artifacts: List[Artifact] = list() + raw_data_uri_set: Set[str] = set() + + for data_source in self.inputs: + if isinstance(data_source, (CSVDataSource, ParquetDataSource, BaseDataSource)): + data_source_uri = ( + data_source.s3_uri + if isinstance(data_source, (CSVDataSource, ParquetDataSource)) + else data_source.data_source_unique_id + ) + if data_source_uri not in raw_data_uri_set: + raw_data_uri_set.add(data_source_uri) + raw_data_artifacts.append( + S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, + sagemaker_session=self.sagemaker_session, + ) + ) + + return raw_data_artifacts + + def _compare_upstream_raw_data( + self, + upstream_raw_data_associations: Iterator[AssociationSummary], + input_raw_data_artifacts: List[Artifact], + ) -> bool: + """Compare the existing and the new upstream Raw Data. + + Arguments: + upstream_raw_data_associations (Iterator[AssociationSummary]): + Upstream existing raw data associations for the pipeline. + input_raw_data_artifacts (List[Artifact]): + New Upstream raw data for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + raw_data_association_set = { + raw_data_association.source_arn + for raw_data_association in upstream_raw_data_associations + } + if len(raw_data_association_set) != len(input_raw_data_artifacts): + return False + for raw_data in input_raw_data_artifacts: + if raw_data.artifact_arn not in raw_data_association_set: + return False + return True + + def _compare_downstream_feature_groups( + self, + downstream_feature_group_associations: Iterator[AssociationSummary], + output_feature_group_contexts: FeatureGroupContexts, + ) -> bool: + """Compare the existing and the new downstream Feature Groups. + + Arguments: + downstream_feature_group_associations (Iterator[AssociationSummary]): + Downstream existing Feature Group association for the pipeline. + output_feature_group_contexts (List[Artifact]): + New Downstream Feature group for the pipeline. + + Returns: + bool: Boolean if old and new Downstream is same. + """ + feature_group_association_set = set() + for feature_group_association in downstream_feature_group_associations: + feature_group_association_set.add(feature_group_association.destination_arn) + if len(feature_group_association_set) != 1: + ValueError( + f"There should only be one Feature Group as output, " + f"instead we got {len(feature_group_association_set)}. " + f"With Feature Group Versions Contexts: {feature_group_association_set}" + ) + return ( + output_feature_group_contexts.pipeline_version_context_arn + in feature_group_association_set + ) + + def _compare_upstream_feature_groups( + self, + upstream_feature_group_associations: Iterator[AssociationSummary], + input_feature_group_contexts: List[FeatureGroupContexts], + ) -> bool: + """Compare the existing and the new upstream Feature Group. + + Arguments: + upstream_feature_group_associations (Iterator[AssociationSummary]): + Upstream existing Feature Group association for the pipeline. + input_feature_group_contexts (List[Artifact]): + New Upstream Feature group for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + feature_group_association_set = set() + for feature_group_association in upstream_feature_group_associations: + feature_group_association_set.add(feature_group_association.source_arn) + if len(feature_group_association_set) != len(input_feature_group_contexts): + return False + for feature_group in input_feature_group_contexts: + if feature_group.pipeline_version_context_arn not in feature_group_association_set: + return False + return True + + def _update_last_transformation_code( + self, upstream_transformation_code_associations: Iterator[AssociationSummary] + ) -> None: + """Compare the existing and the new upstream Transformation Code. + + Arguments: + upstream_transformation_code_associations (Iterator[AssociationSummary]): + Upstream existing transformation code associations for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + upstream_transformation_code = next(upstream_transformation_code_associations, None) + if upstream_transformation_code is None: + return + + last_transformation_code_artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=upstream_transformation_code.source_arn, + sagemaker_session=self.sagemaker_session, + ) + logger.info( + "Retrieved previous transformation code artifact: %s", last_transformation_code_artifact + ) + if ( + last_transformation_code_artifact.properties["state"] + == TRANSFORMATION_CODE_STATUS_ACTIVE + ): + last_transformation_code_artifact.properties["state"] = ( + TRANSFORMATION_CODE_STATUS_INACTIVE + ) + last_transformation_code_artifact.properties["exclusive_end_date"] = self.pipeline[ + LAST_MODIFIED_TIME + ].strftime("%s") + S3LineageEntityHandler.update_transformation_code_artifact( + transformation_code_artifact=last_transformation_code_artifact + ) + logger.info("Updated the last transformation artifact") + + def _get_pipeline_context(self) -> Context: + """Retrieve Pipeline Context. + + Returns: + Context: The Pipeline Context. + """ + return PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=self.pipeline_name, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _get_pipeline_version_context(self, last_update_time: str) -> Context: + """Retrieve Pipeline Version Context. + + Returns: + Context: The Pipeline Version Context. + """ + return PipelineVersionLineageEntityHandler.load_pipeline_version_context( + pipeline_name=self.pipeline_name, + last_update_time=last_update_time, + sagemaker_session=self.sagemaker_session, + ) + + def _check_if_pipeline_lineage_exists(self) -> bool: + """Check if Pipeline Lineage exists. + + Returns: + bool: Check if pipeline lineage exists. + """ + try: + PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=self.pipeline_name, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + return True + except ClientError as e: + if e.response[ERROR][CODE] == RESOURCE_NOT_FOUND: + return False + raise e + + def _retrieve_input_feature_group_contexts(self) -> List[FeatureGroupContexts]: + """Retrieve input Feature Groups' Context ARNs. + + Returns: + List[FeatureGroupContexts]: List of Input Feature Groups for the pipeline. + """ + feature_group_contexts: List[FeatureGroupContexts] = list() + feature_group_input_set: Set[str] = set() + for data_source in self.inputs: + if isinstance(data_source, FeatureGroupDataSource): + feature_group_name: str = FeatureGroupLineageEntityHandler.parse_name_from_arn( + data_source.name + ) + if feature_group_name not in feature_group_input_set: + feature_group_input_set.add(feature_group_name) + feature_group_contexts.append( + FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=data_source.name, + sagemaker_session=self.sagemaker_session, + ) + ) + return feature_group_contexts + + def _retrieve_output_feature_group_contexts(self) -> FeatureGroupContexts: + """Retrieve output Feature Group's Context ARNs. + + Returns: + FeatureGroupContexts: The output Feature Group for the pipeline. + """ + return FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=self.output, sagemaker_session=self.sagemaker_session + ) + + def _create_pipeline_lineage_for_new_pipeline(self) -> Context: + """Create Pipeline Context for a new pipeline. + + Returns: + Context: The Pipeline Context. + """ + return PipelineLineageEntityHandler.create_pipeline_context( + pipeline_name=self.pipeline_name, + pipeline_arn=self.pipeline_arn, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _create_pipeline_version_lineage(self) -> Context: + """Create a new Pipeline Version Context. + + Returns: + Context: The Pipeline Versions Context. + """ + return PipelineVersionLineageEntityHandler.create_pipeline_version_context( + pipeline_name=self.pipeline_name, + pipeline_arn=self.pipeline_arn, + last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _add_associations_for_pipeline( + self, + pipeline_context_arn: str, + pipeline_versions_context_arn: str, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact] = None, + ) -> None: + """Add Feature Processor Lineage Associations for the Pipeline + + Arguments: + pipeline_context_arn (str): The pipeline Context ARN. + pipeline_versions_context_arn (str): The pipeline Version Context ARN. + input_feature_group_contexts (List[FeatureGroupContexts]): List of input FeatureGroups. + input_raw_data_artifacts (List[Artifact]): List of input raw data. + output_feature_group_contexts (FeatureGroupContexts): Output Feature Group + transformation_code_artifact (Optional[Artifact]): The transformation Code. + """ + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=input_feature_group_contexts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + LineageAssociationHandler.add_downstream_feature_group_data_associations( + feature_group_output=output_feature_group_contexts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + LineageAssociationHandler.add_upstream_raw_data_associations( + raw_data_inputs=input_raw_data_artifacts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + if transformation_code_artifact is not None: + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=transformation_code_artifact, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py new file mode 100644 index 0000000000..1a4e9ed04f --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle lineage resource name generation.""" +from __future__ import absolute_import + +FEATURE_PROCESSOR_CREATED_PREFIX = "sm-fs-fe" +FEATURE_PROCESSOR_CREATED_TRIGGER_PREFIX = "sm-fs-fe-trigger" +FEATURE_GROUP_PIPELINE_CONTEXT_SUFFIX = "feature-group-pipeline" +FEATURE_GROUP_PIPELINE_CONTEXT_VERSION_SUFFIX = "feature-group-pipeline-version" +FEATURE_PROCESSOR_PIPELINE_CONTEXT_SUFFIX = "fep" +FEATURE_PROCESSOR_PIPELINE_VERSION_CONTEXT_SUFFIX = "fep-ver" + + +def _get_feature_processor_lineage_context_name( + resource_name: str, + resource_creation_time: str, + lineage_context_prefix: str = None, + lineage_context_suffix: str = None, +) -> str: + """Generic naming generation function for lineage resources used by feature_processor.""" + context_name_base = [f"{resource_name}-{resource_creation_time}"] + if lineage_context_prefix: + context_name_base.insert(0, lineage_context_prefix) + if lineage_context_suffix: + context_name_base.append(lineage_context_suffix) + return "-".join(context_name_base) + + +def _get_feature_group_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group contexts.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, resource_creation_time=feature_group_creation_time + ) + + +def _get_feature_group_pipeline_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group pipeline.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, + resource_creation_time=feature_group_creation_time, + lineage_context_suffix=FEATURE_GROUP_PIPELINE_CONTEXT_SUFFIX, + ) + + +def _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group pipeline version.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, + resource_creation_time=feature_group_creation_time, + lineage_context_suffix=FEATURE_GROUP_PIPELINE_CONTEXT_VERSION_SUFFIX, + ) + + +def _get_feature_processor_pipeline_lineage_context_name( + pipeline_name: str, pipeline_creation_time: str +) -> str: + """Generate context name for feature processor pipeline.""" + return _get_feature_processor_lineage_context_name( + resource_name=pipeline_name, + resource_creation_time=pipeline_creation_time, + lineage_context_prefix=FEATURE_PROCESSOR_CREATED_PREFIX, + lineage_context_suffix=FEATURE_PROCESSOR_PIPELINE_CONTEXT_SUFFIX, + ) + + +def _get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name: str, pipeline_last_update_time: str +) -> str: + """Generate context name for feature processor pipeline version.""" + return _get_feature_processor_lineage_context_name( + resource_name=pipeline_name, + resource_creation_time=pipeline_last_update_time, + lineage_context_prefix=FEATURE_PROCESSOR_CREATED_PREFIX, + lineage_context_suffix=FEATURE_PROCESSOR_PIPELINE_VERSION_CONTEXT_SUFFIX, + ) + + +def _get_feature_processor_schedule_lineage_artifact_name(schedule_name: str) -> str: + """Generate artifact name for feature processor pipeline schedule.""" + return "-".join([FEATURE_PROCESSOR_CREATED_PREFIX, schedule_name]) + + +def _get_feature_processor_trigger_lineage_artifact_name(trigger_name: str) -> str: + """Generate artifact name for feature processor pipeline trigger.""" + return "-".join([FEATURE_PROCESSOR_CREATED_TRIGGER_PREFIX, trigger_name]) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py new file mode 100644 index 0000000000..0413b5d7c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py @@ -0,0 +1,300 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Lineage Associations""" +from __future__ import absolute_import +import logging +from typing import List, Optional, Iterator +from botocore.exceptions import ClientError + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor._constants import VALIDATION_EXCEPTION +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + CONTRIBUTED_TO, + ERROR, + CODE, + SAGEMAKER, + ASSOCIATED_WITH, +) +from sagemaker.core.lineage.artifact import Artifact +from sagemaker.core.lineage.association import Association, AssociationSummary + +logger = logging.getLogger(SAGEMAKER) + + +class LineageAssociationHandler: + """Class to handler the FeatureProcessor Lineage Associations""" + + @staticmethod + def add_upstream_feature_group_data_associations( + feature_group_inputs: List[FeatureGroupContexts], + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Feature Group Lineage Associations. + + Arguments: + feature_group_inputs (List[FeatureGroupContexts]): The input Feature Group List. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + for feature_group in feature_group_inputs: + LineageAssociationHandler._add_association( + source_arn=feature_group.pipeline_context_arn, + destination_arn=pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_raw_data_associations( + raw_data_inputs: List[Artifact], + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Raw Data Lineage Associations. + + Arguments: + raw_data_inputs (List[Artifact]): The input raw data List. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + for raw_data_artifact in raw_data_inputs: + LineageAssociationHandler._add_association( + source_arn=raw_data_artifact.artifact_arn, + destination_arn=pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=raw_data_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_transformation_code_associations( + transformation_code_artifact: Artifact, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Transformation Code Lineage Associations. + + Arguments: + transformation_code_artifact (Artifact): The transformation Code Artifact. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=transformation_code_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_schedule_associations( + schedule_artifact: Artifact, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Schedule Lineage Associations. + + Arguments: + schedule_artifact (Artifact): The schedule Artifact. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=schedule_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_downstream_feature_group_data_associations( + feature_group_output: FeatureGroupContexts, + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Downstream Feature Group Lineage Associations. + + Arguments: + feature_group_output (FeatureGroupContexts): The output Feature Group. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=pipeline_context_arn, + destination_arn=feature_group_output.pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=pipeline_version_context_arn, + destination_arn=feature_group_output.pipeline_version_context_arn, + association_type="ContributedTo", + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_pipeline_and_pipeline_version_association( + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Lineage Association + + between the Pipeline and the Pipeline Versions. + + Arguments: + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=pipeline_context_arn, + destination_arn=pipeline_version_context_arn, + association_type=ASSOCIATED_WITH, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def list_upstream_associations( + entity_arn: str, source_type: str, sagemaker_session: Session + ) -> Iterator[AssociationSummary]: + """List Upstream Lineage Associations. + + Arguments: + entity_arn (str): The Lineage Entity ARN. + source_type (str): The Source Type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return LineageAssociationHandler._list_association( + destination_arn=entity_arn, + source_type=source_type, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def list_downstream_associations( + entity_arn: str, destination_type: str, sagemaker_session: Session + ) -> Iterator[AssociationSummary]: + """List Downstream Lineage Associations. + + Arguments: + entity_arn (str): The Lineage Entity ARN. + destination_type (str): The Destination Type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return LineageAssociationHandler._list_association( + source_arn=entity_arn, + destination_type=destination_type, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def _add_association( + source_arn: str, + destination_arn: str, + association_type: str, + sagemaker_session: Session, + ) -> None: + """Add Lineage Association. + + Arguments: + source_arn (str): The source ARN. + destination_arn (str): The destination ARN. + association_type (str): The association type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + try: + logger.info( + "Adding association with source_arn: " + "%s, destination_arn: %s and association_type: %s.", + source_arn, + destination_arn, + association_type, + ) + Association.create( + source_arn=source_arn, + destination_arn=destination_arn, + association_type=association_type, + sagemaker_session=sagemaker_session, + ) + except ClientError as e: + if e.response[ERROR][CODE] == VALIDATION_EXCEPTION: + logger.info("Association already exists") + else: + raise e + + @staticmethod + def _list_association( + sagemaker_session: Session, + source_arn: Optional[str] = None, + source_type: Optional[str] = None, + destination_arn: Optional[str] = None, + destination_type: Optional[str] = None, + ) -> Iterator[AssociationSummary]: + """List Lineage Associations. + + Arguments: + source_arn (str): The source ARN. + source_type (str): The source type. + destination_arn (str): The destination ARN. + destination_type (str): The destination type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return Association.list( + source_arn=source_arn, + source_type=source_type, + destination_arn=destination_arn, + destination_type=destination_type, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py new file mode 100644 index 0000000000..3bf80e9d95 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Pipeline Lineage""" +from __future__ import absolute_import +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + PIPELINE_NAME_KEY, + PIPELINE_CREATION_TIME_KEY, + LAST_UPDATE_TIME_KEY, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_pipeline_lineage_context_name, +) +from sagemaker.core.lineage import context + +logger = logging.getLogger(SAGEMAKER) + + +class PipelineLineageEntityHandler: + """Class for handling FeatureProcessor Pipeline Lineage""" + + @staticmethod + def create_pipeline_context( + pipeline_name: str, + pipeline_arn: str, + creation_time: str, + last_update_time: str, + sagemaker_session: Session, + ) -> Context: + """Create the FeatureProcessor Pipeline context. + + Arguments: + pipeline_name (str): The pipeline name. + pipeline_arn (str): The pipeline ARN. + creation_time (str): The pipeline creation time. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline context. + """ + return context.Context.create( + context_name=_get_feature_processor_pipeline_lineage_context_name( + pipeline_name, creation_time + ), + context_type="FeatureEngineeringPipeline", + source_uri=pipeline_arn, + source_type=creation_time, + properties={ + PIPELINE_NAME_KEY: pipeline_name, + PIPELINE_CREATION_TIME_KEY: creation_time, + LAST_UPDATE_TIME_KEY: last_update_time, + }, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_pipeline_context( + pipeline_name: str, creation_time: str, sagemaker_session: Session + ) -> Context: + """Load the FeatureProcessor Pipeline context. + + Arguments: + pipeline_name (str): The pipeline name. + creation_time (str): The pipeline creation time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline context. + """ + return Context.load( + context_name=_get_feature_processor_pipeline_lineage_context_name( + pipeline_name, creation_time + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def update_pipeline_context(pipeline_context: Context) -> None: + """Update the FeatureProcessor Pipeline context + + Arguments: + pipeline_context (Context): The pipeline context. + """ + pipeline_context.save() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py new file mode 100644 index 0000000000..08f10fb8fb --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store the Pipeline Schedule""" +from __future__ import absolute_import +import attr + + +@attr.s +class PipelineSchedule: + """A Schedule definition for FeatureProcessor Lineage. + + Attributes: + schedule_name (str): Schedule Name. + schedule_arn (str): The ARN of the Schedule. + schedule_expression (str): The expression that defines when the schedule runs. It supports + at expression, rate expression and cron expression. See https://docs.aws.amazon.com/ + scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-request + -ScheduleExpression for more details. + pipeline_name (str): The SageMaker Pipeline name that will be scheduled. + state (str): Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/ + API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details. + If not specified, it will default to DISABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + """ + + schedule_name: str = attr.ib() + schedule_arn: str = attr.ib() + schedule_expression: str = attr.ib() + pipeline_name: str = attr.ib() + state: str = attr.ib() + start_date: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py new file mode 100644 index 0000000000..e58003f396 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store the Pipeline Schedule""" +from __future__ import absolute_import +import attr + + +@attr.s +class PipelineTrigger: + """An evnet based trigger definition for FeatureProcessor Lineage. + + Attributes: + trigger_name (str): Trigger Name. + trigger_arn (str): The ARN of the Trigger. + event_pattern (str): The event pattern. For more information, see Amazon EventBridge + event patterns in the Amazon EventBridge User Guide. + pipeline_name (str): The SageMaker Pipeline name that will be triggered. + state (str): Specifies whether the trigger is enabled or disabled. Valid values are + ENABLED and DISABLED. + """ + + trigger_name: str = attr.ib() + trigger_arn: str = attr.ib() + event_pattern: str = attr.ib() + pipeline_name: str = attr.ib() + state: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py new file mode 100644 index 0000000000..5d0b4c979b --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py @@ -0,0 +1,92 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Pipeline Version Lineage""" +from __future__ import absolute_import +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + PIPELINE_VERSION_CONTEXT_TYPE, + PIPELINE_NAME_KEY, + LAST_UPDATE_TIME_KEY, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_pipeline_version_lineage_context_name, +) + +logger = logging.getLogger(SAGEMAKER) + + +class PipelineVersionLineageEntityHandler: + """Class for handling FeatureProcessor Pipeline Version Lineage""" + + @staticmethod + def create_pipeline_version_context( + pipeline_name: str, + pipeline_arn: str, + last_update_time: str, + sagemaker_session: Session, + ) -> Context: + """Create the FeatureProcessor Pipeline Version context. + + Arguments: + pipeline_name (str): The pipeline name. + pipeline_arn (str): The pipeline ARN. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline version context. + """ + return Context.create( + context_name=_get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name, last_update_time + ), + context_type=f"{PIPELINE_VERSION_CONTEXT_TYPE}-{pipeline_name}", + source_uri=pipeline_arn, + source_type=last_update_time, + properties={ + PIPELINE_NAME_KEY: pipeline_name, + LAST_UPDATE_TIME_KEY: last_update_time, + }, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_pipeline_version_context( + pipeline_name: str, last_update_time: str, sagemaker_session: Session + ) -> Context: + """Load the FeatureProcessor Pipeline Version context. + + Arguments: + pipeline_name (str): The pipeline name. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline version context. + """ + return Context.load( + context_name=_get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name, last_update_time + ), + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py new file mode 100644 index 0000000000..78a0f18c7c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py @@ -0,0 +1,316 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle S3 Lineage""" +from __future__ import absolute_import +import logging +from typing import Union, Optional, List + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor import ( + CSVDataSource, + ParquetDataSource, + BaseDataSource, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_schedule_lineage_artifact_name, + _get_feature_processor_trigger_lineage_artifact_name, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + TRANSFORMATION_CODE_STATUS_ACTIVE, + FEP_LINEAGE_PREFIX, + TRANSFORMATION_CODE_ARTIFACT_NAME, +) +from sagemaker.core.lineage.artifact import Artifact, ArtifactSummary + +logger = logging.getLogger("sagemaker") + + +class S3LineageEntityHandler: + """Class for handling FeatureProcessor S3 Artifact Lineage""" + + @staticmethod + def retrieve_raw_data_artifact( + raw_data: Union[CSVDataSource, ParquetDataSource, BaseDataSource], + sagemaker_session: Session, + ) -> Artifact: + """Load or create the FeatureProcessor Pipeline's raw data Artifact. + + Arguments: + raw_data (Union[CSVDataSource, ParquetDataSource, BaseDataSource]): The raw data to be + retrieved. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The raw data artifact. + """ + raw_data_uri = ( + raw_data.s3_uri + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_unique_id + ) + raw_data_artifact_name = ( + "sm-fs-fe-raw-data" + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_name + ) + + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=raw_data_uri, sagemaker_session=sagemaker_session + ) + if load_artifact is not None: + return S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + + return S3LineageEntityHandler._create_artifact( + s3_uri=raw_data_uri, + artifact_type="DataSet", + artifact_name=raw_data_artifact_name, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def update_transformation_code_artifact( + transformation_code_artifact: Artifact, + ) -> None: + """Update Pipeline's transformation code Artifact. + + Arguments: + transformation_code_artifact (TransformationCode): The transformation code Artifact to be updated. + """ + transformation_code_artifact.save() + + @staticmethod + def create_transformation_code_artifact( + transformation_code: TransformationCode, + pipeline_last_update_time: str, + sagemaker_session: Session, + ) -> Optional[Artifact]: + """Create the FeatureProcessor Pipeline's transformation code Artifact. + + Arguments: + transformation_code (TransformationCode): The transformation code to be retrieved. + pipeline_last_update_time (str): The last update time of the pipeline. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The transformation code artifact. + """ + if transformation_code is None: + return None + + properties = dict( + state=TRANSFORMATION_CODE_STATUS_ACTIVE, + inclusive_start_date=pipeline_last_update_time, + ) + if transformation_code.name is not None: + properties["name"] = transformation_code.name + if transformation_code.author is not None: + properties["author"] = transformation_code.author + + return S3LineageEntityHandler._create_artifact( + s3_uri=transformation_code.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=pipeline_last_update_time)], + properties=properties, + artifact_type="TransformationCode", + artifact_name=f"{FEP_LINEAGE_PREFIX}-" + f"{TRANSFORMATION_CODE_ARTIFACT_NAME}-" + f"{pipeline_last_update_time}", + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def retrieve_pipeline_schedule_artifact( + pipeline_schedule: PipelineSchedule, + sagemaker_session: Session, + _get_feature_processor_schedule_lineage_artifact_namef=None, + ) -> Optional[Artifact]: + """Load or create the FeatureProcessor Pipeline's schedule Artifact + + Arguments: + pipeline_schedule (PipelineSchedule): Class to hold the Pipeline Schedule details + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Schedule Artifact. + """ + if pipeline_schedule is None: + return None + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=pipeline_schedule.schedule_arn, + sagemaker_session=sagemaker_session, + ) + if load_artifact is not None: + pipeline_schedule_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + pipeline_schedule_artifact.properties["pipeline_name"] = pipeline_schedule.pipeline_name + pipeline_schedule_artifact.properties["schedule_expression"] = ( + pipeline_schedule.schedule_expression + ) + pipeline_schedule_artifact.properties["state"] = pipeline_schedule.state + pipeline_schedule_artifact.properties["start_date"] = pipeline_schedule.start_date + pipeline_schedule_artifact.save() + return pipeline_schedule_artifact + + return S3LineageEntityHandler._create_artifact( + s3_uri=pipeline_schedule.schedule_arn, + artifact_type="PipelineSchedule", + artifact_name=_get_feature_processor_schedule_lineage_artifact_name( + schedule_name=pipeline_schedule.schedule_name + ), + properties=dict( + pipeline_name=pipeline_schedule.pipeline_name, + schedule_expression=pipeline_schedule.schedule_expression, + state=pipeline_schedule.state, + start_date=pipeline_schedule.start_date, + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def retrieve_pipeline_trigger_artifact( + pipeline_trigger: PipelineTrigger, + sagemaker_session: Session, + ) -> Optional[Artifact]: + """Load or create the FeatureProcessor Pipeline's trigger Artifact + + Arguments: + pipeline_trigger (PipelineTrigger): Class to hold the Pipeline Trigger details + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Trigger Artifact. + """ + if pipeline_trigger is None: + return None + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=pipeline_trigger.trigger_arn, + sagemaker_session=sagemaker_session, + ) + if load_artifact is not None: + pipeline_trigger_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + pipeline_trigger_artifact.properties["pipeline_name"] = pipeline_trigger.pipeline_name + pipeline_trigger_artifact.properties["event_pattern"] = pipeline_trigger.event_pattern + pipeline_trigger_artifact.properties["state"] = pipeline_trigger.state + pipeline_trigger_artifact.save() + return pipeline_trigger_artifact + + return S3LineageEntityHandler._create_artifact( + s3_uri=pipeline_trigger.trigger_arn, + artifact_type="PipelineTrigger", + artifact_name=_get_feature_processor_trigger_lineage_artifact_name( + trigger_name=pipeline_trigger.trigger_name + ), + properties=dict( + pipeline_name=pipeline_trigger.pipeline_name, + event_pattern=pipeline_trigger.event_pattern, + state=pipeline_trigger.state, + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_artifact_from_arn(artifact_arn: str, sagemaker_session: Session) -> Artifact: + """Load Lineage Artifacts from ARN. + + Arguments: + artifact_arn (str): The Artifact ARN. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Artifact for the provided ARN. + """ + return Artifact.load(artifact_arn=artifact_arn, sagemaker_session=sagemaker_session) + + @staticmethod + def _load_artifact_from_s3_uri( + s3_uri: str, sagemaker_session: Session + ) -> Optional[ArtifactSummary]: + """Load FeatureProcessor S3 Lineage Artifacts. + + Arguments: + s3_uri (str): The s3 uri of the Artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + ArtifactSummary: The Artifact Summary for the provided S3 URI. + """ + artifacts = Artifact.list(source_uri=s3_uri, sagemaker_session=sagemaker_session) + for artifact_summary in artifacts: + # We want to make sure that source_type is empty. + # Since SDK will not set it while creating artifacts. + if ( + artifact_summary.source.source_types is None + or len(artifact_summary.source.source_types) == 0 + ): + return artifact_summary + return None + + @staticmethod + def _create_artifact( + s3_uri: str, + artifact_type: str, + sagemaker_session: Session, + properties: Optional[dict] = None, + artifact_name: Optional[str] = None, + source_types: Optional[List[dict]] = None, + ) -> Artifact: + """Create Lineage Artifacts. + + Arguments: + s3_uri (str): The s3 uri of the Artifact. + artifact_type (str): The Artifact type. + properties (Optional[dict]): The properties of the Artifact. + artifact_name (Optional[str]): The name of the Artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The new Artifact. + """ + return Artifact.create( + source_uri=s3_uri, + source_types=source_types, + artifact_type=artifact_type, + artifact_name=artifact_name, + properties=properties, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py new file mode 100644 index 0000000000..70ce48d910 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store Transformation Code""" +from __future__ import absolute_import +from typing import Optional +import attr + + +@attr.s +class TransformationCode: + """A Transformation Code definition for FeatureProcessor Lineage. + + Attributes: + s3_uri (str): The S3 URI of the code. + name (Optional[str]): The name of the code Artifact object. + author (Optional[str]): The author of the code. + """ + + s3_uri: str = attr.ib() + name: Optional[str] = attr.ib(default=None) + author: Optional[str] = attr.ib(default=None) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py new file mode 100644 index 0000000000..25f4b04716 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing constants for feature_processor and feature_scheduler module.""" +from __future__ import absolute_import + +FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE = "FeatureGroupPipelineVersion" +PIPELINE_CONTEXT_TYPE = "FeatureEngineeringPipeline" +PIPELINE_VERSION_CONTEXT_TYPE = "FeatureEngineeringPipelineVersion" +PIPELINE_CONTEXT_NAME_SUFFIX = "fep" +PIPELINE_VERSION_CONTEXT_NAME_SUFFIX = "fep-ver" +FEP_LINEAGE_PREFIX = "sm-fs-fe" +DATA_SET = "DataSet" +TRANSFORMATION_CODE = "TransformationCode" +LAST_UPDATE_TIME = "LastUpdateTime" +LAST_MODIFIED_TIME = "LastModifiedTime" +CREATION_TIME = "CreationTime" +RESOURCE_NOT_FOUND = "ResourceNotFound" +ERROR = "Error" +CODE = "Code" +SAGEMAKER = "sagemaker" +CONTRIBUTED_TO = "ContributedTo" +ASSOCIATED_WITH = "AssociatedWith" +FEATURE_GROUP = "FeatureGroupName" +FEATURE_GROUP_PIPELINE_SUFFIX = "feature-group-pipeline" +FEATURE_GROUP_PIPELINE_VERSION_SUFFIX = "feature-group-pipeline-version" +PIPELINE_CONTEXT_NAME_KEY = "pipeline_context_name" +PIPELINE_CONTEXT_VERSION_NAME_KEY = "pipeline_version_context_name" +PIPELINE_NAME_KEY = "PipelineName" +PIPELINE_CREATION_TIME_KEY = "PipelineCreationTime" +LAST_UPDATE_TIME_KEY = "LastUpdateTime" +TRANSFORMATION_CODE_STATUS_ACTIVE = "Active" +TRANSFORMATION_CODE_STATUS_INACTIVE = "Inactive" +TRANSFORMATION_CODE_ARTIFACT_NAME = "transformation-code" diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py new file mode 100644 index 0000000000..0b7c747515 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -0,0 +1,461 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Utilities for working with FeatureGroups and FeatureStores.""" +import logging +import os +import time +from typing import Any, Dict, Sequence, Union + +import boto3 +import pandas as pd +from pandas import DataFrame, Series + +from sagemaker.mlops.feature_store import FeatureGroup as CoreFeatureGroup, FeatureGroup +from sagemaker.core.helper.session_helper import Session +from sagemaker.core.s3.client import S3Uploader, S3Downloader +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + ListCollectionType, + StringFeatureDefinition, +) +from sagemaker.mlops.feature_store.ingestion_manager_pandas import IngestionManagerPandas + +from sagemaker.core.utils import unique_name_from_base + + +logger = logging.getLogger(__name__) + +# --- Constants --- + +_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = { + "Integral": "INT", + "Fractional": "FLOAT", + "String": "STRING", +} + +_DTYPE_TO_FEATURE_TYPE_MAP = { + "object": "String", + "string": "String", + "int64": "Integral", + "float64": "Fractional", +} + +_INTEGER_TYPES = {"int_", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"} +_FLOAT_TYPES = {"float_", "float16", "float32", "float64"} + + +def _get_athena_client(session: Session): + """Get Athena client from session.""" + return session.boto_session.client("athena", region_name=session.boto_region_name) + + +def _get_s3_client(session: Session): + """Get S3 client from session.""" + return session.boto_session.client("s3", region_name=session.boto_region_name) + + +def start_query_execution( + session: Session, + catalog: str, + database: str, + query_string: str, + output_location: str, + kms_key: str = None, + workgroup: str = None, +) -> Dict[str, str]: + """Start Athena query execution. + + Args: + session: Session instance for boto calls. + catalog: Name of the data catalog. + database: Name of the database. + query_string: SQL query string. + output_location: S3 URI for query results. + kms_key: KMS key for encryption (default: None). + workgroup: Athena workgroup name (default: None). + + Returns: + Response dict with QueryExecutionId. + """ + kwargs = { + "QueryString": query_string, + "QueryExecutionContext": {"Catalog": catalog, "Database": database}, + "ResultConfiguration": {"OutputLocation": output_location}, + } + if kms_key: + kwargs["ResultConfiguration"]["EncryptionConfiguration"] = { + "EncryptionOption": "SSE_KMS", + "KmsKey": kms_key, + } + if workgroup: + kwargs["WorkGroup"] = workgroup + return _get_athena_client(session).start_query_execution(**kwargs) + + +def get_query_execution(session: Session, query_execution_id: str) -> Dict[str, Any]: + """Get execution status of an Athena query. + + Args: + session: Session instance for boto calls. + query_execution_id: The query execution ID. + + Returns: + Response dict from Athena. + """ + return _get_athena_client(session).get_query_execution(QueryExecutionId=query_execution_id) + + +def wait_for_athena_query(session: Session, query_execution_id: str, poll: int = 5): + """Wait for Athena query to finish. + + Args: + session: Session instance for boto calls. + query_execution_id: The query execution ID. + poll: Polling interval in seconds (default: 5). + """ + while True: + state = get_query_execution(session, query_execution_id)["QueryExecution"]["Status"]["State"] + if state in ("SUCCEEDED", "FAILED"): + logger.info("Query %s %s.", query_execution_id, state.lower()) + break + logger.info("Query %s is being executed.", query_execution_id) + time.sleep(poll) + + +def run_athena_query( + session: Session, + catalog: str, + database: str, + query_string: str, + output_location: str, + kms_key: str = None, +) -> Dict[str, Any]: + """Execute Athena query, wait for completion, and return result. + + Args: + session: Session instance for boto calls. + catalog: Name of the data catalog. + database: Name of the database. + query_string: SQL query string. + output_location: S3 URI for query results. + kms_key: KMS key for encryption (default: None). + + Returns: + Query execution result dict. + + Raises: + RuntimeError: If query fails. + """ + response = start_query_execution( + session=session, + catalog=catalog, + database=database, + query_string=query_string, + output_location=output_location, + kms_key=kms_key, + ) + query_id = response["QueryExecutionId"] + wait_for_athena_query(session, query_id) + + result = get_query_execution(session, query_id) + if result["QueryExecution"]["Status"]["State"] != "SUCCEEDED": + raise RuntimeError(f"Athena query {query_id} failed.") + return result + + +def download_athena_query_result( + session: Session, + bucket: str, + prefix: str, + query_execution_id: str, + filename: str, +): + """Download query result file from S3. + + Args: + session: Session instance for boto calls. + bucket: S3 bucket name. + prefix: S3 key prefix. + query_execution_id: The query execution ID. + filename: Local filename to save to. + """ + _get_s3_client(session).download_file( + Bucket=bucket, + Key=f"{prefix}/{query_execution_id}.csv", + Filename=filename, + ) + + +def upload_dataframe_to_s3( + data_frame: DataFrame, + output_path: str, + session: Session, + kms_key: str = None, +) -> tuple[str, str]: + """Upload DataFrame to S3 as CSV. + + Args: + data_frame: DataFrame to upload. + output_path: S3 URI base path. + session: Session instance for boto calls. + kms_key: KMS key for encryption (default: None). + + Returns: + Tuple of (s3_folder, temp_table_name). + """ + + temp_id = unique_name_from_base("dataframe-base") + local_file = f"{temp_id}.csv" + s3_folder = os.path.join(output_path, temp_id) + + data_frame.to_csv(local_file, index=False, header=False) + S3Uploader.upload( + local_path=local_file, + desired_s3_uri=s3_folder, + sagemaker_session=session, + kms_key=kms_key, + ) + os.remove(local_file) + + table_name = f'dataframe_{temp_id.replace("-", "_")}' + return s3_folder, table_name + + +def download_csv_from_s3( + s3_uri: str, + session: Session, + kms_key: str = None, +) -> DataFrame: + """Download CSV from S3 and return as DataFrame. + + Args: + s3_uri: S3 URI of the CSV file. + session: Session instance for boto calls. + kms_key: KMS key for decryption (default: None). + + Returns: + DataFrame with CSV contents. + """ + + S3Downloader.download( + s3_uri=s3_uri, + local_path="./", + kms_key=kms_key, + sagemaker_session=session, + ) + + local_file = s3_uri.split("/")[-1] + df = pd.read_csv(local_file) + os.remove(local_file) + + metadata_file = f"{local_file}.metadata" + if os.path.exists(metadata_file): + os.remove(metadata_file) + + return df + + +def get_session_from_role(region: str, assume_role: str = None) -> Session: + """Get a Session from a region and optional IAM role. + + Args: + region: AWS region name. + assume_role: IAM role ARN to assume (default: None). + + Returns: + Session instance. + """ + boto_session = boto3.Session(region_name=region) + + if assume_role: + sts = boto_session.client("sts", region_name=region) + credentials = sts.assume_role( + RoleArn=assume_role, + RoleSessionName="SagemakerExecution", + )["Credentials"] + + boto_session = boto3.Session( + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + return Session( + boto_session=boto_session, + sagemaker_client=boto_session.client("sagemaker"), + sagemaker_runtime_client=boto_session.client("sagemaker-runtime"), + sagemaker_featurestore_runtime_client=boto_session.client("sagemaker-featurestore-runtime"), + ) + + +# --- FeatureDefinition Functions --- + +def _is_collection_column(series: Series, sample_size: int = 1000) -> bool: + """Check if column contains list/set values.""" + sample = series.head(sample_size).dropna() + return sample.apply(lambda x: isinstance(x, (list, set))).any() + + +def _generate_feature_definition( + series: Series, + online_storage_type: str = None, +) -> FeatureDefinition: + """Generate a FeatureDefinition from a pandas Series.""" + dtype = str(series.dtype) + collection_type = None + + if online_storage_type == "InMemory" and _is_collection_column(series): + collection_type = ListCollectionType() + + if dtype in _INTEGER_TYPES: + return IntegralFeatureDefinition(series.name, collection_type) + if dtype in _FLOAT_TYPES: + return FractionalFeatureDefinition(series.name, collection_type) + return StringFeatureDefinition(series.name, collection_type) + + +def load_feature_definitions_from_dataframe( + data_frame: DataFrame, + online_storage_type: str = None, +) -> Sequence[FeatureDefinition]: + """Infer FeatureDefinitions from DataFrame dtypes. + + Column name is used as feature name. Feature type is inferred from the dtype + of the column. Integer dtypes are mapped to Integral feature type. Float dtypes + are mapped to Fractional feature type. All other dtypes are mapped to String. + + For IN_MEMORY online_storage_type, collection type columns within DataFrame + will be inferred as List instead of String. + + Args: + data_frame: DataFrame to infer features from. + online_storage_type: "Standard" or "InMemory" (default: None). + + Returns: + List of FeatureDefinition objects. + """ + return [ + _generate_feature_definition(data_frame[col], online_storage_type) + for col in data_frame.columns + ] + + +# --- FeatureGroup Functions --- + +def create_athena_query(feature_group_name: str, session: Session): + """Create an AthenaQuery for a FeatureGroup. + + Args: + feature_group_name: Name of the FeatureGroup. + session: Session instance for Athena boto calls. + + Returns: + AthenaQuery initialized with data catalog config. + + Raises: + RuntimeError: If no metastore is configured. + """ + from sagemaker.mlops.feature_store.athena_query import AthenaQuery + + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + + if not fg.offline_store_config or not fg.offline_store_config.data_catalog_config: + raise RuntimeError("No metastore is configured with this feature group.") + + catalog_config = fg.offline_store_config.data_catalog_config + disable_glue = catalog_config.disable_glue_table_creation or False + + return AthenaQuery( + catalog=catalog_config.catalog if disable_glue else "AwsDataCatalog", + database=catalog_config.database, + table_name=catalog_config.table_name, + sagemaker_session=session, + ) + + +def as_hive_ddl( + feature_group_name: str, + database: str = "sagemaker_featurestore", + table_name: str = None, +) -> str: + """Generate Hive DDL for a FeatureGroup's offline store table. + + Schema of the table is generated based on the feature definitions. Columns are named + after feature name and data-type are inferred based on feature type. Integral feature + type is mapped to INT data-type. Fractional feature type is mapped to FLOAT data-type. + String feature type is mapped to STRING data-type. + + Args: + feature_group_name: Name of the FeatureGroup. + database: Hive database name (default: "sagemaker_featurestore"). + table_name: Hive table name (default: feature_group_name). + + Returns: + CREATE EXTERNAL TABLE DDL string. + """ + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + table_name = table_name or feature_group_name + resolved_output_s3_uri = fg.offline_store_config.s3_storage_config.resolved_output_s3_uri + + ddl = f"CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" + for fd in fg.feature_definitions: + ddl += f" {fd.feature_name} {_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP.get(fd.feature_type)}\n" + ddl += " write_time TIMESTAMP\n" + ddl += " event_time TIMESTAMP\n" + ddl += " is_deleted BOOLEAN\n" + ddl += ")\n" + ddl += ( + "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" + " STORED AS\n" + " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" + " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" + f"LOCATION '{resolved_output_s3_uri}'" + ) + return ddl + + +def ingest_dataframe( + feature_group_name: str, + data_frame: DataFrame, + max_workers: int = 1, + max_processes: int = 1, + wait: bool = True, + timeout: Union[int, float] = None, +): + """Ingest a pandas DataFrame to a FeatureGroup. + + Args: + feature_group_name: Name of the FeatureGroup. + data_frame: DataFrame to ingest. + max_workers: Threads per process (default: 1). + max_processes: Number of processes (default: 1). + wait: Wait for ingestion to complete (default: True). + timeout: Timeout in seconds (default: None). + + Returns: + IngestionManagerPandas instance. + + Raises: + ValueError: If max_workers or max_processes <= 0. + """ + + if max_processes <= 0: + raise ValueError("max_processes must be greater than 0.") + if max_workers <= 0: + raise ValueError("max_workers must be greater than 0.") + + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + feature_definitions = {fd.feature_name: fd.feature_type for fd in fg.feature_definitions} + + manager = IngestionManagerPandas( + feature_group_name=feature_group_name, + feature_definitions=feature_definitions, + max_workers=max_workers, + max_processes=max_processes, + ) + manager.run(data_frame=data_frame, wait=wait, timeout=timeout) + return manager + diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py new file mode 100644 index 0000000000..4d7b4e5375 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py @@ -0,0 +1,318 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Multi-threaded data ingestion for FeatureStore using SageMaker Core.""" +import logging +import math +import signal +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from multiprocessing import Pool +from typing import Any, Dict, Iterable, List, Sequence, Union + +import pandas as pd +from pandas import DataFrame +from pandas.api.types import is_list_like + +from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup +from sagemaker.core.shapes import FeatureValue + +logger = logging.getLogger(__name__) + + +class IngestionError(Exception): + """Exception raised for errors during ingestion. + + Attributes: + failed_rows: List of row indices that failed to ingest. + message: Error message. + """ + + def __init__(self, failed_rows: List[int], message: str): + self.failed_rows = failed_rows + self.message = message + super().__init__(self.message) + + +@dataclass +class IngestionManagerPandas: + """Class to manage the multi-threaded data ingestion process. + + This class will manage the data ingestion process which is multi-threaded. + + Attributes: + feature_group_name (str): name of the Feature Group. + feature_definitions (Dict[str, Dict[Any, Any]]): dictionary of feature definitions + where the key is the feature name and the value is the FeatureDefinition. + The FeatureDefinition contains the data type of the feature. + max_workers (int): number of threads to create. + max_processes (int): number of processes to create. Each process spawns + ``max_workers`` threads. + """ + + feature_group_name: str + feature_definitions: Dict[str, Dict[Any, Any]] + max_workers: int = 1 + max_processes: int = 1 + _async_result: Any = field(default=None, init=False) + _processing_pool: Pool = field(default=None, init=False) + _failed_indices: List[int] = field(default_factory=list, init=False) + + @property + def failed_rows(self) -> List[int]: + """Get rows that failed to ingest. + + Returns: + List of row indices that failed to be ingested. + """ + return self._failed_indices + + def run( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + wait: bool = True, + timeout: Union[int, float] = None, + ): + """Start the ingestion process. + + Args: + data_frame (DataFrame): source DataFrame to be ingested. + target_stores (List[str]): list of target stores ("OnlineStore", "OfflineStore"). + If None, the default target store is used. + wait (bool): whether to wait for the ingestion to finish or not. + timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised + if timeout is reached. + """ + if self.max_workers == 1 and self.max_processes == 1: + self._run_single_process_single_thread(data_frame=data_frame, target_stores=target_stores) + else: + self._run_multi_process(data_frame=data_frame, target_stores=target_stores, wait=wait, timeout=timeout) + + def wait(self, timeout: Union[int, float] = None): + """Wait for the ingestion process to finish. + + Args: + timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised + if timeout is reached. + """ + try: + results = self._async_result.get(timeout=timeout) + except KeyboardInterrupt as e: + self._processing_pool.terminate() + self._processing_pool.join() + raise e + else: + self._processing_pool.close() + self._processing_pool.join() + + self._failed_indices = [idx for failed in results for idx in failed] + + if self._failed_indices: + raise IngestionError( + self._failed_indices, + f"Failed to ingest some data into FeatureGroup {self.feature_group_name}", + ) + + def _run_single_process_single_thread( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + ): + """Ingest utilizing a single process and a single thread.""" + logger.info("Started single-threaded ingestion for %d rows", len(data_frame)) + failed_rows = [] + + fg = CoreFeatureGroup(feature_group_name=self.feature_group_name) + + for row in data_frame.itertuples(): + self._ingest_row( + data_frame=data_frame, + row=row, + feature_group=fg, + feature_definitions=self.feature_definitions, + failed_rows=failed_rows, + target_stores=target_stores, + ) + + self._failed_indices = failed_rows + if self._failed_indices: + raise IngestionError( + self._failed_indices, + f"Failed to ingest some data into FeatureGroup {self.feature_group_name}", + ) + + def _run_multi_process( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + wait: bool = True, + timeout: Union[int, float] = None, + ): + """Start the ingestion process with the specified number of processes.""" + batch_size = math.ceil(data_frame.shape[0] / self.max_processes) + + args = [] + for i in range(self.max_processes): + start_index = min(i * batch_size, data_frame.shape[0]) + end_index = min(i * batch_size + batch_size, data_frame.shape[0]) + args.append(( + self.max_workers, + self.feature_group_name, + self.feature_definitions, + data_frame[start_index:end_index], + target_stores, + start_index, + timeout, + )) + + def init_worker(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + + self._processing_pool = Pool(self.max_processes, init_worker) + + self._async_result = self._processing_pool.starmap_async( + IngestionManagerPandas._run_multi_threaded, + args, + ) + + if wait: + self.wait(timeout=timeout) + + @staticmethod + def _run_multi_threaded( + max_workers: int, + feature_group_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + data_frame: DataFrame, + target_stores: List[str] = None, + row_offset: int = 0, + timeout: Union[int, float] = None, + ) -> List[int]: + """Start multi-threaded ingestion within a single process.""" + executor = ThreadPoolExecutor(max_workers=max_workers) + batch_size = math.ceil(data_frame.shape[0] / max_workers) + + futures = {} + for i in range(max_workers): + start_index = min(i * batch_size, data_frame.shape[0]) + end_index = min(i * batch_size + batch_size, data_frame.shape[0]) + future = executor.submit( + IngestionManagerPandas._ingest_single_batch, + data_frame=data_frame, + feature_group_name=feature_group_name, + feature_definitions=feature_definitions, + start_index=start_index, + end_index=end_index, + target_stores=target_stores, + ) + futures[future] = (start_index + row_offset, end_index + row_offset) + + failed_indices = [] + for future in as_completed(futures, timeout=timeout): + start, end = futures[future] + failed_rows = future.result() + if not failed_rows: + logger.info("Successfully ingested row %d to %d", start, end) + failed_indices.extend(failed_rows) + + executor.shutdown(wait=False) + return failed_indices + + @staticmethod + def _ingest_single_batch( + data_frame: DataFrame, + feature_group_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + start_index: int, + end_index: int, + target_stores: List[str] = None, + ) -> List[int]: + """Ingest a single batch of DataFrame rows into FeatureStore.""" + logger.info("Started ingesting index %d to %d", start_index, end_index) + failed_rows = [] + + fg = CoreFeatureGroup(feature_group_name=feature_group_name) + + for row in data_frame[start_index:end_index].itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=data_frame, + row=row, + feature_group=fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=target_stores, + ) + + return failed_rows + + @staticmethod + def _ingest_row( + data_frame: DataFrame, + row: Iterable, + feature_group: CoreFeatureGroup, + feature_definitions: Dict[str, Dict[Any, Any]], + failed_rows: List[int], + target_stores: List[str] = None, + ): + """Ingest a single DataFrame row into FeatureStore using SageMaker Core.""" + try: + record = [] + for index in range(1, len(row)): + feature_name = data_frame.columns[index - 1] + feature_value = row[index] + + if not IngestionManagerPandas._feature_value_is_not_none(feature_value): + continue + + if IngestionManagerPandas._is_feature_collection_type(feature_name, feature_definitions): + record.append(FeatureValue( + feature_name=feature_name, + value_as_string_list=IngestionManagerPandas._convert_to_string_list(feature_value), + )) + else: + record.append(FeatureValue( + feature_name=feature_name, + value_as_string=str(feature_value), + )) + + # Use SageMaker Core's put_record directly + feature_group.put_record( + record=record, + target_stores=target_stores, + ) + + except Exception as e: + logger.error("Failed to ingest row %d: %s", row[0], e) + failed_rows.append(row[0]) + + @staticmethod + def _is_feature_collection_type( + feature_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + ) -> bool: + """Check if the feature is a collection type.""" + feature_def = feature_definitions.get(feature_name) + if feature_def: + return feature_def.get("CollectionType") is not None + return False + + @staticmethod + def _feature_value_is_not_none(feature_value: Any) -> bool: + """Check if the feature value is not None. + + For Collection Type features, we check if the value is not None. + For Scalar values, we use pd.notna() to keep the behavior same. + """ + if not is_list_like(feature_value): + return pd.notna(feature_value) + return feature_value is not None + + @staticmethod + def _convert_to_string_list(feature_value: List[Any]) -> List[str]: + """Convert a list of feature values to a list of strings.""" + if not is_list_like(feature_value): + raise ValueError( + f"Invalid feature value: {feature_value} for a collection type feature " + f"must be an Array, but was {type(feature_value)}" + ) + return [str(v) if v is not None else None for v in feature_value] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py new file mode 100644 index 0000000000..f264059eb3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py @@ -0,0 +1,60 @@ +"""Enums for FeatureStore operations.""" +from enum import Enum + +class TargetStoreEnum(Enum): + """Store types for put_record.""" + ONLINE_STORE = "OnlineStore" + OFFLINE_STORE = "OfflineStore" + +class OnlineStoreStorageTypeEnum(Enum): + """Storage types for online store.""" + STANDARD = "Standard" + IN_MEMORY = "InMemory" + +class TableFormatEnum(Enum): + """Offline store table formats.""" + GLUE = "Glue" + ICEBERG = "Iceberg" + +class ResourceEnum(Enum): + """Resource types for search.""" + FEATURE_GROUP = "FeatureGroup" + FEATURE_METADATA = "FeatureMetadata" + +class SearchOperatorEnum(Enum): + """Search operators.""" + AND = "And" + OR = "Or" + +class SortOrderEnum(Enum): + """Sort orders.""" + ASCENDING = "Ascending" + DESCENDING = "Descending" + +class FilterOperatorEnum(Enum): + """Filter operators.""" + EQUALS = "Equals" + NOT_EQUALS = "NotEquals" + GREATER_THAN = "GreaterThan" + GREATER_THAN_OR_EQUAL_TO = "GreaterThanOrEqualTo" + LESS_THAN = "LessThan" + LESS_THAN_OR_EQUAL_TO = "LessThanOrEqualTo" + CONTAINS = "Contains" + EXISTS = "Exists" + NOT_EXISTS = "NotExists" + IN = "In" + +class DeletionModeEnum(Enum): + """Deletion modes for delete_record.""" + SOFT_DELETE = "SoftDelete" + HARD_DELETE = "HardDelete" + +class ExpirationTimeResponseEnum(Enum): + """ExpiresAt response toggle.""" + DISABLED = "Disabled" + ENABLED = "Enabled" + +class ThroughputModeEnum(Enum): + """Throughput modes for feature group.""" + ON_DEMAND = "OnDemand" + PROVISIONED = "Provisioned" \ No newline at end of file diff --git a/sagemaker-mlops/tests/__init__.py b/sagemaker-mlops/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/__init__.py b/sagemaker-mlops/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py new file mode 100644 index 0000000000..f34bf7d447 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py new file mode 100644 index 0000000000..9b2ec55895 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Conftest for feature_store tests.""" +import pytest +from unittest.mock import Mock, MagicMock +import pandas as pd +import numpy as np + + +@pytest.fixture +def mock_session(): + """Create a mock Session.""" + session = Mock() + session.boto_session = Mock() + session.boto_region_name = "us-west-2" + session.sagemaker_client = Mock() + session.sagemaker_runtime_client = Mock() + session.sagemaker_featurestore_runtime_client = Mock() + return session + + +@pytest.fixture +def sample_dataframe(): + """Create a sample DataFrame for testing.""" + return pd.DataFrame({ + "id": pd.Series([1, 2, 3, 4, 5], dtype="int64"), + "value": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5], dtype="float64"), + "name": pd.Series(["a", "b", "c", "d", "e"], dtype="string"), + "event_time": pd.Series( + ["2024-01-01T00:00:00Z"] * 5, + dtype="string" + ), + }) + + +@pytest.fixture +def dataframe_with_collections(): + """Create a DataFrame with collection type columns.""" + return pd.DataFrame({ + "id": pd.Series([1, 2, 3], dtype="int64"), + "tags": pd.Series([["a", "b"], ["c"], ["d", "e", "f"]], dtype="object"), + "scores": pd.Series([[1.0, 2.0], [3.0], [4.0, 5.0]], dtype="object"), + "event_time": pd.Series(["2024-01-01"] * 3, dtype="string"), + }) + + +@pytest.fixture +def feature_definitions_dict(): + """Create a feature definitions dictionary.""" + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + "event_time": {"FeatureName": "event_time", "FeatureType": "String"}, + } + + +@pytest.fixture +def mock_feature_group(): + """Create a mock FeatureGroup from core.""" + fg = MagicMock() + fg.feature_group_name = "test-feature-group" + fg.record_identifier_feature_name = "id" + fg.event_time_feature_name = "event_time" + fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="name", feature_type="String"), + MagicMock(feature_name="event_time", feature_type="String"), + ] + fg.offline_store_config = MagicMock() + fg.offline_store_config.s3_storage_config.s3_uri = "s3://bucket/prefix" + fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix/resolved" + fg.offline_store_config.data_catalog_config.catalog = "AwsDataCatalog" + fg.offline_store_config.data_catalog_config.database = "sagemaker_featurestore" + fg.offline_store_config.data_catalog_config.table_name = "test_feature_group" + fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + fg.online_store_config = MagicMock() + fg.online_store_config.enable_online_store = True + return fg diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py new file mode 100644 index 0000000000..12a323f871 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py @@ -0,0 +1,401 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains constants of feature processor to be used for unit tests.""" +from __future__ import absolute_import + +import datetime +from typing import List, Sequence, Union + +from botocore.exceptions import ClientError +from mock import Mock +from pyspark.sql import DataFrame + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.core.lineage._api_types import ContextSource +from sagemaker.core.lineage.artifact import Artifact, ArtifactSource, ArtifactSummary +from sagemaker.core.lineage.context import Context + +PIPELINE_NAME = "test-pipeline-01" +PIPELINE_ARN = "arn:aws:sagemaker:us-west-2:12345789012:pipeline/test-pipeline-01" +CREATION_TIME = "123123123" +LAST_UPDATE_TIME = "234234234" +SAGEMAKER_SESSION_MOCK = Mock(Session) +CONTEXT_MOCK_01 = Mock(Context) +CONTEXT_MOCK_02 = Mock(Context) + + +class MockDataSource(BaseDataSource): + + data_source_unique_id = "test_source_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + +FEATURE_GROUP_DATA_SOURCE: List[FeatureGroupDataSource] = [ + FeatureGroupDataSource( + name="feature-group-01", + ), + FeatureGroupDataSource( + name="feature-group-02", + ), +] + +FEATURE_GROUP_INPUT: List[FeatureGroupContexts] = [ + FeatureGroupContexts( + name="feature-group-01", + pipeline_context_arn="feature-group-01-pipeline-context-arn", + pipeline_version_context_arn="feature-group-01-pipeline-version-context-arn", + ), + FeatureGroupContexts( + name="feature-group-02", + pipeline_context_arn="feature-group-02-pipeline-context-arn", + pipeline_version_context_arn="feature-group-02-pipeline-version-context-arn", + ), +] + +RAW_DATA_INPUT: Sequence[Union[CSVDataSource, ParquetDataSource, BaseDataSource]] = [ + CSVDataSource(s3_uri="raw-data-uri-01"), + CSVDataSource(s3_uri="raw-data-uri-02"), + ParquetDataSource(s3_uri="raw-data-uri-03"), + MockDataSource(), +] + +RAW_DATA_INPUT_ARTIFACTS: List[Artifact] = [ + Artifact(artifact_arn="artifact-01-arn"), + Artifact(artifact_arn="artifact-02-arn"), + Artifact(artifact_arn="artifact-03-arn"), + Artifact(artifact_arn="artifact-04-arn"), +] + +PIPELINE_SCHEDULE = PipelineSchedule( + schedule_name="schedule-name", + schedule_arn="schedule-arn", + schedule_expression="schedule-expression", + pipeline_name="pipeline-name", + state="state", + start_date="123123123", +) + +PIPELINE_SCHEDULE_2 = PipelineSchedule( + schedule_name="schedule-name-2", + schedule_arn="schedule-arn", + schedule_expression="schedule-expression-2", + pipeline_name="pipeline-name", + state="state-2", + start_date="234234234", +) + +PIPELINE_TRIGGER = PipelineTrigger( + trigger_name="trigger-name", + trigger_arn="trigger-arn", + pipeline_name="pipeline-name", + event_pattern="event-pattern", + state="Enabled", +) + +PIPELINE_TRIGGER_2 = PipelineTrigger( + trigger_name="trigger-name-2", + trigger_arn="trigger-arn", + pipeline_name="pipeline-name", + event_pattern="event-pattern-2", + state="Enabled", +) + +PIPELINE_TRIGGER_ARTIFACT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-trigger-trigger-name", + artifact_type="PipelineTrigger", + source={"source_uri": "trigger-arn"}, + properties=dict( + pipeline_name=PIPELINE_TRIGGER.pipeline_name, + event_pattern=PIPELINE_TRIGGER.event_pattern, + state=PIPELINE_TRIGGER.state, + ), +) + +PIPELINE_TRIGGER_ARTIFACT_SUMMARY: ArtifactSummary = ArtifactSummary( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-trigger-trigger-name", + source=ArtifactSource( + source_uri="trigger-arn", + ), + artifact_type="PipelineTrigger", + creation_time=datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), +) + +ARTIFACT_RESULT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" + }, + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 28, 21, 53, 47, 912000), +) + +SCHEDULE_ARTIFACT_RESULT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" + }, + properties=dict( + pipeline_name=PIPELINE_SCHEDULE.pipeline_name, + schedule_expression=PIPELINE_SCHEDULE.schedule_expression, + state=PIPELINE_SCHEDULE.state, + start_date=PIPELINE_SCHEDULE.start_date, + ), + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 28, 21, 53, 47, 912000), +) + +ARTIFACT_SUMMARY: ArtifactSummary = ArtifactSummary( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source=ArtifactSource( + source_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + source_types=[], + ), + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), +) + +TRANSFORMATION_CODE_ARTIFACT_1 = Artifact( + artifact_arn="ts-artifact-01-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + "source_types": [{"source_id_type": "Custom", "value": "1684369626"}], + }, + properties={ + "name": "test-name", + "author": "test-author", + "inclusive_start_date": "1684369626", + "state": "Active", + }, +) + +TRANSFORMATION_CODE_ARTIFACT_2 = Artifact( + artifact_arn="ts-artifact-02-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + "source_types": [{"source_id_type": "Custom", "value": "1684369626"}], + }, + properties={ + "name": "test-name", + "author": "test-author", + "inclusive_start_date": "1684369626", + "state": "Active", + }, +) + +INACTIVE_TRANSFORMATION_CODE_ARTIFACT_1 = Artifact( + artifact_arn="ts-artifact-02-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + "source_types": [{"source_id_type": "Custom", "value": "1684369307"}], + }, + Properties={ + "name": "test-name", + "author": "test-author", + "exclusive_end_date": "1684369626", + "inclusive_start_date": "1684369307", + "state": "Inactive", + }, +) + +VALIDATION_EXCEPTION = ClientError( + {"Error": {"Code": "ValidationException", "Message": "AssociationAlreadyExists"}}, + "Operation", +) + +RESOURCE_NOT_FOUND_EXCEPTION = ClientError( + {"Error": {"Code": "ResourceNotFound", "Message": "ResourceDoesNotExists"}}, + "Operation", +) + +NON_VALIDATION_EXCEPTION = ClientError( + {"Error": {"Code": "NonValidationException", "Message": "NonValidationError"}}, + "Operation", +) + +FEATURE_GROUP_NAME = "feature-group-name-01" +FEATURE_GROUP = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:789975069016:feature-group/feature-group-name-01", + "FeatureGroupName": "feature-group-name-01", + "RecordIdentifierFeatureName": "model_year_status", + "EventTimeFeatureName": "ingest_time", + "FeatureDefinitions": [ + {"FeatureName": "model_year_status", "FeatureType": "String"}, + {"FeatureName": "avg_mileage", "FeatureType": "String"}, + {"FeatureName": "max_mileage", "FeatureType": "String"}, + {"FeatureName": "avg_price", "FeatureType": "String"}, + {"FeatureName": "max_price", "FeatureType": "String"}, + {"FeatureName": "avg_msrp", "FeatureType": "String"}, + {"FeatureName": "max_msrp", "FeatureType": "String"}, + {"FeatureName": "ingest_time", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://sagemaker-us-west-2-789975069016/" + "feature-store/feature-processor/" + "suryans-v2/offline-store", + "ResolvedOutputS3Uri": "s3://sagemaker-us-west-2-" + "789975069016/feature-store/" + "feature-processor/suryans-v2/" + "offline-store/789975069016/" + "sagemaker/us-west-2/" + "offline-store/" + "feature-group-name-01-" + "1682629457/data", + }, + "DisableGlueTableCreation": False, + "DataCatalogConfig": { + "TableName": "feature-group-name-01_1682629457", + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + }, + }, + "RoleArn": "arn:aws:iam::789975069016:role/service-role/AmazonSageMaker-ExecutionRole-20230421T100744", + "FeatureGroupStatus": "Created", + "OnlineStoreTotalSizeBytes": 0, + "ResponseMetadata": { + "RequestId": "8f139791-345d-4388-8d6d-40420495a3c4", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "8f139791-345d-4388-8d6d-40420495a3c4", + "content-type": "application/x-amz-json-1.1", + "content-length": "1608", + "date": "Mon, 01 May 2023 21:42:59 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE = { + "PipelineArn": "arn:aws:sagemaker:us-west-2:597217924798:pipeline/test-pipeline-26", + "PipelineName": "test-pipeline-26", + "PipelineDisplayName": "test-pipeline-26", + "PipelineDefinition": '{"Version": "2020-12-01", "Metadata": {}, ' + '"Parameters": [{"Name": "scheduled-time", "Type": "String"}], ' + '"PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, ' + '"TrialName": {"Get": "Execution.PipelineExecutionId"}}, ' + '"Steps": [{"Name": "test-pipeline-26-training-step", "Type": ' + '"Training", "Arguments": {"AlgorithmSpecification": {"TrainingInputMode": ' + '"File", "TrainingImage": "153931337802.dkr.ecr.us-west-2.amazonaws.com/' + 'sagemaker-spark-processing:3.2-cpu-py39-v1.1", "ContainerEntrypoint": ' + '["/bin/bash", "/opt/ml/input/data/sagemaker_remote_function_bootstrap/' + 'job_driver.sh", "--files", "s3://bugbash-schema-update/temp.sh", ' + '"/opt/ml/input/data/sagemaker_remote_function_bootstrap/spark_app.py"], ' + '"ContainerArguments": ["--s3_base_uri", ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"--region", "us-west-2", "--client_python_version", "3.9"]}, ' + '"OutputDataConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26"}, ' + '"StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": ' + '{"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.m5.xlarge"}, ' + '"RoleArn": "arn:aws:iam::597217924798:role/Admin", "InputDataConfig": ' + '[{"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26/' + 'sagemaker_remote_function_bootstrap", "S3DataDistributionType": ' + '"FullyReplicated"}}, "ChannelName": "sagemaker_remote_function_bootstrap"}, ' + '{"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": ' + '"s3://bugbash-schema-update/sagemaker-2.142.1.dev0-py2.py3-none-any.whl", ' + '"S3DataDistributionType": "FullyReplicated"}}, "ChannelName": ' + '"sagemaker_wheel_file"}], "Environment": {"AWS_DEFAULT_REGION": "us-west-2"}, ' + '"DebugHookConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"CollectionConfigurations": []},' + ' "ProfilerConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"DisableProfiler": false}, "RetryStrategy": {"MaximumRetryAttempts": 1}}}]}', + "RoleArn": "arn:aws:iam::597217924798:role/Admin", + "PipelineStatus": "Active", + "CreationTime": datetime.datetime(2023, 4, 27, 9, 46, 35, 686000), + "LastModifiedTime": datetime.datetime(2023, 4, 27, 20, 27, 36, 648000), + "CreatedBy": {}, + "LastModifiedBy": {}, + "ResponseMetadata": { + "RequestId": "2075bc1c-1b34-4fe5-b7d8-7cfdf784a7d9", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "2075bc1c-1b34-4fe5-b7d8-7cfdf784a7d9", + "content-type": "application/x-amz-json-1.1", + "content-length": "2555", + "date": "Thu, 04 May 2023 00:28:35 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE_CONTEXT: Context = Context( + context_arn=f"{PIPELINE_NAME}-context-arn", + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + context_type="FeatureEngineeringPipeline", + source=ContextSource(source_uri=PIPELINE_ARN, source_types=[]), + properties={ + "PipelineName": PIPELINE_NAME, + "PipelineCreationTime": CREATION_TIME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, +) + +PIPELINE_VERSION_CONTEXT: Context = Context( + context_arn=f"{PIPELINE_NAME}-version-context-arn", + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + context_type=f"FeatureEngineeringPipelineVersion-{PIPELINE_NAME}", + source=ContextSource(source_uri=PIPELINE_ARN, source_types=LAST_UPDATE_TIME), + properties={"PipelineName": PIPELINE_NAME, "LastUpdateTime": LAST_UPDATE_TIME}, +) + +TRANSFORMATION_CODE_INPUT_1: TransformationCode = TransformationCode( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + author="test-author", + name="test-name", +) + +TRANSFORMATION_CODE_INPUT_2: TransformationCode = TransformationCode( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + author="test-author", + name="test-name", +) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py new file mode 100644 index 0000000000..bc725570b9 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from mock import patch, call + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.core.lineage.context import Context + +from test_constants import ( + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, + CONTEXT_MOCK_02, + FEATURE_GROUP, + FEATURE_GROUP_NAME, +) + + +def test_retrieve_feature_group_context_arns(): + with patch.object( + SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP + ) as fg_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn-fep" + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep-ver" + result = FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=FEATURE_GROUP_NAME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result.name == FEATURE_GROUP_NAME + assert result.pipeline_context_arn == "context-arn-fep" + assert result.pipeline_version_context_arn == "context-arn-fep-ver" + fg_describe_method.assert_called_once_with(FEATURE_GROUP_NAME) + context_load.assert_has_calls( + [ + call( + context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline-version", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == context_load.call_count diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py new file mode 100644 index 0000000000..fe5783210a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py @@ -0,0 +1,2966 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy +import datetime +from typing import Iterator, List + +import pytest +from mock import call, patch, Mock + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + TRANSFORMATION_CODE_STATUS_INACTIVE, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.lineage.artifact import Artifact +from test_constants import ( + FEATURE_GROUP_DATA_SOURCE, + FEATURE_GROUP_INPUT, + LAST_UPDATE_TIME, + PIPELINE, + PIPELINE_ARN, + PIPELINE_CONTEXT, + PIPELINE_NAME, + PIPELINE_VERSION_CONTEXT, + RAW_DATA_INPUT, + RAW_DATA_INPUT_ARTIFACTS, + RESOURCE_NOT_FOUND_EXCEPTION, + SAGEMAKER_SESSION_MOCK, + SCHEDULE_ARTIFACT_RESULT, + PIPELINE_TRIGGER_ARTIFACT, + TRANSFORMATION_CODE_ARTIFACT_1, + TRANSFORMATION_CODE_ARTIFACT_2, + TRANSFORMATION_CODE_INPUT_1, + TRANSFORMATION_CODE_INPUT_2, + ARTIFACT_SUMMARY, + ARTIFACT_RESULT, +) + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage import ( + FeatureProcessorLineageHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import ( + PipelineTrigger, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.core.lineage._api_types import AssociationSummary + +SCHEDULE_ARN = "" +SCHEDULE_EXPRESSION = "" +STATE = "" +TRIGGER_ARN = "" +EVENT_PATTERN = "" +START_DATE = datetime.datetime(2023, 4, 28, 21, 53, 47, 912000) +TAGS = [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("scheduler").return_value = Mock() + return Mock(Session, boto_session=boto_session) + + +@pytest.fixture +def event_bridge_scheduler_helper(sagemaker_session): + return EventBridgeSchedulerHelper( + sagemaker_session, sagemaker_session.boto_session.client("scheduler") + ) + + +def test_create_lineage_when_no_lineage_exists_with_fg_only(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + ): + lineage_handler.create_lineage() + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_not_called() + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=[], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_lineage_when_no_lineage_exists_with_raw_data_only(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_called_once_with( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=[], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_no_lineage_exists_with_fg_and_raw_data_with_tags(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_no_lineage_exists_with_no_transformation_code(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=None, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_not_called() + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_not_called() + + +def test_create_lineage_when_already_exist_with_no_version_change(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as create_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_pipeline_version_context_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_raw_data(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=[RAW_DATA_INPUT[0], RAW_DATA_INPUT[1]] + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 2 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_input_fg(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + [FEATURE_GROUP_DATA_SOURCE[0]], + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[FEATURE_GROUP_INPUT[0], FEATURE_GROUP_INPUT[0]], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=[FEATURE_GROUP_INPUT[0]], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_output_fg(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[1].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[1], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[1], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_transformation_code(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_last_transformation_code_as_none(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_1.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_1.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + update_transformation_code_artifact_method.assert_not_called() + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_all_previous_transformation_code_as_none(): + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_not_called() + update_transformation_code_artifact_method.assert_not_called() + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_removed_transformation_code(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=None, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_upstream_transformation_code_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + artifact_set_tags.assert_not_called() + + +def test_get_pipeline_lineage_names_when_no_lineage_exists(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method: + return_value = lineage_handler.get_pipeline_lineage_names() + + assert return_value is None + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_get_pipeline_lineage_names_when_lineage_exists(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + ): + return_value = lineage_handler.get_pipeline_lineage_names() + + assert return_value == dict( + pipeline_context_name=PIPELINE_CONTEXT.context_name, + pipeline_version_context_name=PIPELINE_VERSION_CONTEXT.context_name, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_schedule_lineage(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_schedule_artifact", + return_value=SCHEDULE_ARTIFACT_RESULT, + ) as retrieve_pipeline_schedule_artifact_method, + patch.object( + LineageAssociationHandler, + "add_upstream_schedule_associations", + ) as add_upstream_schedule_associations_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_schedule_lineage( + pipeline_name=PIPELINE_NAME, + schedule_arn=SCHEDULE_ARN, + schedule_expression=SCHEDULE_EXPRESSION, + state=STATE, + start_date=START_DATE, + tags=TAGS, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_pipeline_schedule_artifact_method.assert_called_once_with( + pipeline_schedule=PipelineSchedule( + schedule_name=PIPELINE_NAME, + schedule_arn=SCHEDULE_ARN, + schedule_expression=SCHEDULE_EXPRESSION, + pipeline_name=PIPELINE_NAME, + state=STATE, + start_date=START_DATE.strftime("%s"), + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_schedule_associations_method.assert_called_once_with( + schedule_artifact=SCHEDULE_ARTIFACT_RESULT, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_trigger_lineage(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_trigger_artifact", + return_value=PIPELINE_TRIGGER_ARTIFACT, + ) as retrieve_pipeline_trigger_artifact_method, + patch.object( + LineageAssociationHandler, + "_add_association", + ) as add_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_trigger_lineage( + pipeline_name=PIPELINE_NAME, + trigger_arn=TRIGGER_ARN, + event_pattern=EVENT_PATTERN, + state=STATE, + tags=TAGS, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_pipeline_trigger_artifact_method.assert_called_once_with( + pipeline_trigger=PipelineTrigger( + trigger_name=PIPELINE_NAME, + trigger_arn=TRIGGER_ARN, + pipeline_name=PIPELINE_NAME, + event_pattern=EVENT_PATTERN, + state=STATE, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_association_method.assert_called_once_with( + source_arn=PIPELINE_TRIGGER_ARTIFACT.artifact_arn, + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_upsert_tags_for_lineage_resources(): + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + mock_session = Mock(Session) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=mock_session, + ) + lineage_handler.sagemaker_session.boto_session = Mock() + lineage_handler.sagemaker_session.sagemaker_client = Mock() + with ( + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, "load_artifact_from_arn", return_value=ARTIFACT_RESULT + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, "_load_artifact_from_s3_uri", return_value=ARTIFACT_SUMMARY + ) as load_artifact_from_s3_uri_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + patch.object( + Context, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as context_set_tags, + patch.object( + EventBridgeSchedulerHelper, "describe_schedule", return_value=dict(Arn="schedule_arn") + ) as get_event_bridge_schedule, + patch.object( + EventBridgeRuleHelper, "describe_rule", return_value=dict(Arn="rule_arn") + ) as get_event_bridge_rule, + ): + lineage_handler.upsert_tags_for_lineage_resources(TAGS) + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=mock_session), + ] + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=mock_session, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=mock_session, + ) + + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + load_artifact_from_s3_uri_method.assert_has_calls( + [ + call(s3_uri="schedule_arn", sagemaker_session=mock_session), + call(s3_uri="rule_arn", sagemaker_session=mock_session), + ] + ) + get_event_bridge_schedule.assert_called_once_with(PIPELINE_NAME) + get_event_bridge_rule.assert_called_once_with(PIPELINE_NAME) + load_artifact_from_arn_method.assert_called_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, sagemaker_session=mock_session + ) + + # three raw data artifact, one schedule artifact and one trigger artifact + artifact_set_tags.assert_has_calls( + [ + call(TAGS), + call(TAGS), + call(TAGS), + call(TAGS), + call(TAGS), + ] + ) + # pipeline context and current pipeline version context + context_set_tags.assert_has_calls( + [ + call(TAGS), + call(TAGS), + ] + ) + + +def generate_pipeline_version_upstream_feature_group_list() -> Iterator[AssociationSummary]: + pipeline_version_upstream_fg: List[AssociationSummary] = list() + for feature_group in FEATURE_GROUP_INPUT: + pipeline_version_upstream_fg.append( + AssociationSummary( + source_arn=feature_group.pipeline_version_context_arn, + source_name=f"{feature_group.name}-pipeline-version", + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ) + return iter(pipeline_version_upstream_fg) + + +def generate_pipeline_version_upstream_raw_data_list() -> Iterator[AssociationSummary]: + pipeline_version_upstream_fg: List[AssociationSummary] = list() + for raw_data in RAW_DATA_INPUT_ARTIFACTS: + pipeline_version_upstream_fg.append( + AssociationSummary( + source_arn=raw_data.artifact_arn, + source_name="sm-fs-fe-raw-data", + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ) + return iter(pipeline_version_upstream_fg) + + +def generate_pipeline_version_upstream_transformation_code() -> Iterator[AssociationSummary]: + return iter( + [ + AssociationSummary( + source_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + source_name=TRANSFORMATION_CODE_ARTIFACT_1.artifact_name, + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ] + ) + + +def generate_pipeline_version_downstream_feature_group() -> Iterator[AssociationSummary]: + return iter( + [ + AssociationSummary( + source_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_name=PIPELINE_VERSION_CONTEXT.context_name, + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_version_context_arn, + destination_name=f"{FEATURE_GROUP_INPUT[0].name}-pipeline-version", + association_type="ContributedTo", + ) + ] + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py new file mode 100644 index 0000000000..a3d24dd0b5 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from mock import patch, call +import pytest + +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) +from sagemaker.core.lineage.association import Association +from botocore.exceptions import ClientError + +from test_constants import ( + FEATURE_GROUP_INPUT, + RAW_DATA_INPUT_ARTIFACTS, + VALIDATION_EXCEPTION, + NON_VALIDATION_EXCEPTION, + SAGEMAKER_SESSION_MOCK, + TRANSFORMATION_CODE_ARTIFACT_1, +) + + +def test_add_upstream_feature_group_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for feature_group in FEATURE_GROUP_INPUT: + create_association_method.assert_has_calls( + [ + call( + source_arn=feature_group.pipeline_context_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(FEATURE_GROUP_INPUT) * 2 == create_association_method.call_count + + +def test_add_upstream_feature_group_data_associations_when_association_already_exists(): + with patch.object( + Association, "create", side_effect=VALIDATION_EXCEPTION + ) as create_association_method: + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for feature_group in FEATURE_GROUP_INPUT: + create_association_method.assert_has_calls( + [ + call( + source_arn=feature_group.pipeline_context_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(FEATURE_GROUP_INPUT) * 2 == create_association_method.call_count + + +def test_add_upstream_feature_group_data_associations_when_non_validation_exception(): + with patch.object(Association, "create", side_effect=NON_VALIDATION_EXCEPTION): + with pytest.raises(ClientError): + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_add_upstream_raw_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_raw_data_associations( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for raw_data in RAW_DATA_INPUT_ARTIFACTS: + create_association_method.assert_has_calls( + [ + call( + source_arn=raw_data.artifact_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=raw_data.artifact_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(RAW_DATA_INPUT_ARTIFACTS) * 2 == create_association_method.call_count + + +def test_add_upstream_transformation_code_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_called_once_with( + source_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_add_downstream_feature_group_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_downstream_feature_group_data_associations( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_has_calls( + [ + call( + source_arn="pipeline-context-arn", + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn="pipeline-version-context-arn", + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_version_context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == create_association_method.call_count + + +def test_add_pipeline_and_pipeline_version_association(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_called_once_with( + source_arn="pipeline-context-arn", + destination_arn="pipeline-version-context-arn", + association_type="AssociatedWith", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_list_upstream_associations(): + with patch.object(Association, "list") as list_association_method: + LineageAssociationHandler.list_upstream_associations( + entity_arn="pipeline-context-arn", + source_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_association_method.assert_called_once_with( + source_arn=None, + source_type="FeatureEngineeringPipelineVersion", + destination_arn="pipeline-context-arn", + destination_type=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_list_downstream_associations(): + with patch.object(Association, "list") as list_association_method: + LineageAssociationHandler.list_downstream_associations( + entity_arn="pipeline-context-arn", + destination_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_association_method.assert_called_once_with( + source_arn="pipeline-context-arn", + source_type=None, + destination_arn=None, + destination_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py new file mode 100644 index 0000000000..deb76a7748 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py @@ -0,0 +1,74 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.core.lineage.context import Context +from test_constants import ( + PIPELINE_NAME, + PIPELINE_ARN, + CREATION_TIME, + LAST_UPDATE_TIME, + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, +) + + +def test_create_pipeline_context(): + with patch.object(Context, "create", return_value=CONTEXT_MOCK_01) as create_method: + result = PipelineLineageEntityHandler.create_pipeline_context( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + creation_time=CREATION_TIME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + create_method.assert_called_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + context_type="FeatureEngineeringPipeline", + source_uri=PIPELINE_ARN, + source_type=CREATION_TIME, + properties={ + "PipelineName": PIPELINE_NAME, + "PipelineCreationTime": CREATION_TIME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_load_pipeline_context(): + with patch.object(Context, "load", return_value=CONTEXT_MOCK_01) as load_method: + result = PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=PIPELINE_NAME, + creation_time=CREATION_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + load_method.assert_called_once_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_update_pipeline_context(): + with patch.object(Context, "save", return_value=CONTEXT_MOCK_01): + PipelineLineageEntityHandler.update_pipeline_context(pipeline_context=CONTEXT_MOCK_01) + CONTEXT_MOCK_01.save.assert_called_once() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py new file mode 100644 index 0000000000..c936c3c164 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger + + +def test_pipeline_trigger(): + + trigger = PipelineTrigger( + trigger_name="test_trigger", + trigger_arn="test_arn", + event_pattern="test_pattern", + pipeline_name="test_pipeline", + state="Enabled", + ) + + assert trigger.trigger_name == "test_trigger" + assert trigger.trigger_arn == "test_arn" + assert trigger.event_pattern == "test_pattern" + assert trigger.pipeline_name == "test_pipeline" + assert trigger.state == "Enabled" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py new file mode 100644 index 0000000000..52e65749be --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) + +from sagemaker.core.lineage.context import Context + +from test_constants import ( + PIPELINE_NAME, + PIPELINE_ARN, + LAST_UPDATE_TIME, + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, +) + + +def test_create_pipeline_version_context(): + with patch.object(Context, "create", return_value=CONTEXT_MOCK_01) as create_method: + result = PipelineVersionLineageEntityHandler.create_pipeline_version_context( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + create_method.assert_called_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + context_type=f"FeatureEngineeringPipelineVersion-{PIPELINE_NAME}", + source_uri=PIPELINE_ARN, + source_type=LAST_UPDATE_TIME, + properties={ + "PipelineName": PIPELINE_NAME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_load_pipeline_version_context(): + with patch.object(Context, "load", return_value=CONTEXT_MOCK_01) as load_method: + result = PipelineVersionLineageEntityHandler.load_pipeline_version_context( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + load_method.assert_called_once_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py new file mode 100644 index 0000000000..8b34806f1d --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py @@ -0,0 +1,434 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy + +from mock import patch +from test_constants import ( + ARTIFACT_RESULT, + ARTIFACT_SUMMARY, + PIPELINE_SCHEDULE, + PIPELINE_SCHEDULE_2, + PIPELINE_TRIGGER, + PIPELINE_TRIGGER_2, + PIPELINE_TRIGGER_ARTIFACT, + PIPELINE_TRIGGER_ARTIFACT_SUMMARY, + SCHEDULE_ARTIFACT_RESULT, + TRANSFORMATION_CODE_ARTIFACT_1, + TRANSFORMATION_CODE_INPUT_1, + LAST_UPDATE_TIME, + MockDataSource, +) +from test_pipeline_lineage_entity_handler import SAGEMAKER_SESSION_MOCK + +from sagemaker.mlops.feature_store.feature_processor import CSVDataSource +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.core.lineage.artifact import Artifact + +raw_data = CSVDataSource( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" +) + + +def test_retrieve_raw_data_artifact_when_artifact_already_exist(): + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=raw_data, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=raw_data.s3_uri, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_create_method.assert_not_called() + + +def test_retrieve_raw_data_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=raw_data, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=raw_data.s3_uri, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=raw_data.s3_uri, + artifact_type="DataSet", + artifact_name="sm-fs-fe-raw-data", + properties=None, + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_user_defined_raw_data_artifact_when_artifact_already_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_create_method.assert_not_called() + + +def test_retrieve_user_defined_raw_data_artifact_when_artifact_does_not_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, + artifact_type="DataSet", + artifact_name=data_source.data_source_name, + properties=None, + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact(): + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == TRANSFORMATION_CODE_ARTIFACT_1 + + artifact_create_method.assert_called_once_with( + source_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=LAST_UPDATE_TIME)], + artifact_type="TransformationCode", + artifact_name=f"sm-fs-fe-transformation-code-{LAST_UPDATE_TIME}", + properties=dict( + name=TRANSFORMATION_CODE_INPUT_1.name, + author=TRANSFORMATION_CODE_INPUT_1.author, + state="Active", + inclusive_start_date=LAST_UPDATE_TIME, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact_when_no_author_or_name(): + transformation_code_input = TransformationCode(s3_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri) + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=transformation_code_input, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == TRANSFORMATION_CODE_ARTIFACT_1 + + artifact_create_method.assert_called_once_with( + source_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=LAST_UPDATE_TIME)], + artifact_type="TransformationCode", + artifact_name=f"sm-fs-fe-transformation-code-{LAST_UPDATE_TIME}", + properties=dict( + state="Active", + inclusive_start_date=LAST_UPDATE_TIME, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact_when_no_code_provided(): + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=None, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result is None + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + artifact_type="PipelineSchedule", + artifact_name=f"sm-fs-fe-{PIPELINE_SCHEDULE.schedule_name}", + properties=dict( + pipeline_name=PIPELINE_SCHEDULE.pipeline_name, + schedule_expression=PIPELINE_SCHEDULE.schedule_expression, + state=PIPELINE_SCHEDULE.state, + start_date=PIPELINE_SCHEDULE.start_date, + ), + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_exists(): + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=SCHEDULE_ARTIFACT_RESULT + ) as artifact_load_method: + with patch.object(SCHEDULE_ARTIFACT_RESULT, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=SCHEDULE_ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == SCHEDULE_ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_updated(): + schedule_artifact_result = copy.deepcopy(SCHEDULE_ARTIFACT_RESULT) + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=schedule_artifact_result + ) as artifact_load_method: + with patch.object(schedule_artifact_result, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=schedule_artifact_result + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == schedule_artifact_result + assert schedule_artifact_result != SCHEDULE_ARTIFACT_RESULT + assert result.properties["pipeline_name"] == PIPELINE_SCHEDULE_2.pipeline_name + assert result.properties["schedule_expression"] == PIPELINE_SCHEDULE_2.schedule_expression + assert result.properties["state"] == PIPELINE_SCHEDULE_2.state + assert result.properties["start_date"] == PIPELINE_SCHEDULE_2.start_date + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == PIPELINE_TRIGGER_ARTIFACT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + artifact_type="PipelineTrigger", + artifact_name=f"sm-fs-fe-trigger-{PIPELINE_TRIGGER.trigger_name}", + properties=dict( + pipeline_name=PIPELINE_TRIGGER.pipeline_name, + event_pattern=PIPELINE_TRIGGER.event_pattern, + state=PIPELINE_TRIGGER.state, + ), + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_exists(): + with patch.object( + Artifact, "list", return_value=[PIPELINE_TRIGGER_ARTIFACT_SUMMARY] + ) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_load_method: + with patch.object(PIPELINE_TRIGGER_ARTIFACT, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == PIPELINE_TRIGGER_ARTIFACT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=PIPELINE_TRIGGER_ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_updated(): + trigger_artifact_result = copy.deepcopy(PIPELINE_TRIGGER_ARTIFACT) + with patch.object( + Artifact, "list", return_value=[PIPELINE_TRIGGER_ARTIFACT_SUMMARY] + ) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=trigger_artifact_result + ) as artifact_load_method: + with patch.object(trigger_artifact_result, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=trigger_artifact_result + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == trigger_artifact_result + assert trigger_artifact_result != PIPELINE_TRIGGER_ARTIFACT + assert result.properties["pipeline_name"] == PIPELINE_TRIGGER_2.pipeline_name + assert result.properties["event_pattern"] == PIPELINE_TRIGGER_2.event_pattern + assert result.properties["state"] == PIPELINE_TRIGGER_2.state + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=PIPELINE_TRIGGER_ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py new file mode 100644 index 0000000000..25ded72266 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py @@ -0,0 +1,317 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from mock import Mock, patch +from sagemaker.mlops.feature_store.feature_processor._config_uploader import ( + ConfigUploader, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + SPARK_JAR_FILES_PATH, + SPARK_FILES_PATH, + SPARK_PY_FILES_PATH, + SAGEMAKER_WHL_FILE_S3_PATH, +) +from sagemaker.core.remote_function.job import ( + _JobSettings, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, +) +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.helper.session_helper import Session + + +@pytest.fixture +def sagemaker_session(): + return Mock(Session) + + +@pytest.fixture +def wrapped_func(): + return Mock() + + +@pytest.fixture +def runtime_env_manager(): + mocked_runtime_env_manager = Mock() + mocked_runtime_env_manager.snapshot.return_value = "some_dependency_path" + return mocked_runtime_env_manager + + +def custom_file_filter(): + pass + + +@pytest.fixture +def remote_decorator_config(sagemaker_session): + return Mock( + _JobSettings, + sagemaker_session=sagemaker_session, + s3_root_uri="some_s3_uri", + s3_kms_key="some_kms", + spark_config=SparkConfig(), + dependencies=None, + include_local_workdir=True, + workdir_config=None, + pre_execution_commands="some_commands", + pre_execution_script="some_path", + python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, + custom_file_filter=None, + ) + + +@pytest.fixture +def config_uploader(remote_decorator_config, runtime_env_manager): + return ConfigUploader(remote_decorator_config, runtime_env_manager) + + +@pytest.fixture +def remote_decorator_config_with_filter(sagemaker_session): + return Mock( + _JobSettings, + sagemaker_session=sagemaker_session, + s3_root_uri="some_s3_uri", + s3_kms_key="some_kms", + spark_config=SparkConfig(), + dependencies=None, + include_local_workdir=True, + pre_execution_commands="some_commands", + pre_execution_script="some_path", + python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, + custom_file_filter=custom_file_filter, + ) + + +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.StoredFunction") +def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrapped_func): + mock_stored_function.save(wrapped_func).return_value = None + config_uploader._prepare_and_upload_callable(wrapped_func, "s3_base_uri", sagemaker_session) + assert mock_stored_function.called_once_with( + s3_base_uri="s3_base_uri", + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +def test_prepare_and_upload_workspace(mock_upload, config_uploader): + remote_decorator_config = config_uploader.remote_decorator_config + s3_path = config_uploader._prepare_and_upload_workspace( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_path == mock_upload.return_value + mock_upload.assert_called_once_with( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=None, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +def test_prepare_and_upload_workspace_with_filter( + mock_job_upload, remote_decorator_config_with_filter, runtime_env_manager +): + config_uploader_with_filter = ConfigUploader( + remote_decorator_config=remote_decorator_config_with_filter, + runtime_env_manager=runtime_env_manager, + ) + remote_decorator_config = config_uploader_with_filter.remote_decorator_config + config_uploader_with_filter._prepare_and_upload_workspace( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=remote_decorator_config_with_filter.custom_file_filter, + ) + + mock_job_upload.assert_called_once_with( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=custom_file_filter, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +def test_prepare_and_upload_runtime_scripts(mock_upload, config_uploader): + s3_path = config_uploader._prepare_and_upload_runtime_scripts( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_path == mock_upload.return_value + mock_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +def test_prepare_and_upload_spark_dependent_files(mock_upload, config_uploader): + s3_paths = config_uploader._prepare_and_upload_spark_dependent_files( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_paths == mock_upload.return_value + mock_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.Channel") +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.DataSource") +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.S3DataSource") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.StoredFunction") +def test_prepare_step_input_channel( + mock_upload_callable, + mock_script_upload, + mock_dependency_upload, + mock_spark_dependency_upload, + mock_s3_data_source, + mock_data_source, + mock_channel, + config_uploader, + wrapped_func, +): + ( + input_data_config, + spark_dependency_paths, + ) = config_uploader.prepare_step_input_channel_for_spark_mode( + wrapped_func, + config_uploader.remote_decorator_config.s3_root_uri, + sagemaker_session, + ) + remote_decorator_config = config_uploader.remote_decorator_config + + assert mock_upload_callable.called_once_with(wrapped_func) + + mock_script_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path="some_dependency_path", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + custom_file_filter=None, + ) + + mock_spark_dependency_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + ) + + # Verify input_data_config is a list of Channel objects + assert isinstance(input_data_config, list) + # 3 channels: runtime scripts, workspace, spark conf + assert len(input_data_config) == 3 + + # Verify each channel was constructed with the correct data + # Channel 1: runtime scripts + mock_s3_data_source.assert_any_call( + s3_uri="some_s3_uri", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + # Channel 2: workspace + mock_s3_data_source.assert_any_call( + s3_uri=f"{config_uploader.remote_decorator_config.s3_root_uri}/sm_rf_user_ws", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + # Channel 3: spark conf + mock_s3_data_source.assert_any_call( + s3_uri="path_d", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + + assert mock_s3_data_source.call_count == 3 + assert mock_data_source.call_count == 3 + assert mock_channel.call_count == 3 + + # Verify channel names and input_mode + channel_call_kwargs = [call.kwargs for call in mock_channel.call_args_list] + channel_names = [kw["channel_name"] for kw in channel_call_kwargs] + assert RUNTIME_SCRIPTS_CHANNEL_NAME in channel_names + assert REMOTE_FUNCTION_WORKSPACE in channel_names + assert SPARK_CONF_CHANNEL_NAME in channel_names + for kw in channel_call_kwargs: + assert kw["input_mode"] == "File" + + assert spark_dependency_paths == { + SPARK_JAR_FILES_PATH: "path_a", + SPARK_PY_FILES_PATH: "path_b", + SPARK_FILES_PATH: "path_c", + } diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py new file mode 100644 index 0000000000..e69499c5b1 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import json + +from dateutil.tz import tzlocal +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + +INPUT_S3_URI = "s3://bucket/prefix/" +INPUT_FEATURE_GROUP_NAME = "input-fg" +INPUT_FEATURE_GROUP_ARN = "arn:aws:sagemaker:us-west-2:12345789012:feature-group/input-fg" +INPUT_FEATURE_GROUP_S3_URI = "s3://bucket/input-fg/" +INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI = ( + "s3://bucket/input-fg/feature-store/12345789012/" + "sagemaker/us-west-2/offline-store/input-fg-12345/data" +) + +FEATURE_GROUP_DATA_SOURCE = FeatureGroupDataSource(name=INPUT_FEATURE_GROUP_ARN) +S3_DATA_SOURCE = CSVDataSource(s3_uri=INPUT_S3_URI) +FEATURE_PROCESSOR_INPUTS = [FEATURE_GROUP_DATA_SOURCE, S3_DATA_SOURCE] +OUTPUT_FEATURE_GROUP_ARN = "arn:aws:sagemaker:us-west-2:12345789012:feature-group/output-fg" + +FEATURE_GROUP_SYSTEM_PARAMS = { + "feature_group_name": "input-fg", + "online_store_enabled": True, + "offline_store_enabled": False, + "offline_store_resolved_s3_uri": None, +} +SYSTEM_PARAMS = {"system": {"scheduled_time": "2023-03-25T02:01:26Z"}} +USER_INPUT_PARAMS = { + "some-key": "some-value", + "some-other-key": {"some-key": "some-value"}, +} + +DATA_SOURCE_UNIQUE_ID_TOO_LONG = """ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +""" + +DESCRIBE_FEATURE_GROUP_RESPONSE = { + "FeatureGroupArn": INPUT_FEATURE_GROUP_ARN, + "FeatureGroupName": INPUT_FEATURE_GROUP_NAME, + "RecordIdentifierFeatureName": "id", + "EventTimeFeatureName": "ingest_time", + "FeatureDefinitions": [ + {"FeatureName": "id", "FeatureType": "String"}, + {"FeatureName": "model", "FeatureType": "String"}, + {"FeatureName": "model_year", "FeatureType": "String"}, + {"FeatureName": "status", "FeatureType": "String"}, + {"FeatureName": "mileage", "FeatureType": "String"}, + {"FeatureName": "price", "FeatureType": "String"}, + {"FeatureName": "msrp", "FeatureType": "String"}, + {"FeatureName": "ingest_time", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime(2023, 3, 29, 19, 15, 47, 20000, tzinfo=tzlocal()), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": INPUT_FEATURE_GROUP_S3_URI, + "ResolvedOutputS3Uri": INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI, + }, + "DisableGlueTableCreation": False, + "DataCatalogConfig": { + "TableName": "input_fg_1680142547", + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + }, + }, + "RoleArn": "arn:aws:iam::12345789012:role/role-name", + "FeatureGroupStatus": "Created", + "OnlineStoreTotalSizeBytes": 12345, + "ResponseMetadata": { + "RequestId": "d36d3647-1632-4f4e-9f7c-2a4e38e4c6f8", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "d36d3647-1632-4f4e-9f7c-2a4e38e4c6f8", + "content-type": "application/x-amz-json-1.1", + "content-length": "1311", + "date": "Fri, 31 Mar 2023 01:05:49 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE = { + "PipelineArn": "some_pipeline_arn", + "RoleArn": "some_execution_role_arn", + "CreationTime": datetime.datetime(2023, 3, 29, 19, 15, 47, 20000, tzinfo=tzlocal()), + "PipelineDefinition": json.dumps( + { + "Steps": [ + { + "RetryPolicies": [ + { + "BackoffRate": 2.0, + "IntervalSeconds": 1, + "MaxAttempts": 5, + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + }, + { + "BackoffRate": 2.0, + "IntervalSeconds": 1, + "MaxAttempts": 5, + "ExceptionType": [ + "SageMaker.JOB_INTERNAL_ERROR", + "SageMaker.CAPACITY_ERROR", + "SageMaker.RESOURCE_LIMIT", + ], + }, + ] + } + ] + } + ), +} + + +def create_fp_config( + inputs=None, + output=OUTPUT_FEATURE_GROUP_ARN, + mode=FeatureProcessorMode.PYSPARK, + target_stores=None, + enable_ingestion=True, + parameters=None, + spark_config=None, +): + """Helper method to create a FeatureProcessorConfig with fewer arguments.""" + + return FeatureProcessorConfig.create( + inputs=inputs or FEATURE_PROCESSOR_INPUTS, + output=output, + mode=mode, + target_stores=target_stores, + enable_ingestion=enable_ingestion, + parameters=parameters, + spark_config=spark_config, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py new file mode 100644 index 0000000000..51637fa979 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor._data_source import PySparkDataSource + + +def test_pyspark_data_source(): + class TestDataSource(PySparkDataSource): + + data_source_unique_id = "test_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestDataSource() + + assert test_data_source.data_source_name == "test_source_name" + assert test_data_source.data_source_unique_id == "test_unique_id" + assert test_data_source.read_data(spark=None, params=None) is None diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py new file mode 100644 index 0000000000..4aff330087 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +from mock import mock_open, patch +import pytest +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +SINGLE_NODE_RESOURCE_CONFIG = { + "current_host": "algo-1", + "current_instance_type": "ml.m5.xlarge", + "current_group_name": "homogeneousCluster", + "hosts": ["algo-1"], + "instance_groups": [ + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-1"], + } + ], + "network_interface_name": "eth0", +} +MULTI_NODE_COUNT = 3 +MULTI_NODE_RESOURCE_CONFIG = { + "current_host": "algo-1", + "current_instance_type": "ml.m5.xlarge", + "current_group_name": "homogeneousCluster", + "hosts": ["algo-1", "algo-2", "algo-3"], + "instance_groups": [ + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-1"], + }, + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-2"], + }, + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-3"], + }, + ], + "network_interface_name": "eth0", +} + + +@patch("builtins.open") +def test_is_training_job(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().is_training_job() is True + + mocked_open.assert_called_once_with("/opt/ml/input/config/resourceconfig.json", "r") + + +@patch("builtins.open") +def test_is_not_training_job(mocked_open): + mocked_open.side_effect = FileNotFoundError() + + assert EnvironmentHelper().is_training_job() is False + + +@patch("builtins.open") +def test_get_instance_count_single_node(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().get_instance_count() == 1 + + +@patch("builtins.open") +def test_get_instance_count_multi_node(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(MULTI_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().get_instance_count() == MULTI_NODE_COUNT + + +@patch("builtins.open") +def test_load_training_resource_config(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().load_training_resource_config() == SINGLE_NODE_RESOURCE_CONFIG + + +@patch("builtins.open") +def test_load_training_resource_config_none(mocked_open): + mocked_open.side_effect = FileNotFoundError() + + assert EnvironmentHelper().load_training_resource_config() is None + + +@pytest.mark.parametrize( + "is_training_result", + [(True), (False)], +) +@patch("datetime.now.strftime", return_value="test_current_time") +@patch("sagemaker.mlops.feature_store.feature_processor._env.EnvironmentHelper.is_training_job") +@patch("os.environ", return_value={"scheduled_time": "test_time"}) +def get_job_scheduled_time(mock_env, mock_is_training, mock_datetime, is_training_result): + + mock_is_training.return_value = is_training_result + output_time = EnvironmentHelper().get_job_scheduled_time + + if is_training_result: + assert output_time == "test_scheduled_time" + else: + assert output_time == "test_current_time" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py new file mode 100644 index 0000000000..f44472d519 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py @@ -0,0 +1,301 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import Mock +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from botocore.exceptions import ClientError +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) +import pytest + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("events").return_value = Mock() + return Mock(Session, boto_session=boto_session, sagemaker_client=Mock()) + + +@pytest.fixture +def event_bridge_rule_helper(sagemaker_session): + return EventBridgeRuleHelper(sagemaker_session, sagemaker_session.boto_session.client("events")) + + +def test_put_rule_without_event_pattern(event_bridge_rule_helper): + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_rule = Mock( + return_value=dict(RuleArn="rule_arn") + ) + event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline="target_pipeline", + event_pattern=None, + state="Disabled", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_rule.assert_called_with( + Name="target_pipeline", + EventPattern=( + '{"detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], ' + '"source": ["aws.sagemaker"], "detail": {"currentPipelineExecutionStatus": ' + '["Succeeded"], "pipelineArn": ["pipeline_arn"]}}' + ), + State="Disabled", + ) + + +def test_put_rule_with_event_pattern(event_bridge_rule_helper): + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_rule = Mock( + return_value=dict(RuleArn="rule_arn") + ) + event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline="target_pipeline", + event_pattern="event_pattern", + state="Disabled", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_rule.assert_called_with( + Name="target_pipeline", + EventPattern="event_pattern", + State="Disabled", + ) + + +def test_put_targets_success(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_targets = Mock( + return_value=dict(FailedEntryCount=0) + ) + event_bridge_rule_helper.put_target( + rule_name="rule_name", + target_pipeline="target_pipeline", + target_pipeline_parameters={"param": "value"}, + role_arn="role_arn", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_targets.assert_called_with( + Rule="rule_name", + Targets=[ + { + "Id": "pipeline_name", + "Arn": "pipeline_arn", + "RoleArn": "role_arn", + "SageMakerPipelineParameters": {"PipelineParameterList": {"param": "value"}}, + } + ], + ) + + +def test_put_targets_failure(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_targets = Mock( + return_value=dict( + FailedEntryCount=1, + FailedEntries=[dict(ErrorMessage="test_error_message")], + ) + ) + with pytest.raises( + Exception, match="Failed to add target pipeline to rule. Failure reason: test_error_message" + ): + event_bridge_rule_helper.put_target( + rule_name="rule_name", + target_pipeline="target_pipeline", + target_pipeline_parameters={"param": "value"}, + role_arn="role_arn", + ) + + +def test_delete_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.delete_rule = Mock() + event_bridge_rule_helper.delete_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.delete_rule.assert_called_with( + Name="rule_name" + ) + + +def test_describe_rule_success(event_bridge_rule_helper): + mock_describe_response = dict(State="ENABLED", RuleName="rule_name") + event_bridge_rule_helper.event_bridge_rule_client.describe_rule = Mock( + return_value=mock_describe_response + ) + assert event_bridge_rule_helper.describe_rule("rule_name") == mock_describe_response + + +def test_describe_rule_non_existent(event_bridge_rule_helper): + mock_describe_response = dict(State="ENABLED", RuleName="rule_name") + event_bridge_rule_helper.event_bridge_rule_client.describe_rule = Mock( + return_value=mock_describe_response, + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="describe_rule", + ), + ) + assert event_bridge_rule_helper.describe_rule("rule_name") is None + + +def test_remove_targets(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.remove_targets = Mock() + event_bridge_rule_helper.remove_targets(rule_name="rule_name", ids=["target_pipeline"]) + event_bridge_rule_helper.event_bridge_rule_client.remove_targets.assert_called_with( + Rule="rule_name", + Ids=["target_pipeline"], + ) + + +def test_enable_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.enable_rule = Mock() + event_bridge_rule_helper.enable_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.enable_rule.assert_called_with( + Name="rule_name" + ) + + +def test_disable_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.disable_rule = Mock() + event_bridge_rule_helper.disable_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.disable_rule.assert_called_with( + Name="rule_name" + ) + + +def test_add_tags(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.tag_resource = Mock() + event_bridge_rule_helper.add_tags("rule_arn", [{"key": "value"}]) + + event_bridge_rule_helper.event_bridge_rule_client.tag_resource.assert_called_with( + ResourceARN="rule_arn", Tags=[{"key": "value"}] + ) + + +def test_generate_event_pattern_from_feature_processor_pipeline_events(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_pattern = ( + event_bridge_rule_helper._generate_event_pattern_from_feature_processor_pipeline_events( + [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_2", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + ) + ) + + assert ( + event_pattern + == '{"detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], ' + '"$or": [{"source": ["aws.sagemaker"], "detail": {"currentPipelineExecutionStatus": ' + '["Failed"], "pipelineArn": ["pipeline_arn"]}}, {"source": ["aws.sagemaker"], "detail": ' + '{"currentPipelineExecutionStatus": ["Failed"], "pipelineArn": ["pipeline_arn"]}}]}' + ) + + +def test_validate_feature_processor_pipeline_events(event_bridge_rule_helper): + events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + + with pytest.raises(ValueError, match="Pipeline names in pipeline_events must be unique."): + event_bridge_rule_helper._validate_feature_processor_pipeline_events(events) + + +def test_aggregate_pipeline_events_with_same_desired_status(event_bridge_rule_helper): + events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_2", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + + assert event_bridge_rule_helper._aggregate_pipeline_events_with_same_desired_status(events) == { + (FeatureProcessorPipelineExecutionStatus.FAILED,): [ + "test_pipeline_1", + "test_pipeline_2", + ] + } + + +@pytest.mark.parametrize( + "pipeline_uri,expected_result", + [ + ( + "arn:aws:sagemaker:us-west-2:123456789012:pipeline/test-pipeline", + dict( + pipeline_arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test-pipeline", + pipeline_name="test-pipeline", + ), + ), + ( + "test-pipeline", + dict( + pipeline_arn="test-pipeline-arn", + pipeline_name="test-pipeline", + ), + ), + ], +) +def test_generate_pipeline_arn_and_name(event_bridge_rule_helper, pipeline_uri, expected_result): + event_bridge_rule_helper.sagemaker_session.sagemaker_client.describe_pipeline = Mock( + return_value=dict(PipelineArn="test-pipeline-arn") + ) + assert event_bridge_rule_helper._generate_pipeline_arn_and_name(pipeline_uri) == expected_result diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py new file mode 100644 index 0000000000..baca05e84d --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py @@ -0,0 +1,96 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from datetime import datetime +import pytest +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from mock import Mock + +from sagemaker.core.helper.session_helper import Session + +SCHEDULE_NAME = "test_schedule" +SCHEDULE_ARN = "test_schedule_arn" +NEW_SCHEDULE_ARN = "test_new_schedule_arn" +TARGET_ARN = "test_arn" +CRON_SCHEDULE = "test_cron" +STATE = "ENABLED" +ROLE = "test_role" +START_DATE = datetime.now() + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("scheduler").return_value = Mock() + return Mock(Session, boto_session=boto_session) + + +@pytest.fixture +def event_bridge_scheduler_helper(sagemaker_session): + return EventBridgeSchedulerHelper( + sagemaker_session, sagemaker_session.boto_session.client("scheduler") + ) + + +def test_upsert_schedule_already_exists(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.event_bridge_scheduler_client.update_schedule.return_value = ( + SCHEDULE_ARN + ) + schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=SCHEDULE_NAME, + pipeline_arn=TARGET_ARN, + schedule_expression=CRON_SCHEDULE, + state=STATE, + start_date=START_DATE, + role=ROLE, + ) + assert schedule_arn == SCHEDULE_ARN + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.assert_not_called() + + +def test_upsert_schedule_not_exists(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.event_bridge_scheduler_client.update_schedule.side_effect = ( + ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="update_schedule", + ) + ) + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.return_value = ( + NEW_SCHEDULE_ARN + ) + + schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=SCHEDULE_NAME, + pipeline_arn=TARGET_ARN, + schedule_expression=CRON_SCHEDULE, + state=STATE, + start_date=START_DATE, + role=ROLE, + ) + assert schedule_arn == NEW_SCHEDULE_ARN + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.assert_called_once() + + +def test_delete_schedule(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.sagemaker_session.boto_session = Mock() + event_bridge_scheduler_helper.sagemaker_session.sagemaker_client = Mock() + event_bridge_scheduler_helper.delete_schedule(schedule_name=TARGET_ARN) + event_bridge_scheduler_helper.event_bridge_scheduler_client.delete_schedule.assert_called_with( + Name=TARGET_ARN + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py new file mode 100644 index 0000000000..27a3b96231 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ( + FeatureProcessorArgValidator, + InputValidator, + SparkUDFSignatureValidator, + InputOffsetValidator, + BaseDataSourceValidator, +) +from sagemaker.core.helper.session_helper import Session + + +def test_get_validation_chain(): + fp_config = tdh.create_fp_config(mode=FeatureProcessorMode.PYSPARK) + result = ValidatorFactory.get_validation_chain(fp_config) + + assert result.validators is not None + assert { + InputValidator, + FeatureProcessorArgValidator, + InputOffsetValidator, + BaseDataSourceValidator, + SparkUDFSignatureValidator, + } == {type(instance) for instance in result.validators} + + +def test_get_udf_wrapper(): + fp_config = tdh.create_fp_config(mode=FeatureProcessorMode.PYSPARK) + udf_wrapper = Mock(UDFWrapper) + + with patch.object( + UDFWrapperFactory, "_get_spark_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper_method: + result = UDFWrapperFactory.get_udf_wrapper(fp_config) + + assert result == udf_wrapper + get_udf_wrapper_method.assert_called_with(fp_config) + + +def test_get_udf_wrapper_invalid_mode(): + fp_config = Mock(FeatureProcessorConfig) + fp_config.mode = FeatureProcessorMode.PYTHON + fp_config.sagemaker_session = Mock(Session) + + with pytest.raises( + ValueError, + match=r"FeatureProcessorMode FeatureProcessorMode.PYTHON is not supported\.", + ): + UDFWrapperFactory.get_udf_wrapper(fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py new file mode 100644 index 0000000000..6f6c8471aa --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ValidatorChain +from sagemaker.mlops.feature_store.feature_processor.feature_processor import ( + feature_processor, +) + + +@pytest.fixture +def udf(): + return Mock(Callable) + + +@pytest.fixture +def wrapped_udf(): + return Mock() + + +@pytest.fixture +def udf_wrapper(wrapped_udf): + mock = Mock(UDFWrapper) + mock.wrap.return_value = wrapped_udf + return mock + + +@pytest.fixture +def validator_chain(): + return Mock(ValidatorChain) + + +@pytest.fixture +def fp_config(): + mock = Mock(FeatureProcessorConfig) + mock.mode = FeatureProcessorMode.PYSPARK + return mock + + +def test_feature_processor(udf, udf_wrapper, validator_chain, fp_config, wrapped_udf): + with patch.object( + FeatureProcessorConfig, "create", return_value=fp_config + ) as fp_config_create_method: + with patch.object( + UDFWrapperFactory, "get_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper: + with patch.object( + ValidatorFactory, + "get_validation_chain", + return_value=validator_chain, + ) as get_validation_chain: + decorated_udf = feature_processor( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE], + output="", + )(udf) + + assert decorated_udf == wrapped_udf + + fp_config_create_method.assert_called() + get_udf_wrapper.assert_called_with(fp_config) + get_validation_chain.assert_called() + + validator_chain.validate.assert_called_with(fp_config=fp_config, udf=udf) + udf_wrapper.wrap.assert_called_with(fp_config=fp_config, udf=udf) + + assert decorated_udf.feature_processor_config == fp_config + + +def test_feature_processor_validation_fails(udf, udf_wrapper, validator_chain, fp_config): + with patch.object( + FeatureProcessorConfig, "create", return_value=fp_config + ) as fp_config_create_method: + with patch.object( + UDFWrapperFactory, "get_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper: + with patch.object( + ValidatorFactory, + "get_validation_chain", + return_value=validator_chain, + ) as get_validation_chain: + validator_chain.validate.side_effect = ValueError() + + # Verify validation error is raised to user. + with pytest.raises(ValueError): + feature_processor( + inputs=tdh.FEATURE_PROCESSOR_INPUTS, + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + )(udf) + + # Verify validation failure causes execution to terminate early. + # Verify FeatureProcessorConfig interactions. + fp_config_create_method.assert_called() + get_udf_wrapper.assert_called_once() + get_validation_chain.assert_called_once() + validator_chain.validate.assert_called_with(fp_config=fp_config, udf=udf) + udf_wrapper.wrap.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py new file mode 100644 index 0000000000..2914bbb63a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import attr +import pytest +import test_data_helpers as tdh + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +def test_feature_processor_config_is_immutable(): + fp_config = FeatureProcessorConfig.create( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + mode=FeatureProcessorMode.PYSPARK, + target_stores=None, + enable_ingestion=True, + parameters=None, + spark_config=None, + ) + + with pytest.raises(attr.exceptions.FrozenInstanceError): + # Only attempting one field, as FrozenInstanceError indicates all fields are frozen + # (as opposed to FrozenAttributeError). + fp_config.inputs = [] + + with pytest.raises( + TypeError, + match="'FeatureProcessorConfig' object does not support item assignment", + ): + fp_config["inputs"] = [] diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py new file mode 100644 index 0000000000..78c506313a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) + + +def test_feature_processor_pipeline_events(): + fe_pipeline_events = FeatureProcessorPipelineEvents( + pipeline_name="pipeline_name", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.EXECUTING], + ) + assert fe_pipeline_events.pipeline_name == "pipeline_name" + assert fe_pipeline_events.pipeline_execution_status == [ + FeatureProcessorPipelineExecutionStatus.EXECUTING + ] diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py new file mode 100644 index 0000000000..1cd7e381e0 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py @@ -0,0 +1,1057 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from datetime import datetime +from typing import Callable + +import pytest +import json +from botocore.exceptions import ClientError +from mock import Mock, patch, call + +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + FeatureProcessorLineageHandler, +) +from sagemaker.mlops.feature_store.feature_processor import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.remote_function.spark_config import SparkConfig + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._constants import ( + FEATURE_PROCESSOR_TAG_KEY, + FEATURE_PROCESSOR_TAG_VALUE, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, + PIPELINE_NAME_MAXIMUM_LENGTH, +) +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + schedule, + to_pipeline, + execute, + delete_schedule, + describe, + list_pipelines, + put_trigger, + enable_trigger, + disable_trigger, + delete_trigger, + _validate_fg_lineage_resources, + _validate_pipeline_lineage_resources, +) +from sagemaker.core.remote_function.job import ( + _JobSettings, + SPARK_APP_SCRIPT_PATH, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + ENTRYPOINT_SCRIPT_NAME, + SPARK_CONF_CHANNEL_NAME, +) +from sagemaker.core.workflow.parameters import Parameter, ParameterTypeEnum +from sagemaker.mlops.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) +import test_data_helpers as tdh + +REGION = "us-west-2" +IMAGE = "image_uri" +BUCKET = "my-s3-bucket" +DEFAULT_BUCKET_PREFIX = "default_bucket_prefix" +S3_URI = f"s3://{BUCKET}/keyprefix" +DEFAULT_IMAGE = ( + "153931337802.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark-processing:3.2-cpu-py39-v1.1" +) +PIPELINE_ARN = "pipeline_arn" +SCHEDULE_ARN = "schedule_arn" +SCHEDULE_ROLE_ARN = "my_schedule_role_arn" +EXECUTION_ROLE_ARN = "my_execution_role_arn" +EVENT_BRIDGE_RULE_ARN = "arn:aws:events:us-west-2:123456789012:rule/test-rule" +VALID_SCHEDULE_STATE = "ENABLED" +INVALID_SCHEDULE_STATE = "invalid" +TEST_REGION = "us-west-2" +PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name" +PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name" +NOW = datetime.now() +SAGEMAKER_SESSION_MOCK = Mock(Session) +CONTEXT_MOCK_01 = Mock(Context) +CONTEXT_MOCK_02 = Mock(Context) +CONTEXT_MOCK_03 = Mock(Context) +FEATURE_GROUP = tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() +PIPELINE = tdh.PIPELINE.copy() +TAGS = [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + + +def mock_session(): + session = Mock() + session.default_bucket.return_value = BUCKET + session.default_bucket_prefix = DEFAULT_BUCKET_PREFIX + session.expand_role.return_value = EXECUTION_ROLE_ARN + session.boto_region_name = TEST_REGION + session.sagemaker_config = None + session._append_sagemaker_config_tags.return_value = [] + session.default_bucket_prefix = None + session.sagemaker_client = Mock() + return session + + +def mock_pipeline(): + pipeline = Mock() + pipeline.describe.return_value = {"PipelineArn": PIPELINE_ARN} + pipeline.upsert.return_value = None + return pipeline + + +def mock_event_bridge_scheduler_helper(): + helper = Mock() + helper.upsert_schedule.return_value = dict(ScheduleArn=SCHEDULE_ARN) + helper.delete_schedule.return_value = None + helper.describe_schedule.return_value = { + "Arn": "some_schedule_arn", + "ScheduleExpression": "some_schedule_expression", + "StartDate": NOW, + "State": VALID_SCHEDULE_STATE, + "Target": {"Arn": "some_pipeline_arn", "RoleArn": "some_schedule_role_arn"}, + } + return helper + + +def mock_event_bridge_rule_helper(): + helper = Mock() + helper.describe_rule.return_value = { + "Arn": "some_rule_arn", + "EventPattern": "some_event_pattern", + "State": "ENABLED", + } + return helper + + +def mock_feature_processor_lineage(): + return Mock(FeatureProcessorLineageHandler) + + +@pytest.fixture +def job_function(): + return Mock(Callable) + + +@pytest.fixture +def config_uploader(): + uploader = Mock() + uploader.return_value = "some_s3_uri" + uploader.prepare_and_upload_runtime_scripts.return_value = "some_s3_uri" + return uploader + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_fg_lineage_resources", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.Pipeline", + return_value=mock_pipeline(), +) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.TrainingInput") +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.TrainingStep") +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.ModelTrainer") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader" + "._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.RuntimeEnvironmentManager") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_callable" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.create_lineage" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.get_pipeline_lineage_names", + return_value=dict( + pipeline_context_name="pipeline-context-name", + pipeline_version_context_name="pipeline-version-context-name", + ), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.PipelineSession" +) +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.expand_role", side_effect=lambda session, role: role) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline( + get_execution_role, + expand_role, + session, + mock_pipeline_session, + mock_get_pipeline_lineage_names, + mock_create_lineage, + mock_upload_callable, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + mock_spark_dependency_upload, + mock_model_trainer, + mock_training_step, + mock_training_input, + mock_spark_image, + pipeline, + lineage_validator, +): + session.sagemaker_config = None + session.boto_region_name = TEST_REGION + session.expand_role.return_value = EXECUTION_ROLE_ARN + + # Configure RuntimeEnvironmentManager mock to return proper string values + mock_runtime_manager_instance = mock_runtime_manager.return_value + mock_runtime_manager_instance._current_python_version.return_value = "3.10" + mock_runtime_manager_instance.snapshot.return_value = "/tmp/snapshot" + + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + sagemaker_session=session, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", session.boto_region_name]) + + mock_feature_processor_config = Mock( + mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg" + ) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + pipeline_arn = to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + tags=[("tag_key_1", "tag_value_1"), ("tag_key_2", "tag_value_2")], + sagemaker_session=session, + ) + assert pipeline_arn == PIPELINE_ARN + + assert mock_upload_callable.called_once_with(job_function) + + mock_script_upload.assert_called_once_with( + spark_config, + f"{S3_URI}/pipeline_name", + None, + session, + ) + + mock_dependency_upload.assert_called_once_with( + "/tmp/snapshot", + True, + None, + None, + f"{S3_URI}/pipeline_name", + None, + session, + None, + ) + + mock_spark_dependency_upload.assert_called_once_with( + spark_config, + f"{S3_URI}/pipeline_name", + None, + session, + ) + + mock_model_trainer.assert_called_once() + # Verify ModelTrainer was configured correctly + model_trainer_call_kwargs = mock_model_trainer.call_args[1] + assert model_trainer_call_kwargs["training_image"] == "some_image_uri" + assert model_trainer_call_kwargs["role"] == EXECUTION_ROLE_ARN + assert model_trainer_call_kwargs["training_input_mode"] == "File" + + # Verify PipelineSession was passed to ModelTrainer + assert model_trainer_call_kwargs["sagemaker_session"] == mock_pipeline_session.return_value + mock_pipeline_session.assert_called_once_with( + boto_session=session.boto_session, + default_bucket=session.default_bucket(), + default_bucket_prefix=session.default_bucket_prefix, + ) + + # Verify Compute config + compute_arg = model_trainer_call_kwargs["compute"] + assert compute_arg.instance_type == "ml.m5.large" + assert compute_arg.instance_count == 1 + assert compute_arg.volume_size_in_gb == 30 + + # No VPC config was provided, so networking should be None + assert model_trainer_call_kwargs["networking"] is None + + # max_runtime_in_seconds defaults to 86400 in _JobSettings + stopping_condition_arg = model_trainer_call_kwargs["stopping_condition"] + assert stopping_condition_arg.max_runtime_in_seconds == 86400 + + # Verify OutputDataConfig + output_data_config_arg = model_trainer_call_kwargs["output_data_config"] + assert output_data_config_arg.s3_output_path == f"{S3_URI}/pipeline_name" + + # Verify SourceCode has a command string + source_code_arg = model_trainer_call_kwargs["source_code"] + assert source_code_arg.command is not None + assert len(source_code_arg.command) > 0 + assert "--client_python_version 3.10" in source_code_arg.command + + # No tags on _JobSettings, so tags should be None + assert model_trainer_call_kwargs["tags"] is None + + # Verify train() was called with input_data_config + mock_model_trainer.return_value.train.assert_called_once() + train_call_kwargs = mock_model_trainer.return_value.train.call_args[1] + assert "input_data_config" in train_call_kwargs + + mock_training_step.assert_called_once_with( + name="-".join(["pipeline_name", "feature-processor"]), + step_args=mock_model_trainer.return_value.train.return_value, + retry_policies=[ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + max_attempts=1, + ), + SageMakerJobStepRetryPolicy( + exception_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT, + ], + max_attempts=1, + ), + ], + ) + + pipeline.assert_called_once_with( + name="pipeline_name", + steps=[mock_training_step()], + sagemaker_session=session, + parameters=[Parameter(name="scheduled_time", parameter_type=ParameterTypeEnum.STRING)], + ) + + pipeline().upsert.assert_has_calls( + [ + call( + role_arn=EXECUTION_ROLE_ARN, + tags=[ + dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE), + dict(Key="tag_key_1", Value="tag_value_1"), + dict(Key="tag_key_2", Value="tag_value_2"), + ], + ), + call( + role_arn=EXECUTION_ROLE_ARN, + tags=[ + { + "Key": PIPELINE_CONTEXT_NAME_TAG_KEY, + "Value": "pipeline-context-name", + }, + { + "Key": PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + "Value": "pipeline-version-context-name", + }, + ], + ), + ] + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_not_wrapped_by_feature_processor(get_execution_role, session): + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + wrapped_func = Mock( + Callable, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Please wrap step parameter with feature_processor decorator in order to use to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_not_wrapped_by_remote(get_execution_role, session): + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYTHON) + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=None, + wrapped_func=job_function, + ) + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Please wrap step parameter with remote decorator in order to use to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_wrong_mode(get_execution_role, mock_spark_image, session): + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", TEST_REGION]) + + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYTHON) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYTHON + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Mode FeatureProcessorMode.PYTHON is not supported by to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_pipeline_name_length_limit_exceeds( + get_execution_role, mock_spark_image, session +): + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", TEST_REGION]) + + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYSPARK) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Pipeline name used by feature processor should be less than 80 " + "characters. Please choose another pipeline name.", + ): + to_pipeline( + pipeline_name="".join(["a" for _ in range(PIPELINE_NAME_MAXIMUM_LENGTH + 1)]), + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_used_reserved_tags(get_execution_role, mock_spark_image, session): + session.sagemaker_config = None + session.boto_region_name = TEST_REGION + session.expand_role.return_value = EXECUTION_ROLE_ARN + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + sagemaker_session=session, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", session.boto_region_name]) + + mock_feature_processor_config = Mock( + mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg" + ) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="sm-fs-fe:created-from is a reserved tag key for to_pipeline API. Please choose another tag.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + tags=[("sm-fs-fe:created-from", "random")], + sagemaker_session=session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler" + "._get_tags_from_pipeline_to_propagate_to_lineage_resources", + return_value=TAGS, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.FeatureProcessorLineageHandler", + return_value=mock_feature_processor_lineage(), +) +def test_schedule(lineage, helper, validation, get_tags): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline = Mock( + return_value={"PipelineArn": "my:arn", "CreationTime": NOW} + ) + + schedule_arn = schedule( + schedule_expression="some_schedule", + state=VALID_SCHEDULE_STATE, + start_date=NOW, + pipeline_name=PIPELINE_ARN, + role_arn=SCHEDULE_ROLE_ARN, + sagemaker_session=session, + ) + + assert schedule_arn == SCHEDULE_ARN + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeRuleHelper", + return_value=mock_event_bridge_rule_helper(), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_describe_both_exist(mock_scheduler_helper, mock_rule_helper): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline.return_value = PIPELINE + describe_schedule_response = describe( + pipeline_name="some_pipeline_arn", sagemaker_session=session + ) + assert describe_schedule_response == dict( + pipeline_arn="some_pipeline_arn", + pipeline_execution_role_arn="some_execution_role_arn", + max_retries=5, + schedule_arn="some_schedule_arn", + schedule_expression="some_schedule_expression", + schedule_state=VALID_SCHEDULE_STATE, + schedule_start_date=NOW.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + schedule_role="some_schedule_role_arn", + trigger="some_rule_arn", + event_pattern="some_event_pattern", + trigger_state="ENABLED", + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeRuleHelper.describe_rule", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper.describe_schedule", + return_value=None, +) +def test_describe_only_pipeline_exist(helper, mock_describe_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline.return_value = { + "PipelineArn": "some_pipeline_arn", + "RoleArn": "some_execution_role_arn", + "PipelineDefinition": json.dumps({"Steps": [{"Arguments": {}}]}), + } + helper.describe_schedule().return_value = None + describe_schedule_response = describe( + pipeline_name="some_pipeline_arn", sagemaker_session=session + ) + assert describe_schedule_response == dict( + pipeline_arn="some_pipeline_arn", + pipeline_execution_role_arn="some_execution_role_arn", + ) + + +def test_list_pipelines(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.list_contexts.return_value = { + "ContextSummaries": [ + { + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:12345789012:pipeline/some_pipeline_name" + } + } + ] + } + list_response = list_pipelines(session) + assert list_response == [dict(pipeline_name="some_pipeline_name")] + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_delete_schedule_both_exist(helper): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.delete_pipeline = Mock() + delete_schedule(pipeline_name=PIPELINE_ARN, sagemaker_session=session) + helper().delete_schedule.assert_called_once_with(PIPELINE_ARN) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_delete_schedule_not_exist(helper): + helper.delete_schedule.side_effect = ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="update_schedule", + ) + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.delete_pipeline = Mock() + delete_schedule(pipeline_name=PIPELINE_ARN, sagemaker_session=session) + helper().delete_schedule.assert_called_once_with(PIPELINE_ARN) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources", + return_value=None, +) +def test_execute(validation): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline = Mock( + return_value={"PipelineArn": "my:arn", "CreationTime": NOW} + ) + session.sagemaker_client.start_pipeline_execution = Mock( + return_value={"PipelineExecutionArn": "my:arn"} + ) + execution_arn = execute( + pipeline_name="some_pipeline", execution_time=NOW, sagemaker_session=session + ) + assert execution_arn == "my:arn" + + +def test_validate_fg_lineage_resources_happy_case(): + with patch.object( + SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP + ) as fg_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02, CONTEXT_MOCK_03] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn" + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep" + type(CONTEXT_MOCK_03).context_arn = "context-arn-fep-ver" + _validate_fg_lineage_resources( + feature_group_name="some_fg", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + fg_describe_method.assert_called_once_with(feature_group_name="some_fg") + context_load.assert_has_calls( + [ + call( + context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline-version", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == context_load.call_count + + +def test_validete_fg_lineage_resources_rnf(): + with patch.object(SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP): + with patch.object( + Context, + "load", + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFound"}}, + operation_name="describe_context", + ), + ): + feature_group_name = "some_fg" + feature_group_creation_time = FEATURE_GROUP["CreationTime"].strftime("%s") + context_name = f"{feature_group_name}-{feature_group_creation_time}" + with pytest.raises( + ValueError, + match=f"Lineage resource {context_name} has not yet been created for feature group" + f" {feature_group_name} or has already been deleted. Please try again later.", + ): + _validate_fg_lineage_resources( + feature_group_name="some_fg", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_validate_pipeline_lineage_resources_happy_case(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.return_value = Mock() + pipeline_name = "some_pipeline" + with patch.object( + session.sagemaker_client, "describe_pipeline", return_value=PIPELINE + ) as pipeline_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn" + type(CONTEXT_MOCK_01).properties = {"LastUpdateTime": NOW.strftime("%s")} + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep" + _validate_pipeline_lineage_resources( + pipeline_name=pipeline_name, + sagemaker_session=session, + ) + pipeline_describe_method.assert_called_once_with(PipelineName=pipeline_name) + pipeline_creation_time = PIPELINE["CreationTime"].strftime("%s") + last_updated_time = NOW.strftime("%s") + context_load.assert_has_calls( + [ + call( + context_name=f"sm-fs-fe-{pipeline_name}-{pipeline_creation_time}-fep", + sagemaker_session=session, + ), + call( + context_name=f"sm-fs-fe-{pipeline_name}-{last_updated_time}-fep-ver", + sagemaker_session=session, + ), + ] + ) + assert 2 == context_load.call_count + + +def test_validate_pipeline_lineage_resources_rnf(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.return_value = Mock() + pipeline_name = "some_pipeline" + with patch.object(session.sagemaker_client, "describe_pipeline", return_value=PIPELINE): + with patch.object( + Context, + "load", + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFound"}}, + operation_name="describe_context", + ), + ): + with pytest.raises( + ValueError, + match="Pipeline lineage resources have not been created yet or have" + " already been deleted. Please try again later.", + ): + _validate_pipeline_lineage_resources( + pipeline_name=pipeline_name, + sagemaker_session=session, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_remote_decorator_fields_consistency(get_execution_role, session): + expected_remote_decorator_attributes = { + "sagemaker_session", + "environment_variables", + "image_uri", + "dependencies", + "pre_execution_commands", + "pre_execution_script", + "include_local_workdir", + "instance_type", + "instance_count", + "volume_size", + "max_runtime_in_seconds", + "max_retry_attempts", + "keep_alive_period_in_seconds", + "spark_config", + "job_conda_env", + "job_name_prefix", + "encrypt_inter_container_traffic", + "enable_network_isolation", + "role", + "s3_root_uri", + "s3_kms_key", + "volume_kms_key", + "vpc_config", + "tags", + "use_spot_instances", + "max_wait_time_in_seconds", + "custom_file_filter", + "disable_output_compression", + "use_torchrun", + "use_mpirun", + "nproc_per_node", + } + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + actual_attributes = {attribute for attribute, _ in job_settings.__dict__.items()} + + assert expected_remote_decorator_attributes == actual_attributes + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.create_trigger_lineage" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.describe_rule", + return_value={"EventPattern": "test-pattern"}, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.add_tags" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler." + "_get_tags_from_pipeline_to_propagate_to_lineage_resources", + return_value=TAGS, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.put_target" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.put_rule", + return_value="arn:aws:events:us-west-2:123456789012:rule/test-rule", +) +def test_put_trigger( + mock_put_rule, + mock_put_target, + mock_get_tags, + mock_add_tags, + mock_describe_rule, + mock_create_trigger_lineage, +): + session = Mock( + Session, + sagemaker_client=Mock( + describe_pipeline=Mock(return_value={"PipelineArn": "test-pipeline-arn"}) + ), + boto_session=Mock(), + ) + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test-pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + put_trigger( + source_pipeline_events=source_pipeline_events, + target_pipeline="test-target-pipeline", + state="Enabled", + event_pattern="test-pattern", + role_arn=SCHEDULE_ROLE_ARN, + sagemaker_session=session, + ) + + mock_put_rule.assert_called_once_with( + source_pipeline_events=source_pipeline_events, + target_pipeline="test-target-pipeline", + state="Enabled", + event_pattern="test-pattern", + ) + mock_put_target.assert_called_once_with( + rule_name="test-rule", + target_pipeline="test-target-pipeline", + target_pipeline_parameters=None, + role_arn=SCHEDULE_ROLE_ARN, + ) + mock_add_tags.assert_called_once_with(rule_arn=EVENT_BRIDGE_RULE_ARN, tags=TAGS) + mock_create_trigger_lineage.assert_called_once_with( + pipeline_name="test-target-pipeline", + trigger_arn=EVENT_BRIDGE_RULE_ARN, + state="Enabled", + tags=TAGS, + event_pattern="test-pattern", + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.enable_rule" +) +def test_enable_trigger(mock_enable_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + enable_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_enable_rule.assert_called_once_with(rule_name="test-pipeline") + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.disable_rule" +) +def test_disable_trigger(mock_disable_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + disable_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_disable_rule.assert_called_once_with(rule_name="test-pipeline") + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.list_targets_by_rule", + return_value=[{"Targets": [{"Id": "target_pipeline"}]}], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.remove_targets" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.delete_rule" +) +def test_delete_trigger(mock_delete_rule, mock_remove_targets, mock_list_targets_by_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + delete_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_delete_rule.assert_called_once_with("test-pipeline") + mock_list_targets_by_rule.assert_called_once_with("test-pipeline") + mock_remove_targets.assert_called_once_with(rule_name="test-pipeline", ids=["target_pipeline"]) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py new file mode 100644 index 0000000000..24d20f96c5 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py @@ -0,0 +1,320 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch, call +from pyspark.sql import SparkSession, DataFrame +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + IcebergTableDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.core.helper.session_helper import Session + + +@pytest.fixture +def describe_fg_response(): + return tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() + + +@pytest.fixture +def sagemaker_session(describe_fg_response): + return Mock(Session, describe_feature_group=Mock(return_value=describe_fg_response)) + + +@pytest.fixture +def spark_session(mock_data_frame): + return Mock( + SparkSession, + read=Mock( + csv=Mock(return_value=Mock()), + parquet=Mock(return_value=mock_data_frame), + conf=Mock(set=Mock()), + ), + table=Mock(return_value=mock_data_frame), + ) + + +@pytest.fixture +def environment_helper(): + return Mock( + EnvironmentHelper, + get_job_scheduled_time=Mock(return_value="2023-05-05T15:22:57Z"), + ) + + +@pytest.fixture +def mock_data_frame(): + return Mock(DataFrame, filter=Mock()) + + +@pytest.fixture +def spark_session_factory(spark_session): + factory = Mock(SparkSessionFactory) + factory.spark_session = spark_session + factory.get_spark_session_with_iceberg_config = Mock(return_value=spark_session) + return factory + + +@pytest.fixture +def fp_config(): + return tdh.create_fp_config() + + +@pytest.fixture +def input_loader(spark_session_factory, sagemaker_session, environment_helper): + return SparkDataFrameInputLoader( + spark_session_factory, + environment_helper, + sagemaker_session, + ) + + +def test_load_from_s3_with_csv_object(input_loader: SparkDataFrameInputLoader, spark_session): + s3_data_source = CSVDataSource( + s3_uri="s3://bucket/prefix/file.csv", + csv_header=True, + csv_infer_schema=True, + ) + + input_loader.load_from_s3(s3_data_source) + + spark_session.read.csv.assert_called_with( + "s3a://bucket/prefix/file.csv", header=True, inferSchema=True + ) + + +def test_load_from_s3_with_parquet_object(input_loader, spark_session): + s3_data_source = ParquetDataSource(s3_uri="s3://bucket/prefix/file.parquet") + + input_loader.load_from_s3(s3_data_source) + + spark_session.read.parquet.assert_called_with("s3a://bucket/prefix/file.parquet") + + +@pytest.mark.parametrize( + "condition", + [(None), ("condition")], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_loader." + "SparkDataFrameInputLoader._get_iceberg_offset_filter_condition" +) +def test_load_from_iceberg_table( + mock_get_filter_condition, + condition, + input_loader, + spark_session, + spark_session_factory, + mock_data_frame, +): + iceberg_table_data_source = IcebergTableDataSource( + warehouse_s3_uri="s3://bucket/prefix/", + catalog="Catalog", + database="Database", + table="Table", + ) + mock_get_filter_condition.return_value = condition + + input_loader.load_from_iceberg_table(iceberg_table_data_source, "event_time", "start", "end") + spark_session_factory.get_spark_session_with_iceberg_config.assert_called_with( + "s3://bucket/prefix/", "catalog" + ) + spark_session.table.assert_called_with("catalog.database.table") + mock_get_filter_condition.assert_called_with("event_time", "start", "end") + + if condition: + mock_data_frame.filter.assert_called_with(condition) + else: + mock_data_frame.filter.assert_not_called() + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_loader.SparkDataFrameInputLoader.load_from_date_partitioned_s3" +) +def test_load_from_feature_group_with_arn( + mock_load_from_date_partitioned_s3, sagemaker_session, input_loader +): + fg_arn = tdh.INPUT_FEATURE_GROUP_ARN + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource( + name=fg_arn, input_start_offset="start", input_end_offset="end" + ) + + input_loader.load_from_feature_group(fg_data_source) + + sagemaker_session.describe_feature_group.assert_called_with(fg_name) + mock_load_from_date_partitioned_s3.assert_called_with( + ParquetDataSource(tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI), + "start", + "end", + ) + + +def test_load_from_feature_group_offline_store_not_enabled(input_loader, describe_fg_response): + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + with pytest.raises( + ValueError, + match=( + f"Input Feature Groups must have an enabled Offline Store." + f" Feature Group: {fg_name} does not have an Offline Store enabled." + ), + ): + del describe_fg_response["OfflineStoreConfig"] + input_loader.load_from_feature_group(fg_data_source) + + +def test_load_from_feature_group_with_default_table_format( + sagemaker_session, input_loader, spark_session +): + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + input_loader.load_from_feature_group(fg_data_source) + + sagemaker_session.describe_feature_group.assert_called_with(fg_name) + spark_session.read.parquet.assert_called_with( + tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI.replace("s3:", "s3a:") + ) + + +def test_load_from_feature_group_with_iceberg_table_format( + describe_fg_response, spark_session_factory, spark_session, environment_helper +): + describe_iceberg_fg_response = describe_fg_response.copy() + describe_iceberg_fg_response["OfflineStoreConfig"]["TableFormat"] = "Iceberg" + mocked_session = Mock( + Session, describe_feature_group=Mock(return_value=describe_iceberg_fg_response) + ) + mock_input_loader = SparkDataFrameInputLoader( + spark_session_factory, environment_helper, mocked_session + ) + + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + mock_input_loader.load_from_feature_group(fg_data_source) + + mocked_session.describe_feature_group.assert_called_with(fg_name) + spark_session.table.assert_called_with( + "awsdatacatalog.sagemaker_featurestore.input_fg_1680142547" + ) + + +@pytest.mark.parametrize( + "param", + [ + (None, None, None), + ("start", None, "event_time >= 'start_time'"), + (None, "end", "event_time < 'end_time'"), + ("start", "end", "event_time >= 'start_time' AND event_time < 'end_time'"), + ], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_offset_parser.InputOffsetParser.get_iso_format_offset_date", + side_effect=[ + "start_time", + "end_time", + ], +) +def test_get_iceberg_offset_filter_condition(mock_get_iso_date, param, input_loader): + start_offset, end_offset, expected_condition_str = param + + condition = input_loader._get_iceberg_offset_filter_condition( + "event_time", start_offset, end_offset + ) + + if start_offset or end_offset: + mock_get_iso_date.assert_has_calls([call(start_offset), call(end_offset)]) + else: + mock_get_iso_date.assert_not_called() + + assert condition == expected_condition_str + + +@pytest.mark.parametrize( + "param", + [ + (None, None, None), + ( + "start", + None, + "(year >= 'year_start') AND NOT ((year = 'year_start' AND month < 'month_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day < 'day_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day = 'day_start' AND hour < 'hour_start'))", + ), + ( + None, + "end", + "(year <= 'year_end') AND NOT ((year = 'year_end' AND month > 'month_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day > 'day_end') OR (year = 'year_end' " + "AND month = 'month_end' AND day = 'day_end' AND hour >= 'hour_end'))", + ), + ( + "start", + "end", + "(year >= 'year_start' AND year <= 'year_end') AND NOT ((year = 'year_start' AND " + "month < 'month_start') OR (year = 'year_start' AND month = 'month_start' AND day < 'day_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day = 'day_start' AND hour < 'hour_start') OR " + "(year = 'year_end' AND month > 'month_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day > 'day_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day = 'day_end' AND hour >= 'hour_end'))", + ), + ], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_offset_parser.InputOffsetParser." + "get_offset_date_year_month_day_hour", + side_effect=[ + ("year_start", "month_start", "day_start", "hour_start"), + ("year_end", "month_end", "day_end", "hour_end"), + ], +) +def test_get_s3_partitions_offset_filter_condition(mock_get_ymdh, param, input_loader): + start_offset, end_offset, expected_condition_str = param + + condition = input_loader._get_s3_partitions_offset_filter_condition(start_offset, end_offset) + + if start_offset or end_offset: + mock_get_ymdh.assert_has_calls([call(start_offset), call(end_offset)]) + else: + mock_get_ymdh.assert_not_called() + + assert condition == expected_condition_str + + +@pytest.mark.parametrize( + "condition", + [(None), ("condition")], +) +def test_load_from_date_partitioned_s3(input_loader, spark_session, mock_data_frame, condition): + input_loader._get_s3_partitions_offset_filter_condition = Mock(return_value=condition) + + input_loader.load_from_date_partitioned_s3( + ParquetDataSource("s3://path/to/file"), "start", "end" + ) + df = spark_session.read.parquet + df.assert_called_with("s3a://path/to/file") + + if condition: + mock_data_frame.filter.assert_called_with(condition) + else: + mock_data_frame.filter.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py new file mode 100644 index 0000000000..b244a4c2da --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) +from datetime import datetime +from dateutil.relativedelta import relativedelta +import pytest + + +@pytest.fixture +def input_offset_parser(): + time_spec = dict(year=2023, month=5, day=10, hour=17, minute=30, second=20) + return InputOffsetParser(now=datetime(**time_spec)) + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ("1 hour", "2023-05-10T16:30:20Z"), + ("1 day", "2023-05-09T17:30:20Z"), + ("1 month", "2023-04-10T17:30:20Z"), + ("1 year", "2022-05-10T17:30:20Z"), + ], +) +def test_get_iso_format_offset_date(param, input_offset_parser): + input_offset, expected_offset_date = param + output_offset_date = input_offset_parser.get_iso_format_offset_date(input_offset) + + assert output_offset_date == expected_offset_date + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ( + "1 hour", + datetime.strptime("2023-05-10T16:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 day", + datetime.strptime("2023-05-09T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 month", + datetime.strptime("2023-04-10T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 year", + datetime.strptime("2022-05-10T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ], +) +def test_get_offset_datetime(param, input_offset_parser): + input_offset, expected_offset_datetime = param + output_offet_datetime = input_offset_parser.get_offset_datetime(input_offset) + + assert output_offet_datetime == expected_offset_datetime + + +@pytest.mark.parametrize( + "param", + [ + (None, (None, None, None, None)), + ("1 hour", ("2023", "05", "10", "16")), + ("1 day", ("2023", "05", "09", "17")), + ("1 month", ("2023", "04", "10", "17")), + ("1 year", ("2022", "05", "10", "17")), + ], +) +def test_get_offset_date_year_month_day_hour(param, input_offset_parser): + input_offset, expected_date_tuple = param + output_date_tuple = input_offset_parser.get_offset_date_year_month_day_hour(input_offset) + + assert output_date_tuple == expected_date_tuple + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ("1 hour", relativedelta(hours=-1)), + ("20 hours", relativedelta(hours=-20)), + ("1 day", relativedelta(days=-1)), + ("20 days", relativedelta(days=-20)), + ("1 month", relativedelta(months=-1)), + ("20 months", relativedelta(months=-20)), + ("1 year", relativedelta(years=-1)), + ("20 years", relativedelta(years=-20)), + ], +) +def test_parse_offset_to_timedelta(param, input_offset_parser): + input_offset, expected_deltatime = param + output_deltatime = input_offset_parser.parse_offset_to_timedelta(input_offset) + + assert output_deltatime == expected_deltatime + + +@pytest.mark.parametrize( + "param", + [ + ( + "random invalid string", + "[random invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "1 invalid string", + "[1 invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "2 days invalid string", + "[2 days invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "1 second", + "[second] is not a valid offset unit. Supported units: ['hour', 'day', 'week', 'month', 'year']", + ), + ], +) +def test_parse_offset_to_timedelta_negative(param, input_offset_parser): + input_offset, expected_error_message = param + + with pytest.raises(ValueError) as e: + input_offset_parser.parse_offset_to_timedelta(input_offset) + + assert str(e.value) == expected_error_message diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py new file mode 100644 index 0000000000..1ac2042a55 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +import pytest +import test_data_helpers as tdh +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._params_loader import ( + ParamsLoader, + SystemParamsLoader, +) + + +@pytest.fixture +def system_params_loader_mock(): + system_params_loader = Mock(SystemParamsLoader) + system_params_loader.get_system_args.return_value = tdh.SYSTEM_PARAMS + return system_params_loader + + +@pytest.fixture +def environment_checker(): + environment_checker = Mock(EnvironmentHelper) + environment_checker.is_training_job.return_value = False + environment_checker.get_job_scheduled_time = Mock(return_value="2023-05-05T15:22:57Z") + return environment_checker + + +@pytest.fixture +def params_loader(system_params_loader_mock): + return ParamsLoader(system_params_loader_mock) + + +@pytest.fixture +def system_params_loader(environment_checker): + return SystemParamsLoader(environment_checker) + + +def test_get_parameter_args(params_loader, system_params_loader_mock): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + parameters=tdh.USER_INPUT_PARAMS, + ) + + params = params_loader.get_parameter_args(fp_config) + + system_params_loader_mock.get_system_args.assert_called_once() + assert params == {"params": {**tdh.USER_INPUT_PARAMS, **tdh.SYSTEM_PARAMS}} + + +def test_get_parameter_args_with_no_user_params(params_loader, system_params_loader_mock): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + parameters=None, + ) + + params = params_loader.get_parameter_args(fp_config) + + system_params_loader_mock.get_system_args.assert_called_once() + assert params == {"params": {**tdh.SYSTEM_PARAMS}} + + +def test_get_system_arg_from_pipeline_execution(system_params_loader): + system_params = system_params_loader.get_system_args() + + assert system_params == { + "system": { + "scheduled_time": "2023-05-05T15:22:57Z", + } + } diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py new file mode 100644 index 0000000000..12ac116f53 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py @@ -0,0 +1,175 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import feature_store_pyspark +import pytest +from mock import Mock, patch, call + +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, + SparkSessionFactory, +) + + +@pytest.fixture +def env_helper(): + return Mock( + is_training_job=Mock(return_value=False), + load_training_resource_config=Mock(return_value=None), + ) + + +def test_spark_session_factory_configuration(): + env_helper = Mock() + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) + spark_configs = dict(spark_session_factory._get_spark_configs(is_training_job=False)) + jsc_hadoop_configs = dict(spark_session_factory._get_jsc_hadoop_configs()) + + # General optimizations + assert spark_configs.get("spark.hadoop.fs.s3a.aws.credentials.provider") == ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ) + + assert spark_configs.get("spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version") == "2" + assert ( + spark_configs.get("spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored") + == "true" + ) + assert spark_configs.get("spark.hadoop.parquet.enable.summary-metadata") == "false" + + assert spark_configs.get("spark.sql.parquet.mergeSchema") == "false" + assert spark_configs.get("spark.sql.parquet.filterPushdown") == "true" + assert spark_configs.get("spark.sql.hive.metastorePartitionPruning") == "true" + + assert spark_configs.get("spark.hadoop.fs.s3a.threads.max") == "500" + assert spark_configs.get("spark.hadoop.fs.s3a.connection.maximum") == "500" + assert spark_configs.get("spark.hadoop.fs.s3a.experimental.input.fadvise") == "normal" + assert spark_configs.get("spark.hadoop.fs.s3a.block.size") == "128M" + assert spark_configs.get("spark.hadoop.fs.s3a.fast.upload.buffer") == "disk" + assert spark_configs.get("spark.hadoop.fs.trash.interval") == "0" + assert spark_configs.get("spark.port.maxRetries") == "50" + + assert spark_configs.get("spark.test.key") == "spark.test.value" + + assert jsc_hadoop_configs.get("mapreduce.fileoutputcommitter.marksuccessfuljobs") == "false" + + # Verify configurations when not running on a training job + assert ",".join(feature_store_pyspark.classpath_jars()) in spark_configs.get("spark.jars") + assert ",".join( + [ + "org.apache.hadoop:hadoop-aws:3.3.1", + "org.apache.hadoop:hadoop-common:3.3.1", + ] + ) in spark_configs.get("spark.jars.packages") + + +def test_spark_session_factory_configuration_on_training_job(): + env_helper = Mock() + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) + + spark_config = spark_session_factory._get_spark_configs(is_training_job=True) + assert dict(spark_config).get("spark.test.key") == "spark.test.value" + + assert all(tup[0] != "spark.jars" for tup in spark_config) + assert all(tup[0] != "spark.jars.packages" for tup in spark_config) + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory(mock_spark_context): + env_helper = Mock() + env_helper.get_instance_count.return_value = 1 + spark_session_factory = SparkSessionFactory(env_helper) + + spark_session_factory.spark_session + + _, _, kw_args = mock_spark_context.mock_calls[0] + spark_conf = kw_args["conf"] + + mock_spark_context.assert_called_once() + assert spark_conf.get("spark.master") == "local[*]" + for cfg in spark_session_factory._get_spark_configs(True): + assert spark_conf.get(cfg[0]) == cfg[1] + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory_with_iceberg_config(mock_spark_context): + mock_env_helper = Mock() + mock_spark_context.side_effect = [Mock(), Mock()] + + spark_session_factory = SparkSessionFactory(mock_env_helper) + + spark_session = spark_session_factory.spark_session + mock_conf = Mock() + spark_session.conf = mock_conf + + spark_session_with_iceberg_config = spark_session_factory.get_spark_session_with_iceberg_config( + "warehouse", "catalog" + ) + + assert spark_session is spark_session_with_iceberg_config + expected_calls = [ + call.set(cfg[0], cfg[1]) + for cfg in spark_session_factory._get_iceberg_configs("warehouse", "catalog") + ] + + mock_conf.assert_has_calls(expected_calls, any_order=False) + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory_same_instance(mock_spark_context): + mock_env_helper = Mock() + mock_spark_context.side_effect = [Mock(), Mock()] + + spark_session_factory = SparkSessionFactory(mock_env_helper) + + a_reference = spark_session_factory.spark_session + another_reference = spark_session_factory.spark_session + + assert a_reference is another_reference + + +@patch("feature_store_pyspark.FeatureStoreManager.FeatureStoreManager") +def test_feature_store_manager_same_instance(mock_feature_store_manager): + mock_feature_store_manager.side_effect = [Mock(), Mock()] + + factory = FeatureStoreManagerFactory() + + assert factory.feature_store_manager is factory.feature_store_manager + + +def test_spark_session_factory_get_spark_session_with_iceberg_config(env_helper): + spark_session_factory = SparkSessionFactory(env_helper) + iceberg_configs = dict(spark_session_factory._get_iceberg_configs("s3://test/path", "Catalog")) + + assert ( + iceberg_configs.get("spark.sql.catalog.catalog") + == "smfs.shaded.org.apache.iceberg.spark.SparkCatalog" + ) + assert iceberg_configs.get("spark.sql.catalog.catalog.warehouse") == "s3://test/path" + assert ( + iceberg_configs.get("spark.sql.catalog.catalog.catalog-impl") + == "smfs.shaded.org.apache.iceberg.aws.glue.GlueCatalog" + ) + assert ( + iceberg_configs.get("spark.sql.catalog.catalog.io-impl") + == "smfs.shaded.org.apache.iceberg.aws.s3.S3FileIO" + ) + assert iceberg_configs.get("spark.sql.catalog.catalog.glue.skip-name-validation") == "true" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py new file mode 100644 index 0000000000..561cc09e76 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py @@ -0,0 +1,280 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch +from pyspark.sql import DataFrame, SparkSession + +from sagemaker.mlops.feature_store.feature_processor._input_loader import InputLoader +from sagemaker.mlops.feature_store.feature_processor._params_loader import ParamsLoader +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import SparkArgProvider +from sagemaker.mlops.feature_store.feature_processor._data_source import PySparkDataSource + + +@pytest.fixture +def params_loader(): + params_loader = Mock(ParamsLoader) + params_loader.get_parameter_args = Mock(return_value={"params": {"key": "value"}}) + return params_loader + + +@pytest.fixture +def feature_group_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def s3_uri_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def base_data_source_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def input_loader(feature_group_as_spark_df, s3_uri_as_spark_df): + input_loader = Mock(InputLoader) + input_loader.load_from_s3.return_value = s3_uri_as_spark_df + input_loader.load_from_feature_group.return_value = feature_group_as_spark_df + + return input_loader + + +@pytest.fixture +def spark_session(): + return Mock(SparkSession) + + +@pytest.fixture +def spark_session_factory(spark_session): + return Mock(SparkSessionFactory, spark_session=spark_session) + + +@pytest.fixture +def spark_arg_provider(params_loader, input_loader, spark_session_factory): + return SparkArgProvider(params_loader, input_loader, spark_session_factory) + + +class MockDataSource(PySparkDataSource): + + data_source_unique_id = "test_id" + data_source_name = "test_source" + + def read_data(self, spark, params) -> DataFrame: + return Mock(DataFrame) + + +def test_provide_additional_kw_args(spark_arg_provider, spark_session): + def udf(fg_input, s3_input, params, spark): + return None + + additional_kw_args = spark_arg_provider.provide_additional_kwargs(udf) + + assert additional_kw_args.keys() == {"spark"} + assert additional_kw_args["spark"] == spark_session + + +def test_not_provide_additional_kw_args(spark_arg_provider): + def udf(input, params): + return None + + additional_kw_args = spark_arg_provider.provide_additional_kwargs(udf) + + assert additional_kw_args == {} + + +def test_provide_params(spark_arg_provider, params_loader): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(fg_input, s3_input, params, spark): + return None + + params = spark_arg_provider.provide_params_arg(udf, fp_config) + + params_loader.get_parameter_args.assert_called_with(fp_config) + assert params == params_loader.get_parameter_args.return_value + + +def test_not_provide_params(spark_arg_provider, params_loader): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(fg_input, s3_input, spark): + return None + + params = spark_arg_provider.provide_params_arg(udf, fp_config) + + assert params == {} + + +def test_provide_input_args_with_no_input(spark_arg_provider): + fp_config = tdh.create_fp_config(inputs=[], output=tdh.OUTPUT_FEATURE_GROUP_ARN) + + def udf() -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, match="Expected at least one input to the user defined function." + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_extra_udf_parameters(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.INPUT_S3_URI], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, + match=r"The signature of the user defined function does not match the list of inputs requested." + r" Expected 1 parameter\(s\).", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_extra_fp_config_inputs(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None) -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, + match=r"The signature of the user defined function does not match the list of inputs requested." + r" Expected 2 parameter\(s\).", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args( + spark_arg_provider, + feature_group_as_spark_df, + s3_uri_as_spark_df, +): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == feature_group_as_spark_df + assert inputs["input_s3_uri"] == s3_uri_as_spark_df + + +def test_provide_input_args_with_reversed_inputs( + spark_arg_provider, + feature_group_as_spark_df, + s3_uri_as_spark_df, +): + fp_config = tdh.create_fp_config( + inputs=[tdh.S3_DATA_SOURCE, tdh.FEATURE_GROUP_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == s3_uri_as_spark_df + assert inputs["input_s3_uri"] == feature_group_as_spark_df + + +def test_provide_input_args_with_optional_args_out_of_order(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf_spark_params(spark=None, params=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_params_spark(params=None, spark=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_spark(spark=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_params(params=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + for udf in [udf_spark_params, udf_params_spark, udf_spark, udf_params]: + with pytest.raises( + ValueError, + match="Expected at least one input to the user defined function.", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_optional_args( + spark_arg_provider, feature_group_as_spark_df, s3_uri_as_spark_df +): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf_all_optional(input_fg=None, input_s3_uri=None, params=None, spark=None) -> DataFrame: + return Mock(DataFrame) + + def udf_no_optional(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_only_params(input_fg=None, input_s3_uri=None, params=None) -> DataFrame: + return Mock(DataFrame) + + def udf_only_spark(input_fg=None, input_s3_uri=None, spark=None) -> DataFrame: + return Mock(DataFrame) + + for udf in [udf_all_optional, udf_no_optional, udf_only_params, udf_only_spark]: + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == feature_group_as_spark_df + assert inputs["input_s3_uri"] == s3_uri_as_spark_df + + +def test_provide_input_arg_for_base_data_source(spark_arg_provider, params_loader, spark_session): + fp_config = tdh.create_fp_config(inputs=[MockDataSource()], output=tdh.OUTPUT_FEATURE_GROUP_ARN) + + def udf(input_df) -> DataFrame: + return input_df + + with patch.object(MockDataSource, "read_data", return_value=Mock(DataFrame)) as mock_read: + spark_arg_provider.provide_input_args(udf, fp_config) + mock_read.assert_called_with(spark=spark_session, params={"key": "value"}) + params_loader.get_parameter_args.assert_called_with(fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py new file mode 100644 index 0000000000..34770a011a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from feature_store_pyspark.FeatureStoreManager import FeatureStoreManager +from mock import Mock +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor import IngestionError +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, +) +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + SparkOutputReceiver, +) + + +@pytest.fixture +def df() -> Mock: + return Mock(DataFrame) + + +@pytest.fixture +def feature_store_manager(): + return Mock(FeatureStoreManager) + + +@pytest.fixture +def feature_store_manager_factory(feature_store_manager): + return Mock(FeatureStoreManagerFactory, feature_store_manager=feature_store_manager) + + +@pytest.fixture +def spark_output_receiver(feature_store_manager_factory): + return SparkOutputReceiver(feature_store_manager_factory) + + +def test_ingest_udf_output_enable_ingestion_false(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config(enable_ingestion=False) + spark_output_receiver.ingest_udf_output(df, fp_config) + + feature_store_manager.ingest_data.assert_not_called() + + +def test_ingest_udf_output(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + spark_output_receiver.ingest_udf_output(df, fp_config) + + feature_store_manager.ingest_data.assert_called_with( + input_data_frame=df, + feature_group_arn=fp_config.output, + target_stores=fp_config.target_stores, + ) + + +def test_ingest_udf_output_failed_records(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + + # Simulate streaming ingestion failure. + mock_failed_records_df = Mock() + mock_java_exception = Mock(_target_id="") + mock_java_exception.getClass = Mock( + return_value=Mock(getSimpleName=Mock(return_value="StreamIngestionFailureException")) + ) + + feature_store_manager.ingest_data.side_effect = Py4JJavaError( + msg="", java_exception=mock_java_exception + ) + feature_store_manager.get_failed_stream_ingestion_data_frame.return_value = ( + mock_failed_records_df + ) + + with pytest.raises(IngestionError): + spark_output_receiver.ingest_udf_output(df, fp_config) + + mock_failed_records_df.show.assert_called_with(n=20, truncate=False) + + +def test_ingest_udf_output_all_py4j_error_raised(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + + # Simulate ingestion failure. + mock_java_exception = Mock(_target_id="") + mock_java_exception.getClass = Mock( + return_value=Mock(getSimpleName=Mock(return_value="ValidationError")) + ) + feature_store_manager.ingest_data.side_effect = Py4JJavaError( + msg="", java_exception=mock_java_exception + ) + + with pytest.raises(Py4JJavaError): + spark_output_receiver.ingest_udf_output(df, fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py new file mode 100644 index 0000000000..51747d78f9 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable + +import pytest +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import UDFArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + UDFOutputReceiver, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper + + +@pytest.fixture +def udf_arg_provider(): + udf_arg_provider = Mock(UDFArgProvider) + udf_arg_provider.provide_input_args.return_value = {"input": Mock()} + udf_arg_provider.provide_params_arg.return_value = {"params": Mock()} + udf_arg_provider.provide_additional_kwargs.return_value = {"kwarg": Mock()} + + return udf_arg_provider + + +@pytest.fixture +def udf_output_receiver(): + udf_output_receiver = Mock(UDFOutputReceiver) + udf_output_receiver.ingest_udf_output.return_value = Mock() + return udf_output_receiver + + +@pytest.fixture +def udf_output(): + udf_output = Mock(Callable) + return udf_output + + +@pytest.fixture +def udf(udf_output): + udf = Mock(Callable) + udf.return_value = udf_output + return udf + + +@pytest.fixture +def fp_config(): + fp_config = Mock(FeatureProcessorConfig) + return fp_config + + +def test_wrap(fp_config, udf_output, udf_arg_provider, udf_output_receiver): + def test_udf(input, params, kwarg): + # Verify wrapped function is called with auto-loaded arguments. + assert input is udf_arg_provider.provide_input_args.return_value["input"] + assert params is udf_arg_provider.provide_params_arg.return_value["params"] + assert kwarg is udf_arg_provider.provide_additional_kwargs.return_value["kwarg"] + return udf_output + + udf_wrapper = UDFWrapper(udf_arg_provider, udf_output_receiver) + + # Execute decorator function and the decorated function. + wrapped_udf = udf_wrapper.wrap(test_udf, fp_config) + wrapped_udf() + + # Verify interactions with dependencies. + udf_arg_provider.provide_input_args.assert_called_with(test_udf, fp_config) + udf_arg_provider.provide_params_arg.assert_called_with(test_udf, fp_config) + udf_arg_provider.provide_additional_kwargs.assert_called_with(test_udf) + udf_output_receiver.ingest_udf_output.assert_called_with(udf_output, fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py new file mode 100644 index 0000000000..16a8784586 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable +from pyspark.sql import DataFrame + +import pytest + +import test_data_helpers as tdh +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._validation import ( + SparkUDFSignatureValidator, + Validator, + ValidatorChain, + BaseDataSourceValidator, +) +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + BaseDataSource, +) + + +def test_validator_chain(): + fp_config = tdh.create_fp_config() + udf = Mock(Callable) + + first_validator = Mock(Validator) + second_validator = Mock(Validator) + validator_chain = ValidatorChain([first_validator, second_validator]) + + validator_chain.validate(udf, fp_config) + + first_validator.validate.assert_called_with(udf, fp_config) + second_validator.validate.assert_called_with(udf, fp_config) + + +def test_validator_chain_validation_fails(): + fp_config = tdh.create_fp_config() + udf = Mock(Callable) + + first_validator = Mock(validate=Mock(side_effect=ValueError())) + second_validator = Mock(validate=Mock()) + validator_chain = ValidatorChain([first_validator, second_validator]) + + with pytest.raises(ValueError): + validator_chain.validate(udf, fp_config) + + +def test_spark_udf_signature_validator_valid(): + # One Input + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE]) + + def one_data_source(fg_data_source, params, spark): + return None + + SparkUDFSignatureValidator().validate(one_data_source, fp_config) + + # Two Inputs + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def two_data_sources(fg_data_source, s3_data_source, params, spark): + return None + + SparkUDFSignatureValidator().validate(two_data_sources, fp_config) + + # No Optional Args (params and spark) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_args(fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(no_optional_args, fp_config) + + # Optional Args (no params) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_params_arg(fg_data_source, s3_data_source, spark): + return None + + SparkUDFSignatureValidator().validate(no_optional_params_arg, fp_config) + + # No Optional Args (no spark) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_spark_arg(fg_data_source, s3_data_source, params): + return None + + SparkUDFSignatureValidator().validate(no_optional_spark_arg, fp_config) + + +def test_spark_udf_signature_validator_udf_input_mismatch(): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def one_input(one, params, spark): + return None + + def three_inputs(one, two, three, params, spark): + return None + + exception_string = ( + r"feature_processor expected a function with \(2\) parameter\(s\) before any" + r" optional 'params' or 'spark' parameters for the \(2\) requested data source\(s\)\." + ) + + with pytest.raises(ValueError, match=exception_string): + SparkUDFSignatureValidator().validate(one_input, fp_config) + + with pytest.raises(ValueError, match=exception_string): + SparkUDFSignatureValidator().validate(three_inputs, fp_config) + + +def test_spark_udf_signature_validator_zero_input_params(): + def zero_inputs(params, spark): + return None + + with pytest.raises(ValueError, match="feature_processor expects at least 1 input parameter."): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + SparkUDFSignatureValidator().validate(zero_inputs, fp_config) + + +def test_spark_udf_signature_validator_udf_invalid_non_input_position(): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + with pytest.raises( + ValueError, + match="feature_processor expected the 'params' parameter to be the last or second last" + " parameter after input parameters.", + ): + + def invalid_params_position(params, fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(invalid_params_position, fp_config) + + with pytest.raises( + ValueError, + match="feature_processor expected the 'spark' parameter to be the last or second last" + " parameter after input parameters.", + ): + + def invalid_spark_position(spark, fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(invalid_spark_position, fp_config) + + +@pytest.mark.parametrize( + "data_source_name, data_source_unique_id, error_pattern", + [ + ("$_invalid_source", "unique_id", "data_source_name of input does not match pattern '.*'."), + ("", "unique_id", "data_source_name of input does not match pattern '.*'."), + ( + "source", + tdh.DATA_SOURCE_UNIQUE_ID_TOO_LONG, + "data_source_unique_id of input does not match pattern '.*'.", + ), + ("source", "", "data_source_unique_id of input does not match pattern '.*'."), + ], +) +def test_spark_udf_signature_validator_udf_invalid_base_data_source( + data_source_name, data_source_unique_id, error_pattern +): + class TestInValidCustomDataSource(BaseDataSource): + + data_source_name = None + data_source_unique_id = None + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestInValidCustomDataSource() + test_data_source.data_source_name = data_source_name + test_data_source.data_source_unique_id = data_source_unique_id + + fp_config = tdh.create_fp_config(inputs=[test_data_source]) + + def udf(input_data_source, params, spark): + return None + + with pytest.raises(ValueError, match=error_pattern): + BaseDataSourceValidator().validate(udf, fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py new file mode 100644 index 0000000000..2fed784208 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py @@ -0,0 +1,113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for athena_query.py""" +import os +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd + +from sagemaker.mlops.feature_store.athena_query import AthenaQuery + + +class TestAthenaQuery: + @pytest.fixture + def mock_session(self): + session = Mock() + session.boto_session.client.return_value = Mock() + session.boto_region_name = "us-west-2" + return session + + @pytest.fixture + def athena_query(self, mock_session): + return AthenaQuery( + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + table_name="my_feature_group", + sagemaker_session=mock_session, + ) + + def test_initialization(self, athena_query): + assert athena_query.catalog == "AwsDataCatalog" + assert athena_query.database == "sagemaker_featurestore" + assert athena_query.table_name == "my_feature_group" + assert athena_query._current_query_execution_id is None + + @patch("sagemaker.mlops.feature_store.athena_query.start_query_execution") + def test_run_starts_query(self, mock_start, athena_query): + mock_start.return_value = {"QueryExecutionId": "query-123"} + + result = athena_query.run( + query_string="SELECT * FROM table", + output_location="s3://bucket/output", + ) + + assert result == "query-123" + assert athena_query._current_query_execution_id == "query-123" + assert athena_query._result_bucket == "bucket" + assert athena_query._result_file_prefix == "output" + + @patch("sagemaker.mlops.feature_store.athena_query.start_query_execution") + def test_run_with_kms_key(self, mock_start, athena_query): + mock_start.return_value = {"QueryExecutionId": "query-123"} + + athena_query.run( + query_string="SELECT * FROM table", + output_location="s3://bucket/output", + kms_key="arn:aws:kms:us-west-2:123:key/abc", + ) + + mock_start.assert_called_once() + call_kwargs = mock_start.call_args[1] + assert call_kwargs["kms_key"] == "arn:aws:kms:us-west-2:123:key/abc" + + @patch("sagemaker.mlops.feature_store.athena_query.wait_for_athena_query") + def test_wait_calls_helper(self, mock_wait, athena_query): + athena_query._current_query_execution_id = "query-123" + + athena_query.wait() + + mock_wait.assert_called_once_with(athena_query.sagemaker_session, "query-123") + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_get_query_execution(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} + + result = athena_query.get_query_execution() + + assert result["QueryExecution"]["Status"]["State"] == "SUCCEEDED" + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + @patch("sagemaker.mlops.feature_store.athena_query.download_athena_query_result") + @patch("pandas.read_csv") + @patch("os.path.join") + def test_as_dataframe_success(self, mock_join, mock_read_csv, mock_download, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + athena_query._result_bucket = "bucket" + athena_query._result_file_prefix = "prefix" + + mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} + mock_join.return_value = "/tmp/query-123.csv" + mock_read_csv.return_value = pd.DataFrame({"col": [1, 2, 3]}) + + with patch("tempfile.gettempdir", return_value="/tmp"): + with patch("os.remove"): + df = athena_query.as_dataframe() + + assert len(df) == 3 + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_as_dataframe_raises_when_running(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "RUNNING"}}} + + with pytest.raises(RuntimeError, match="still executing"): + athena_query.as_dataframe() + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_as_dataframe_raises_when_failed(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "FAILED"}}} + + with pytest.raises(RuntimeError, match="failed"): + athena_query.as_dataframe() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py new file mode 100644 index 0000000000..254fb0e196 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py @@ -0,0 +1,345 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for dataset_builder.py""" +import datetime +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd + +from sagemaker.mlops.feature_store import FeatureGroup +from sagemaker.mlops.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + TableType, + JoinTypeEnum, + JoinComparatorEnum, + construct_feature_group_to_be_merged, +) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, +) + + +class TestTableType: + def test_feature_group_value(self): + assert TableType.FEATURE_GROUP.value == "FeatureGroup" + + def test_data_frame_value(self): + assert TableType.DATA_FRAME.value == "DataFrame" + + +class TestJoinTypeEnum: + def test_inner_join(self): + assert JoinTypeEnum.INNER_JOIN.value == "JOIN" + + def test_left_join(self): + assert JoinTypeEnum.LEFT_JOIN.value == "LEFT JOIN" + + def test_right_join(self): + assert JoinTypeEnum.RIGHT_JOIN.value == "RIGHT JOIN" + + def test_full_join(self): + assert JoinTypeEnum.FULL_JOIN.value == "FULL JOIN" + + def test_cross_join(self): + assert JoinTypeEnum.CROSS_JOIN.value == "CROSS JOIN" + + +class TestJoinComparatorEnum: + def test_equals(self): + assert JoinComparatorEnum.EQUALS.value == "=" + + def test_greater_than(self): + assert JoinComparatorEnum.GREATER_THAN.value == ">" + + def test_less_than(self): + assert JoinComparatorEnum.LESS_THAN.value == "<" + + +class TestFeatureGroupToBeMerged: + def test_initialization(self): + fg = FeatureGroupToBeMerged( + features=["id", "value"], + included_feature_names=["id", "value"], + projected_feature_names=["id", "value"], + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + table_name="my_table", + record_identifier_feature_name="id", + event_time_identifier_feature=FeatureDefinition( + feature_name="event_time", + feature_type="String", + ), + ) + + assert fg.features == ["id", "value"] + assert fg.catalog == "AwsDataCatalog" + assert fg.table_name == "my_table" + assert fg.join_type == JoinTypeEnum.INNER_JOIN + assert fg.join_comparator == JoinComparatorEnum.EQUALS + + def test_custom_join_settings(self): + fg = FeatureGroupToBeMerged( + features=["id"], + included_feature_names=["id"], + projected_feature_names=["id"], + catalog="AwsDataCatalog", + database="db", + table_name="table", + record_identifier_feature_name="id", + event_time_identifier_feature=FeatureDefinition( + feature_name="ts", + feature_type="String", + ), + join_type=JoinTypeEnum.LEFT_JOIN, + join_comparator=JoinComparatorEnum.GREATER_THAN, + ) + + assert fg.join_type == JoinTypeEnum.LEFT_JOIN + assert fg.join_comparator == JoinComparatorEnum.GREATER_THAN + + +class TestConstructFeatureGroupToBeMerged: + @patch("sagemaker.mlops.feature_store.dataset_builder.FeatureGroup") + def test_constructs_from_feature_group(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_group_name = "test-fg" + mock_fg.record_identifier_feature_name = "id" + mock_fg.event_time_feature_name = "event_time" + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="event_time", feature_type="String"), + ] + mock_fg.offline_store_config.data_catalog_config.catalog = "MyCatalog" + mock_fg.offline_store_config.data_catalog_config.database = "MyDatabase" + mock_fg.offline_store_config.data_catalog_config.table_name = "MyTable" + mock_fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + mock_fg_class.get.return_value = mock_fg + + target_fg = MagicMock() + target_fg.feature_group_name = "test-fg" + + result = construct_feature_group_to_be_merged( + target_feature_group=target_fg, + included_feature_names=["id", "value"], + ) + + assert result.table_name == "MyTable" + assert result.database == "MyDatabase" + assert result.record_identifier_feature_name == "id" + assert result.table_type == TableType.FEATURE_GROUP + + @patch("sagemaker.mlops.feature_store.dataset_builder.FeatureGroup") + def test_raises_when_no_metastore(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_group_name = "test-fg" + mock_fg.offline_store_config = None + mock_fg_class.get.return_value = mock_fg + + target_fg = MagicMock() + target_fg.feature_group_name = "test-fg" + + with pytest.raises(RuntimeError, match="No metastore"): + construct_feature_group_to_be_merged(target_fg, None) + + +class TestDatasetBuilder: + @pytest.fixture + def mock_session(self): + return Mock() + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + "event_time": ["2024-01-01", "2024-01-02", "2024-01-03"], + }) + + def test_initialization_with_dataframe(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + assert builder._output_path == "s3://bucket/output" + assert builder._record_identifier_feature_name == "id" + + def test_fluent_api_point_in_time(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.point_in_time_accurate_join() + + assert result is builder + assert builder._point_in_time_accurate_join is True + + def test_fluent_api_include_duplicated(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.include_duplicated_records() + + assert result is builder + assert builder._include_duplicated_records is True + + def test_fluent_api_include_deleted(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.include_deleted_records() + + assert result is builder + assert builder._include_deleted_records is True + + def test_fluent_api_number_of_recent_records(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.with_number_of_recent_records_by_record_identifier(5) + + assert result is builder + assert builder._number_of_recent_records == 5 + + def test_fluent_api_number_of_records(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.with_number_of_records_from_query_results(100) + + assert result is builder + assert builder._number_of_records == 100 + + def test_fluent_api_as_of(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + timestamp = datetime.datetime(2024, 1, 15, 12, 0, 0) + result = builder.as_of(timestamp) + + assert result is builder + assert builder._write_time_ending_timestamp == timestamp + + def test_fluent_api_event_time_range(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 1, 31) + result = builder.with_event_time_range(start, end) + + assert result is builder + assert builder._event_time_starting_timestamp == start + assert builder._event_time_ending_timestamp == end + + @patch.object(DatasetBuilder, "_run_query") + @patch("sagemaker.mlops.feature_store.dataset_builder.construct_feature_group_to_be_merged") + def test_with_feature_group(self, mock_construct, mock_run, mock_session, sample_dataframe): + mock_fg_to_merge = MagicMock() + mock_construct.return_value = mock_fg_to_merge + + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + mock_fg = MagicMock() + result = builder.with_feature_group(mock_fg, target_feature_name_in_base="id") + + assert result is builder + assert len(builder._feature_groups_to_be_merged) == 1 + + +class TestDatasetBuilderCreate: + @pytest.fixture + def mock_session(self): + return Mock() + + def test_create_with_feature_group(self, mock_session): + mock_fg = MagicMock(spec=FeatureGroup) + builder = DatasetBuilder.create( + base=mock_fg, + output_path="s3://bucket/output", + session=mock_session, + ) + assert builder._base == mock_fg + assert builder._output_path == "s3://bucket/output" + + def test_create_with_dataframe(self, mock_session): + df = pd.DataFrame({"id": [1], "value": [10]}) + builder = DatasetBuilder.create( + base=df, + output_path="s3://bucket/output", + session=mock_session, + record_identifier_feature_name="id", + event_time_identifier_feature_name="event_time", + ) + assert builder._record_identifier_feature_name == "id" + + def test_create_with_dataframe_requires_identifiers(self, mock_session): + df = pd.DataFrame({"id": [1], "value": [10]}) + with pytest.raises(ValueError, match="record_identifier_feature_name"): + DatasetBuilder.create( + base=df, + output_path="s3://bucket/output", + session=mock_session, + ) + + +class TestDatasetBuilderValidation: + @pytest.fixture + def mock_session(self): + return Mock() + + def test_to_csv_raises_for_invalid_base(self, mock_session): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base="invalid", # Not DataFrame or FeatureGroup + _output_path="s3://bucket/output", + ) + + with pytest.raises(ValueError, match="must be either"): + builder.to_csv_file() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py new file mode 100644 index 0000000000..299868b5d2 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for feature_definition.py""" +import pytest + +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, + CollectionTypeEnum, + IntegralFeatureDefinition, + FractionalFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, + ListCollectionType, + SetCollectionType, +) + + +class TestFeatureTypeEnum: + def test_fractional_value(self): + assert FeatureTypeEnum.FRACTIONAL.value == "Fractional" + + def test_integral_value(self): + assert FeatureTypeEnum.INTEGRAL.value == "Integral" + + def test_string_value(self): + assert FeatureTypeEnum.STRING.value == "String" + + +class TestCollectionTypeEnum: + def test_list_value(self): + assert CollectionTypeEnum.LIST.value == "List" + + def test_set_value(self): + assert CollectionTypeEnum.SET.value == "Set" + + def test_vector_value(self): + assert CollectionTypeEnum.VECTOR.value == "Vector" + + +class TestCollectionTypes: + def test_list_collection_type(self): + collection = ListCollectionType() + assert collection.collection_type == "List" + assert collection.collection_config is None + + def test_set_collection_type(self): + collection = SetCollectionType() + assert collection.collection_type == "Set" + assert collection.collection_config is None + + def test_vector_collection_type(self): + collection = VectorCollectionType(dimension=128) + assert collection.collection_type == "Vector" + assert collection.collection_config is not None + assert collection.collection_config.vector_config.dimension == 128 + + +class TestFeatureDefinitionFactories: + def test_integral_feature_definition(self): + definition = IntegralFeatureDefinition(feature_name="my_int_feature") + assert definition.feature_name == "my_int_feature" + assert definition.feature_type == "Integral" + assert definition.collection_type is None + + def test_fractional_feature_definition(self): + definition = FractionalFeatureDefinition(feature_name="my_float_feature") + assert definition.feature_name == "my_float_feature" + assert definition.feature_type == "Fractional" + assert definition.collection_type is None + + def test_string_feature_definition(self): + definition = StringFeatureDefinition(feature_name="my_string_feature") + assert definition.feature_name == "my_string_feature" + assert definition.feature_type == "String" + assert definition.collection_type is None + + def test_integral_with_list_collection(self): + definition = IntegralFeatureDefinition( + feature_name="my_int_list", + collection_type=ListCollectionType(), + ) + assert definition.feature_name == "my_int_list" + assert definition.feature_type == "Integral" + assert definition.collection_type == "List" + + def test_string_with_set_collection(self): + definition = StringFeatureDefinition( + feature_name="my_string_set", + collection_type=SetCollectionType(), + ) + assert definition.feature_name == "my_string_set" + assert definition.feature_type == "String" + assert definition.collection_type == "Set" + + def test_fractional_with_vector_collection(self): + definition = FractionalFeatureDefinition( + feature_name="my_embedding", + collection_type=VectorCollectionType(dimension=256), + ) + assert definition.feature_name == "my_embedding" + assert definition.feature_type == "Fractional" + assert definition.collection_type == "Vector" + assert definition.collection_config.vector_config.dimension == 256 + + +class TestFeatureDefinitionSerialization: + """Test that FeatureDefinition can be serialized (Pydantic model_dump).""" + + def test_simple_definition_serialization(self): + definition = IntegralFeatureDefinition(feature_name="id") + # Pydantic model - use model_dump + data = definition.model_dump(exclude_none=True) + assert data["feature_name"] == "id" + assert data["feature_type"] == "Integral" + + def test_collection_definition_serialization(self): + definition = FractionalFeatureDefinition( + feature_name="vector", + collection_type=VectorCollectionType(dimension=10), + ) + data = definition.model_dump(exclude_none=True) + assert data["feature_name"] == "vector" + assert data["feature_type"] == "Fractional" + assert data["collection_type"] == "Vector" + assert data["collection_config"]["vector_config"]["dimension"] == 10 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py new file mode 100644 index 0000000000..a9d5408bf6 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py @@ -0,0 +1,202 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for feature_utils.py""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd +import numpy as np + +from sagemaker.mlops.feature_store.feature_utils import ( + load_feature_definitions_from_dataframe, + as_hive_ddl, + create_athena_query, + ingest_dataframe, + get_session_from_role, + _is_collection_column, + _generate_feature_definition, +) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + ListCollectionType, +) + + +class TestLoadFeatureDefinitionsFromDataframe: + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": pd.Series([1, 2, 3], dtype="int64"), + "value": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "name": pd.Series(["a", "b", "c"], dtype="string"), + }) + + def test_infers_integral_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + id_def = next(d for d in defs if d.feature_name == "id") + assert id_def.feature_type == "Integral" + + def test_infers_fractional_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + value_def = next(d for d in defs if d.feature_name == "value") + assert value_def.feature_type == "Fractional" + + def test_infers_string_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + name_def = next(d for d in defs if d.feature_name == "name") + assert name_def.feature_type == "String" + + def test_returns_correct_count(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + assert len(defs) == 3 + + def test_collection_type_with_in_memory_storage(self): + df = pd.DataFrame({ + "id": pd.Series([1, 2], dtype="int64"), + "tags": pd.Series([["a", "b"], ["c"]], dtype="object"), + }) + defs = load_feature_definitions_from_dataframe(df, online_storage_type="InMemory") + tags_def = next(d for d in defs if d.feature_name == "tags") + assert tags_def.collection_type == "List" + + +class TestIsCollectionColumn: + def test_list_column_returns_true(self): + series = pd.Series([[1, 2], [3, 4], [5]]) + assert _is_collection_column(series) == True + + def test_scalar_column_returns_false(self): + series = pd.Series([1, 2, 3]) + assert _is_collection_column(series) == False + + def test_empty_series(self): + series = pd.Series([], dtype="object") + assert _is_collection_column(series) == False + + +class TestAsHiveDdl: + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_generates_ddl_string(self, mock_fg_class): + # Setup mock + mock_fg = MagicMock() + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="name", feature_type="String"), + ] + mock_fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix" + mock_fg_class.get.return_value = mock_fg + + ddl = as_hive_ddl("my-feature-group") + + assert "CREATE EXTERNAL TABLE" in ddl + assert "my-feature-group" in ddl + assert "id INT" in ddl + assert "value FLOAT" in ddl + assert "name STRING" in ddl + assert "write_time TIMESTAMP" in ddl + assert "event_time TIMESTAMP" in ddl + assert "is_deleted BOOLEAN" in ddl + assert "s3://bucket/prefix" in ddl + + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_custom_database_and_table(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_definitions = [] + mock_fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix" + mock_fg_class.get.return_value = mock_fg + + ddl = as_hive_ddl("my-fg", database="custom_db", table_name="custom_table") + + assert "custom_db.custom_table" in ddl + + +class TestCreateAthenaQuery: + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_creates_athena_query(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.offline_store_config.data_catalog_config.catalog = "MyCatalog" + mock_fg.offline_store_config.data_catalog_config.database = "MyDatabase" + mock_fg.offline_store_config.data_catalog_config.table_name = "MyTable" + mock_fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + mock_fg_class.get.return_value = mock_fg + + session = Mock() + query = create_athena_query("my-fg", session) + + assert query.catalog == "AwsDataCatalog" # disable_glue=False uses default + assert query.database == "MyDatabase" + assert query.table_name == "MyTable" + + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_raises_when_no_metastore(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.offline_store_config = None + mock_fg_class.get.return_value = mock_fg + + session = Mock() + with pytest.raises(RuntimeError, match="No metastore"): + create_athena_query("my-fg", session) + + +class TestIngestDataframe: + @patch("sagemaker.mlops.feature_store.feature_utils.IngestionManagerPandas") + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_creates_manager_and_runs(self, mock_fg_class, mock_manager_class): + mock_fg = MagicMock() + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + ] + mock_fg_class.get.return_value = mock_fg + + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + df = pd.DataFrame({"id": [1, 2, 3]}) + result = ingest_dataframe("my-fg", df, max_workers=2, max_processes=1) + + mock_manager_class.assert_called_once() + mock_manager.run.assert_called_once() + assert result == mock_manager + + def test_raises_on_invalid_max_workers(self): + df = pd.DataFrame({"id": [1, 2, 3]}) + with pytest.raises(ValueError, match="max_workers"): + ingest_dataframe("my-fg", df, max_workers=0) + + def test_raises_on_invalid_max_processes(self): + df = pd.DataFrame({"id": [1, 2, 3]}) + with pytest.raises(ValueError, match="max_processes"): + ingest_dataframe("my-fg", df, max_processes=-1) + + +class TestGetSessionFromRole: + @patch("sagemaker.mlops.feature_store.feature_utils.boto3") + @patch("sagemaker.mlops.feature_store.feature_utils.Session") + def test_creates_session_without_role(self, mock_session_class, mock_boto3): + mock_boto_session = MagicMock() + mock_boto3.Session.return_value = mock_boto_session + + get_session_from_role(region="us-west-2") + + mock_boto3.Session.assert_called_with(region_name="us-west-2") + mock_session_class.assert_called_once() + + @patch("sagemaker.mlops.feature_store.feature_utils.boto3") + @patch("sagemaker.mlops.feature_store.feature_utils.Session") + def test_assumes_role_when_provided(self, mock_session_class, mock_boto3): + mock_boto_session = MagicMock() + mock_sts = MagicMock() + mock_sts.assume_role.return_value = { + "Credentials": { + "AccessKeyId": "key", + "SecretAccessKey": "secret", + "SessionToken": "token", + } + } + mock_boto_session.client.return_value = mock_sts + mock_boto3.Session.return_value = mock_boto_session + + get_session_from_role(region="us-west-2", assume_role="arn:aws:iam::123:role/MyRole") + + mock_sts.assume_role.assert_called_once() + assert mock_boto3.Session.call_count == 2 # Initial + after assume diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py new file mode 100644 index 0000000000..2ecf495967 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py @@ -0,0 +1,256 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for ingestion_manager_pandas.py""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd +import numpy as np + +from sagemaker.mlops.feature_store.ingestion_manager_pandas import ( + IngestionManagerPandas, + IngestionError, +) + + +class TestIngestionError: + def test_stores_failed_rows(self): + error = IngestionError([1, 5, 10], "Some rows failed") + assert error.failed_rows == [1, 5, 10] + assert "Some rows failed" in str(error) + + +class TestIngestionManagerPandas: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + } + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + "name": ["a", "b", "c"], + }) + + @pytest.fixture + def manager(self, feature_definitions): + return IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=1, + max_processes=1, + ) + + def test_initialization(self, manager): + assert manager.feature_group_name == "test-fg" + assert manager.max_workers == 1 + assert manager.max_processes == 1 + assert manager.failed_rows == [] + + def test_failed_rows_property(self, manager): + manager._failed_indices = [1, 2, 3] + assert manager.failed_rows == [1, 2, 3] + + +class TestIngestionManagerHelpers: + def test_is_feature_collection_type_true(self): + feature_defs = { + "tags": {"FeatureName": "tags", "FeatureType": "String", "CollectionType": "List"}, + } + assert IngestionManagerPandas._is_feature_collection_type("tags", feature_defs) is True + + def test_is_feature_collection_type_false(self): + feature_defs = { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + } + assert IngestionManagerPandas._is_feature_collection_type("id", feature_defs) is False + + def test_is_feature_collection_type_missing(self): + feature_defs = {} + assert IngestionManagerPandas._is_feature_collection_type("unknown", feature_defs) is False + + def test_feature_value_is_not_none_scalar(self): + assert IngestionManagerPandas._feature_value_is_not_none(5) is True + assert IngestionManagerPandas._feature_value_is_not_none(None) is False + assert IngestionManagerPandas._feature_value_is_not_none(np.nan) is False + + def test_feature_value_is_not_none_list(self): + assert IngestionManagerPandas._feature_value_is_not_none([1, 2, 3]) is True + assert IngestionManagerPandas._feature_value_is_not_none([]) is True + assert IngestionManagerPandas._feature_value_is_not_none(None) is False + + def test_convert_to_string_list(self): + result = IngestionManagerPandas._convert_to_string_list([1, 2, 3]) + assert result == ["1", "2", "3"] + + def test_convert_to_string_list_with_none(self): + result = IngestionManagerPandas._convert_to_string_list([1, None, 3]) + assert result == ["1", None, "3"] + + def test_convert_to_string_list_raises_for_non_list(self): + with pytest.raises(ValueError, match="must be an Array"): + IngestionManagerPandas._convert_to_string_list("not a list") + + +class TestIngestionManagerRun: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + } + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + }) + + @patch.object(IngestionManagerPandas, "_run_single_process_single_thread") + def test_run_single_thread_mode(self, mock_single, feature_definitions, sample_dataframe): + manager = IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=1, + max_processes=1, + ) + + manager.run(sample_dataframe) + + mock_single.assert_called_once() + + @patch.object(IngestionManagerPandas, "_run_multi_process") + def test_run_multi_process_mode(self, mock_multi, feature_definitions, sample_dataframe): + manager = IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=2, + max_processes=2, + ) + + manager.run(sample_dataframe) + + mock_multi.assert_called_once() + + +class TestIngestionManagerIngestRow: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + } + + @pytest.fixture + def collection_feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "tags": {"FeatureName": "tags", "FeatureType": "String", "CollectionType": "List"}, + } + + def test_ingest_row_success(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + mock_fg.put_record.assert_called_once() + assert len(failed_rows) == 0 + + def test_ingest_row_with_collection_type(self, collection_feature_definitions): + df = pd.DataFrame({ + "id": [1], + "tags": [["tag1", "tag2"]], + }) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=collection_feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + mock_fg.put_record.assert_called_once() + call_args = mock_fg.put_record.call_args + record = call_args[1]["record"] + + # Find the tags feature value + tags_value = next(v for v in record if v.feature_name == "tags") + assert tags_value.value_as_string_list == ["tag1", "tag2"] + + def test_ingest_row_failure_appends_to_failed(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + mock_fg.put_record.side_effect = Exception("API Error") + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + assert len(failed_rows) == 1 + assert failed_rows[0] == 0 # Index of failed row + + def test_ingest_row_with_target_stores(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=["OnlineStore"], + ) + + call_args = mock_fg.put_record.call_args + assert call_args[1]["target_stores"] == ["OnlineStore"] + + def test_ingest_row_skips_none_values(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": [None]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + call_args = mock_fg.put_record.call_args + record = call_args[1]["record"] + # Only id should be in record, name is None + assert len(record) == 1 + assert record[0].feature_name == "id" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py new file mode 100644 index 0000000000..44e3ec6085 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for inputs.py (enums).""" +import pytest + +from sagemaker.mlops.feature_store.inputs import ( + TargetStoreEnum, + OnlineStoreStorageTypeEnum, + TableFormatEnum, + ResourceEnum, + SearchOperatorEnum, + SortOrderEnum, + FilterOperatorEnum, + DeletionModeEnum, + ExpirationTimeResponseEnum, + ThroughputModeEnum, +) + + +class TestTargetStoreEnum: + def test_online_store(self): + assert TargetStoreEnum.ONLINE_STORE.value == "OnlineStore" + + def test_offline_store(self): + assert TargetStoreEnum.OFFLINE_STORE.value == "OfflineStore" + + +class TestOnlineStoreStorageTypeEnum: + def test_standard(self): + assert OnlineStoreStorageTypeEnum.STANDARD.value == "Standard" + + def test_in_memory(self): + assert OnlineStoreStorageTypeEnum.IN_MEMORY.value == "InMemory" + + +class TestTableFormatEnum: + def test_glue(self): + assert TableFormatEnum.GLUE.value == "Glue" + + def test_iceberg(self): + assert TableFormatEnum.ICEBERG.value == "Iceberg" + + +class TestResourceEnum: + def test_feature_group(self): + assert ResourceEnum.FEATURE_GROUP.value == "FeatureGroup" + + def test_feature_metadata(self): + assert ResourceEnum.FEATURE_METADATA.value == "FeatureMetadata" + + +class TestSearchOperatorEnum: + def test_and(self): + assert SearchOperatorEnum.AND.value == "And" + + def test_or(self): + assert SearchOperatorEnum.OR.value == "Or" + + +class TestSortOrderEnum: + def test_ascending(self): + assert SortOrderEnum.ASCENDING.value == "Ascending" + + def test_descending(self): + assert SortOrderEnum.DESCENDING.value == "Descending" + + +class TestFilterOperatorEnum: + def test_equals(self): + assert FilterOperatorEnum.EQUALS.value == "Equals" + + def test_not_equals(self): + assert FilterOperatorEnum.NOT_EQUALS.value == "NotEquals" + + def test_greater_than(self): + assert FilterOperatorEnum.GREATER_THAN.value == "GreaterThan" + + def test_contains(self): + assert FilterOperatorEnum.CONTAINS.value == "Contains" + + def test_exists(self): + assert FilterOperatorEnum.EXISTS.value == "Exists" + + def test_in(self): + assert FilterOperatorEnum.IN.value == "In" + + +class TestDeletionModeEnum: + def test_soft_delete(self): + assert DeletionModeEnum.SOFT_DELETE.value == "SoftDelete" + + def test_hard_delete(self): + assert DeletionModeEnum.HARD_DELETE.value == "HardDelete" + + +class TestExpirationTimeResponseEnum: + def test_disabled(self): + assert ExpirationTimeResponseEnum.DISABLED.value == "Disabled" + + def test_enabled(self): + assert ExpirationTimeResponseEnum.ENABLED.value == "Enabled" + + +class TestThroughputModeEnum: + def test_on_demand(self): + assert ThroughputModeEnum.ON_DEMAND.value == "OnDemand" + + def test_provisioned(self): + assert ThroughputModeEnum.PROVISIONED.value == "Provisioned"