From 534df91349710aa7e20a858ddd68a71ae6450995 Mon Sep 17 00:00:00 2001 From: adishaa Date: Fri, 16 Jan 2026 07:00:13 -0800 Subject: [PATCH 1/8] feat: Add Feature Store Support to V3 --- .../sagemaker/mlops/feature_store/__init__.py | 129 ++++ .../mlops/feature_store/athena_query.py | 112 +++ .../mlops/feature_store/dataset_builder.py | 725 ++++++++++++++++++ .../mlops/feature_store/feature_definition.py | 107 +++ .../mlops/feature_store/feature_utils.py | 488 ++++++++++++ .../feature_store/ingestion_manager_pandas.py | 321 ++++++++ .../sagemaker/mlops/feature_store/inputs.py | 60 ++ 7 files changed, 1942 insertions(+) create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py new file mode 100644 index 0000000000..ee8cd7d1a3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""SageMaker FeatureStore V3 - powered by sagemaker-core.""" + +# Resources from core +from sagemaker_core.main.resources import FeatureGroup, FeatureMetadata +from sagemaker_core.main.resources import FeatureStore + +# Shapes from core (Pydantic - no to_dict() needed) +from sagemaker_core.main.shapes import ( + DataCatalogConfig, + FeatureParameter, + FeatureValue, + Filter, + OfflineStoreConfig, + OnlineStoreConfig, + OnlineStoreSecurityConfig, + S3StorageConfig, + SearchExpression, + ThroughputConfig, + TtlDuration, +) + +# Enums (local - core uses strings) +from sagemaker.mlops.feature_store.inputs import ( + DeletionModeEnum, + ExpirationTimeResponseEnum, + FilterOperatorEnum, + OnlineStoreStorageTypeEnum, + ResourceEnum, + SearchOperatorEnum, + SortOrderEnum, + TableFormatEnum, + TargetStoreEnum, + ThroughputModeEnum, +) + +# Feature Definition helpers (local) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, + CollectionTypeEnum, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + ListCollectionType, + SetCollectionType, + VectorCollectionType, +) + +# Utility functions (local) +from sagemaker.mlops.feature_store.feature_utils import ( + as_hive_ddl, + create_athena_query, + create_dataset, + get_session_from_role, + ingest_dataframe, + load_feature_definitions_from_dataframe, +) + +# Classes (local) +from sagemaker.mlops.feature_store.athena_query import AthenaQuery +from sagemaker.mlops.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + JoinComparatorEnum, + JoinTypeEnum, + TableType, +) +from sagemaker.mlops.feature_store.ingestion_manager_pandas import ( + IngestionError, + IngestionManagerPandas, +) + +__all__ = [ + # Resources + "FeatureGroup", + "FeatureMetadata", + "FeatureStore", + # Shapes + "DataCatalogConfig", + "FeatureParameter", + "FeatureValue", + "Filter", + "OfflineStoreConfig", + "OnlineStoreConfig", + "OnlineStoreSecurityConfig", + "S3StorageConfig", + "SearchExpression", + "ThroughputConfig", + "TtlDuration", + # Enums + "DeletionModeEnum", + "ExpirationTimeResponseEnum", + "FilterOperatorEnum", + "OnlineStoreStorageTypeEnum", + "ResourceEnum", + "SearchOperatorEnum", + "SortOrderEnum", + "TableFormatEnum", + "TargetStoreEnum", + "ThroughputModeEnum", + # Feature Definitions + "FeatureDefinition", + "FeatureTypeEnum", + "CollectionTypeEnum", + "FractionalFeatureDefinition", + "IntegralFeatureDefinition", + "StringFeatureDefinition", + "ListCollectionType", + "SetCollectionType", + "VectorCollectionType", + # Utility functions + "as_hive_ddl", + "create_athena_query", + "create_dataset", + "get_session_from_role", + "ingest_dataframe", + "load_feature_definitions_from_dataframe", + # Classes + "AthenaQuery", + "DatasetBuilder", + "FeatureGroupToBeMerged", + "IngestionError", + "IngestionManagerPandas", + "JoinComparatorEnum", + "JoinTypeEnum", + "TableType", +] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py new file mode 100644 index 0000000000..123b3c4305 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/athena_query.py @@ -0,0 +1,112 @@ +import os +import tempfile +from dataclasses import dataclass, field +from typing import Any, Dict +from urllib.parse import urlparse +import pandas as pd +from pandas import DataFrame + +from sagemaker.mlops.feature_store.feature_utils import ( + start_query_execution, + get_query_execution, + wait_for_athena_query, + download_athena_query_result, +) + +from sagemaker.core.helper.session_helper import Session + +@dataclass +class AthenaQuery: + """Class to manage querying of feature store data with AWS Athena. + + This class instantiates a AthenaQuery object that is used to retrieve data from feature store + via standard SQL queries. + + Attributes: + catalog (str): name of the data catalog. + database (str): name of the database. + table_name (str): name of the table. + sagemaker_session (Session): instance of the Session class to perform boto calls. + """ + + catalog: str + database: str + table_name: str + sagemaker_session: Session + _current_query_execution_id: str = field(default=None, init=False) + _result_bucket: str = field(default=None, init=False) + _result_file_prefix: str = field(default=None, init=False) + + def run( + self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None + ) -> str: + """Execute a SQL query given a query string, output location and kms key. + + This method executes the SQL query using Athena and outputs the results to output_location + and returns the execution id of the query. + + Args: + query_string: SQL query string. + output_location: S3 URI of the query result. + kms_key: KMS key id. If set, will be used to encrypt the query result file. + workgroup (str): The name of the workgroup in which the query is being started. + + Returns: + Execution id of the query. + """ + response = start_query_execution( + session=self.sagemaker_session, + catalog=self.catalog, + database=self.database, + query_string=query_string, + output_location=output_location, + kms_key=kms_key, + workgroup=workgroup, + ) + + self._current_query_execution_id = response["QueryExecutionId"] + parsed_result = urlparse(output_location, allow_fragments=False) + self._result_bucket = parsed_result.netloc + self._result_file_prefix = parsed_result.path.strip("/") + return self._current_query_execution_id + + def wait(self): + """Wait for the current query to finish.""" + wait_for_athena_query(self.sagemaker_session, self._current_query_execution_id) + + def get_query_execution(self) -> Dict[str, Any]: + """Get execution status of the current query. + + Returns: + Response dict from Athena. + """ + return get_query_execution(self.sagemaker_session, self._current_query_execution_id) + + def as_dataframe(self, **kwargs) -> DataFrame: + """Download the result of the current query and load it into a DataFrame. + + Args: + **kwargs (object): key arguments used for the method pandas.read_csv to be able to + have a better tuning on data. For more info read: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html + + Returns: + A pandas DataFrame contains the query result. + """ + state = self.get_query_execution()["QueryExecution"]["Status"]["State"] + if state != "SUCCEEDED": + if state in ("QUEUED", "RUNNING"): + raise RuntimeError(f"Query {self._current_query_execution_id} still executing.") + raise RuntimeError(f"Query {self._current_query_execution_id} failed.") + + output_file = os.path.join(tempfile.gettempdir(), f"{self._current_query_execution_id}.csv") + download_athena_query_result( + session=self.sagemaker_session, + bucket=self._result_bucket, + prefix=self._result_file_prefix, + query_execution_id=self._current_query_execution_id, + filename=output_file, + ) + kwargs.pop("delimiter", None) + return pd.read_csv(output_file, delimiter=",", **kwargs) + diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py new file mode 100644 index 0000000000..f5450663a6 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py @@ -0,0 +1,725 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Dataset Builder for FeatureStore.""" +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Union +import datetime + +import pandas as pd + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store import FeatureGroup +from sagemaker.mlops.feature_store.feature_definition import FeatureDefinition, FeatureTypeEnum +from sagemaker.mlops.feature_store.feature_utils import ( + upload_dataframe_to_s3, + download_csv_from_s3, + run_athena_query, +) + +_DEFAULT_CATALOG = "AwsDataCatalog" +_DEFAULT_DATABASE = "sagemaker_featurestore" + +_DTYPE_TO_FEATURE_TYPE = { + "object": "String", "string": "String", + "int64": "Integral", "int32": "Integral", + "float64": "Fractional", "float32": "Fractional", +} + +_DTYPE_TO_ATHENA_TYPE = { + "object": "STRING", "int64": "INT", "float64": "DOUBLE", + "bool": "BOOLEAN", "datetime64[ns]": "TIMESTAMP", +} + + +class TableType(Enum): + FEATURE_GROUP = "FeatureGroup" + DATA_FRAME = "DataFrame" + + +class JoinTypeEnum(Enum): + INNER_JOIN = "JOIN" + LEFT_JOIN = "LEFT JOIN" + RIGHT_JOIN = "RIGHT JOIN" + FULL_JOIN = "FULL JOIN" + CROSS_JOIN = "CROSS JOIN" + + +class JoinComparatorEnum(Enum): + EQUALS = "=" + GREATER_THAN = ">" + GREATER_THAN_OR_EQUAL_TO = ">=" + LESS_THAN = "<" + LESS_THAN_OR_EQUAL_TO = "<=" + NOT_EQUAL_TO = "<>" + + +@dataclass +class FeatureGroupToBeMerged: + """FeatureGroup metadata which will be used for SQL join. + + This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names, + a list of feature names which will be included in SQL query, a database, an Athena table name, + a feature name of record identifier, a feature name of event time identifier and a feature name + of base which is the target join key. + + Attributes: + features (List[str]): A list of strings representing feature names of this FeatureGroup. + included_feature_names (List[str]): A list of strings representing features to be + included in the SQL join. + projected_feature_names (List[str]): A list of strings representing features to be + included for final projection in output. + catalog (str): A string representing the catalog. + database (str): A string representing the database. + table_name (str): A string representing the Athena table name of this FeatureGroup. + record_identifier_feature_name (str): A string representing the record identifier feature. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + target_feature_name_in_base (str): A string representing the feature name in base which will + be used as target join key (default: None). + table_type (TableType): A TableType representing the type of table if it is Feature Group or + Panda Data Frame (default: None). + feature_name_in_target (str): A string representing the feature name in the target feature + group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + """ + features: List[str] + included_feature_names: List[str] + projected_feature_names: List[str] + catalog: str + database: str + table_name: str + record_identifier_feature_name: str + event_time_identifier_feature: FeatureDefinition + target_feature_name_in_base: str = None + table_type: TableType = None + feature_name_in_target: str = None + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN + + +def construct_feature_group_to_be_merged( + target_feature_group: FeatureGroup, + included_feature_names: List[str], + target_feature_name_in_base: str = None, + feature_name_in_target: str = None, + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS, + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN, +) -> FeatureGroupToBeMerged: + """Construct a FeatureGroupToBeMerged object by provided parameters. + + Args: + target_feature_group (FeatureGroup): A FeatureGroup object. + included_feature_names (List[str]): A list of strings representing features to be + included in the output. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + feature_name_in_target (str): A string representing the feature name in the target feature + group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + + Returns: + A FeatureGroupToBeMerged object. + + Raises: + RuntimeError: No metastore is configured with the FeatureGroup. + ValueError: Invalid feature name(s) in included_feature_names. + """ + fg = FeatureGroup.get(feature_group_name=target_feature_group.feature_group_name) + + if not fg.offline_store_config or not fg.offline_store_config.data_catalog_config: + raise RuntimeError(f"No metastore configured for FeatureGroup {fg.feature_group_name}.") + + catalog_config = fg.offline_store_config.data_catalog_config + disable_glue = catalog_config.disable_glue_table_creation or False + + features = [fd.feature_name for fd in fg.feature_definitions] + record_id = fg.record_identifier_feature_name + event_time_name = fg.event_time_feature_name + event_time_type = next( + (fd.feature_type for fd in fg.feature_definitions if fd.feature_name == event_time_name), + None + ) + + if feature_name_in_target and feature_name_in_target not in features: + raise ValueError(f"Feature {feature_name_in_target} not found in {fg.feature_group_name}") + + for feat in included_feature_names or []: + if feat not in features: + raise ValueError(f"Feature {feat} not found in {fg.feature_group_name}") + + if not included_feature_names: + included_feature_names = features.copy() + projected_feature_names = features.copy() + else: + projected_feature_names = included_feature_names.copy() + if record_id not in included_feature_names: + included_feature_names.append(record_id) + if event_time_name not in included_feature_names: + included_feature_names.append(event_time_name) + + return FeatureGroupToBeMerged( + features=features, + included_feature_names=included_feature_names, + projected_feature_names=projected_feature_names, + catalog=catalog_config.catalog if disable_glue else _DEFAULT_CATALOG, + database=catalog_config.database, + table_name=catalog_config.table_name, + record_identifier_feature_name=record_id, + event_time_identifier_feature=FeatureDefinition(event_time_name, FeatureTypeEnum(event_time_type)), + target_feature_name_in_base=target_feature_name_in_base, + table_type=TableType.FEATURE_GROUP, + feature_name_in_target=feature_name_in_target, + join_comparator=join_comparator, + join_type=join_type, + ) + + +@dataclass +class DatasetBuilder: + """DatasetBuilder definition. + + This class instantiates a DatasetBuilder object that comprises a base, a list of feature names, + an output path and a KMS key ID. + + Attributes: + _sagemaker_session (Session): Session instance to perform boto calls. + _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset. + _output_path (str): An S3 URI which stores the output .csv file. + _record_identifier_feature_name (str): A string representing the record identifier feature + if base is a DataFrame (default: None). + _event_time_identifier_feature_name (str): A string representing the event time identifier + feature if base is a DataFrame (default: None). + _included_feature_names (List[str]): A list of strings representing features to be + included in the output. If not set, all features will be included in the output. + (default: None). + _kms_key_id (str): A KMS key id. If set, will be used to encrypt the result file + (default: None). + _point_in_time_accurate_join (bool): A boolean representing if point-in-time join + is applied to the resulting dataframe when calling "to_dataframe". + When set to True, users can retrieve data using "row-level time travel" + according to the event times provided to the DatasetBuilder. This requires that the + entity dataframe with event times is submitted as the base in the constructor + (default: False). + _include_duplicated_records (bool): A boolean representing whether the resulting dataframe + when calling "to_dataframe" should include duplicated records (default: False). + _include_deleted_records (bool): A boolean representing whether the resulting + dataframe when calling "to_dataframe" should include deleted records (default: False). + _number_of_recent_records (int): An integer representing how many records will be + returned for each record identifier (default: 1). + _number_of_records (int): An integer representing the number of records that should be + returned in the resulting dataframe when calling "to_dataframe" (default: None). + _write_time_ending_timestamp (datetime.datetime): A datetime that represents the latest + write time for a record to be included in the resulting dataset. Records with a + newer write time will be omitted from the resulting dataset. (default: None). + _event_time_starting_timestamp (datetime.datetime): A datetime that represents the earliest + event time for a record to be included in the resulting dataset. Records + with an older event time will be omitted from the resulting dataset. (default: None). + _event_time_ending_timestamp (datetime.datetime): A datetime that represents the latest + event time for a record to be included in the resulting dataset. Records + with a newer event time will be omitted from the resulting dataset. (default: None). + _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of + FeatureGroupToBeMerged which will be joined to base (default: []). + _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the + type of event time identifier feature (default: None). + """ + + _sagemaker_session: Session + _base: Union[FeatureGroup, pd.DataFrame] + _output_path: str + _record_identifier_feature_name: str = None + _event_time_identifier_feature_name: str = None + _included_feature_names: List[str] = None + _kms_key_id: str = None + _event_time_identifier_feature_type: FeatureTypeEnum = None + + _point_in_time_accurate_join: bool = field(default=False, init=False) + _include_duplicated_records: bool = field(default=False, init=False) + _include_deleted_records: bool = field(default=False, init=False) + _number_of_recent_records: int = field(default=None, init=False) + _number_of_records: int = field(default=None, init=False) + _write_time_ending_timestamp: datetime.datetime = field(default=None, init=False) + _event_time_starting_timestamp: datetime.datetime = field(default=None, init=False) + _event_time_ending_timestamp: datetime.datetime = field(default=None, init=False) + _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = field(default_factory=list, init=False) + + def with_feature_group( + self, + feature_group: FeatureGroup, + target_feature_name_in_base: str = None, + included_feature_names: List[str] = None, + feature_name_in_target: str = None, + join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS, + join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN, + ) -> "DatasetBuilder": + """Join FeatureGroup with base. + + Args: + feature_group (FeatureGroup): A target FeatureGroup which will be joined to base. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as a join key (default: None). + included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + feature_name_in_target (str): A string representing the feature name in the target + feature group that will be compared to the target feature in the base feature group. + If None is provided, the record identifier feature will be used in the + SQL join. (default: None). + join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator + used when joining the target feature in the base feature group and the feature + in the target feature group. (default: JoinComparatorEnum.EQUALS). + join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between + the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN). + + Returns: + This DatasetBuilder object. + """ + self._feature_groups_to_be_merged.append( + construct_feature_group_to_be_merged( + feature_group, included_feature_names, target_feature_name_in_base, + feature_name_in_target, join_comparator, join_type, + ) + ) + return self + + def point_in_time_accurate_join(self) -> "DatasetBuilder": + """Enable point-in-time accurate join. + + Returns: + This DatasetBuilder object. + """ + self._point_in_time_accurate_join = True + return self + + def include_duplicated_records(self) -> "DatasetBuilder": + """Include duplicated records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_duplicated_records = True + return self + + def include_deleted_records(self) -> "DatasetBuilder": + """Include deleted records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_deleted_records = True + return self + + def with_number_of_recent_records_by_record_identifier(self, n: int) -> "DatasetBuilder": + """Set number_of_recent_records field with provided input. + + Args: + n (int): An int that how many recent records will be returned for + each record identifier. + + Returns: + This DatasetBuilder object. + """ + self._number_of_recent_records = n + return self + + def with_number_of_records_from_query_results(self, n: int) -> "DatasetBuilder": + """Set number_of_records field with provided input. + + Args: + n (int): An int that how many records will be returned. + + Returns: + This DatasetBuilder object. + """ + self._number_of_records = n + return self + + def as_of(self, timestamp: datetime.datetime) -> "DatasetBuilder": + """Set write_time_ending_timestamp field with provided input. + + Args: + timestamp (datetime.datetime): A datetime that all records' write time in dataset will + be before it. + + Returns: + This DatasetBuilder object. + """ + self._write_time_ending_timestamp = timestamp + return self + + def with_event_time_range( + self, + starting_timestamp: datetime.datetime = None, + ending_timestamp: datetime.datetime = None, + ) -> "DatasetBuilder": + """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs. + + Args: + starting_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be after it (default: None). + ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset + will be before it (default: None). + + Returns: + This DatasetBuilder object. + """ + self._event_time_starting_timestamp = starting_timestamp + self._event_time_ending_timestamp = ending_timestamp + return self + + def to_csv_file(self) -> tuple[str, str]: + """Get query string and result in .csv format file. + + Returns: + The S3 path of the .csv file. + The query string executed. + """ + if isinstance(self._base, pd.DataFrame): + return self._to_csv_from_dataframe() + if isinstance(self._base, FeatureGroup): + return self._to_csv_from_feature_group() + raise ValueError("Base must be either a FeatureGroup or a DataFrame.") + + def to_dataframe(self) -> tuple[pd.DataFrame, str]: + """Get query string and result in pandas.DataFrame. + + Returns: + The pandas.DataFrame object. + The query string executed. + """ + csv_file, query_string = self.to_csv_file() + df = download_csv_from_s3(csv_file, self._sagemaker_session, self._kms_key_id) + if "row_recent" in df.columns: + df = df.drop("row_recent", axis="columns") + return df, query_string + + + def _to_csv_from_dataframe(self) -> tuple[str, str]: + s3_folder, temp_table_name = upload_dataframe_to_s3( + self._base, self._output_path, self._sagemaker_session, self._kms_key_id + ) + self._create_temp_table(temp_table_name, s3_folder) + + base_features = list(self._base.columns) + event_time_dtype = str(self._base[self._event_time_identifier_feature_name].dtypes) + self._event_time_identifier_feature_type = FeatureTypeEnum( + _DTYPE_TO_FEATURE_TYPE.get(event_time_dtype, "String") + ) + + included = self._included_feature_names or base_features + fg_to_merge = FeatureGroupToBeMerged( + features=base_features, + included_feature_names=included, + projected_feature_names=included, + catalog=_DEFAULT_CATALOG, + database=_DEFAULT_DATABASE, + table_name=temp_table_name, + record_identifier_feature_name=self._record_identifier_feature_name, + event_time_identifier_feature=FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + table_type=TableType.DATA_FRAME, + ) + + query_string = self._construct_query_string(fg_to_merge) + result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + return self._extract_result(result) + + def _to_csv_from_feature_group(self) -> tuple[str, str]: + base_fg = construct_feature_group_to_be_merged(self._base, self._included_feature_names) + self._record_identifier_feature_name = base_fg.record_identifier_feature_name + self._event_time_identifier_feature_name = base_fg.event_time_identifier_feature.feature_name + self._event_time_identifier_feature_type = base_fg.event_time_identifier_feature.feature_type + + query_string = self._construct_query_string(base_fg) + result = self._run_query(query_string, base_fg.catalog, base_fg.database) + return self._extract_result(result) + + def _extract_result(self, query_result: dict) -> tuple[str, str]: + execution = query_result.get("QueryExecution", {}) + return ( + execution.get("ResultConfiguration", {}).get("OutputLocation"), + execution.get("Query"), + ) + + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: + return run_athena_query( + session=self._sagemaker_session, + catalog=catalog, + database=database, + query_string=query_string, + output_location=self._output_path, + kms_key=self._kms_key_id, + ) + + def _create_temp_table(self, temp_table_name: str, s3_folder: str): + columns = ", ".join( + f"{col} {_DTYPE_TO_ATHENA_TYPE.get(str(self._base[col].dtypes), 'STRING')}" + for col in self._base.columns + ) + serde = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"' + query = ( + f"CREATE EXTERNAL TABLE {temp_table_name} ({columns}) " + f"ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + f"WITH SERDEPROPERTIES ({serde}) LOCATION '{s3_folder}';" + ) + self._run_query(query, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + + + def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: + base_query = self._construct_table_query(base, "base") + query = f"WITH fg_base AS ({base_query})" + + for i, fg in enumerate(self._feature_groups_to_be_merged): + fg_query = self._construct_table_query(fg, str(i)) + query += f",\nfg_{i} AS ({fg_query})" + + selected = ", ".join(f"fg_base.{f}" for f in base.projected_feature_names) + selected_final = ", ".join(base.projected_feature_names) + + for i, fg in enumerate(self._feature_groups_to_be_merged): + selected += ", " + ", ".join( + f'fg_{i}."{f}" as "{f}.{i+1}"' for f in fg.projected_feature_names + ) + selected_final += ", " + ", ".join( + f'"{f}.{i+1}"' for f in fg.projected_feature_names + ) + + query += ( + f"\nSELECT {selected_final}\nFROM (\n" + f"SELECT {selected}, row_number() OVER (\n" + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n' + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC' + ) + + join_strings = [] + for i, fg in enumerate(self._feature_groups_to_be_merged): + if not fg.target_feature_name_in_base: + fg.target_feature_name_in_base = self._record_identifier_feature_name + elif fg.target_feature_name_in_base not in base.features: + raise ValueError(f"Feature {fg.target_feature_name_in_base} not found in base") + query += f', fg_{i}."{fg.event_time_identifier_feature.feature_name}" DESC' + join_strings.append(self._construct_join_condition(fg, str(i))) + + recent_where = "" + if self._number_of_recent_records is not None and self._number_of_recent_records >= 0: + recent_where = f"WHERE row_recent <= {self._number_of_recent_records}" + + query += f"\n) AS row_recent\nFROM fg_base{''.join(join_strings)}\n)\n{recent_where}" + + if self._number_of_records is not None and self._number_of_records >= 0: + query += f"\nLIMIT {self._number_of_records}" + + return query + + def _construct_table_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + included = ", ".join(f'table_{suffix}."{f}"' for f in fg.included_feature_names) + included_with_write = included + if fg.table_type is TableType.FEATURE_GROUP: + included_with_write += f', table_{suffix}."write_time"' + + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + + if self._include_duplicated_records and self._include_deleted_records: + return ( + f"SELECT {included}\n" + f'FROM "{fg.database}"."{fg.table_name}" table_{suffix}\n' + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, ["NOT is_deleted"]) + ) + + if fg.table_type is TableType.FEATURE_GROUP and self._include_deleted_records: + rank = f'ORDER BY origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + return ( + f"SELECT {included}\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}", origin_{suffix}."{event_time}"\n' + f"{rank}) AS row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"WHERE NOT is_deleted) AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, [f"row_{suffix} = 1"]) + ) + + if fg.table_type is TableType.FEATURE_GROUP: + dedup = self._construct_dedup_query(fg, suffix) + deleted = self._construct_deleted_query(fg, suffix) + rank_cond = ( + f'OR (table_{suffix}."{event_time}" = deleted_{suffix}."{event_time}" ' + f'AND table_{suffix}."api_invocation_time" > deleted_{suffix}."api_invocation_time")\n' + f'OR (table_{suffix}."{event_time}" = deleted_{suffix}."{event_time}" ' + f'AND table_{suffix}."api_invocation_time" = deleted_{suffix}."api_invocation_time" ' + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n' + ) + + if self._include_duplicated_records: + return ( + f"WITH {deleted}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\n" + f'FROM "{fg.database}"."{fg.table_name}" table_{suffix}\n' + f"LEFT JOIN deleted_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'WHERE deleted_{suffix}."{record_id}" IS NULL\n' + f"UNION ALL\n" + f"SELECT {included_with_write}\nFROM deleted_{suffix}\n" + f'JOIN "{fg.database}"."{fg.table_name}" table_{suffix}\n' + f'ON table_{suffix}."{record_id}" = deleted_{suffix}."{record_id}"\n' + f'AND (table_{suffix}."{event_time}" > deleted_{suffix}."{event_time}"\n{rank_cond})\n' + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + return ( + f"WITH {dedup},\n{deleted}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\nFROM table_{suffix}\n" + f"LEFT JOIN deleted_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'WHERE deleted_{suffix}."{record_id}" IS NULL\n' + f"UNION ALL\n" + f"SELECT {included_with_write}\nFROM deleted_{suffix}\n" + f"JOIN table_{suffix} ON table_{suffix}.\"{record_id}\" = deleted_{suffix}.\"{record_id}\"\n" + f'AND (table_{suffix}."{event_time}" > deleted_{suffix}."{event_time}"\n{rank_cond})\n' + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + dedup = self._construct_dedup_query(fg, suffix) + return ( + f"WITH {dedup}\n" + f"SELECT {included}\nFROM (\n" + f"SELECT {included_with_write}\nFROM table_{suffix}\n" + f") AS table_{suffix}\n" + + self._construct_where_query_string(suffix, fg.event_time_identifier_feature, []) + ) + + def _construct_dedup_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + rank = "" + is_fg = fg.table_type is TableType.FEATURE_GROUP + + if is_fg: + rank = f'ORDER BY origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + + where_conds = [] + if is_fg and self._write_time_ending_timestamp: + where_conds.append(self._construct_write_time_condition(f"origin_{suffix}")) + where_conds.extend(self._construct_event_time_conditions(f"origin_{suffix}", fg.event_time_identifier_feature)) + where_str = f"WHERE {' AND '.join(where_conds)}\n" if where_conds else "" + + dedup_where = f"WHERE dedup_row_{suffix} = 1\n" if is_fg else "" + + return ( + f"table_{suffix} AS (\n" + f"SELECT *\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}", origin_{suffix}."{event_time}"\n' + f"{rank}) AS dedup_row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"{where_str})\n{dedup_where})" + ) + + def _construct_deleted_query(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + record_id = fg.record_identifier_feature_name + event_time = fg.event_time_identifier_feature.feature_name + rank = f'ORDER BY origin_{suffix}."{event_time}" DESC' + + if fg.table_type is TableType.FEATURE_GROUP: + rank += f', origin_{suffix}."api_invocation_time" DESC, origin_{suffix}."write_time" DESC\n' + + write_cond = "" + if fg.table_type is TableType.FEATURE_GROUP and self._write_time_ending_timestamp: + write_cond = f" AND {self._construct_write_time_condition(f'origin_{suffix}')}\n" + + event_conds = "" + if self._event_time_starting_timestamp and self._event_time_ending_timestamp: + conds = self._construct_event_time_conditions(f"origin_{suffix}", fg.event_time_identifier_feature) + event_conds = "".join(f"AND {c}\n" for c in conds) + + return ( + f"deleted_{suffix} AS (\n" + f"SELECT *\nFROM (\n" + f"SELECT *, row_number() OVER (\n" + f'PARTITION BY origin_{suffix}."{record_id}"\n' + f"{rank}) AS deleted_row_{suffix}\n" + f'FROM "{fg.database}"."{fg.table_name}" origin_{suffix}\n' + f"WHERE is_deleted{write_cond}{event_conds})\n" + f"WHERE deleted_row_{suffix} = 1\n)" + ) + + def _construct_where_query_string( + self, suffix: str, event_time_feature: FeatureDefinition, conditions: List[str] + ) -> str: + self._validate_options() + + if isinstance(self._base, FeatureGroup) and self._write_time_ending_timestamp: + conditions.append(self._construct_write_time_condition(f"table_{suffix}")) + + conditions.extend(self._construct_event_time_conditions(f"table_{suffix}", event_time_feature)) + return f"WHERE {' AND '.join(conditions)}" if conditions else "" + + def _validate_options(self): + is_df_base = isinstance(self._base, pd.DataFrame) + no_joins = len(self._feature_groups_to_be_merged) == 0 + + if self._number_of_recent_records is not None and self._number_of_recent_records < 0: + raise ValueError("number_of_recent_records must be non-negative.") + if self._number_of_records is not None and self._number_of_records < 0: + raise ValueError("number_of_records must be non-negative.") + if is_df_base and no_joins: + if self._include_deleted_records: + raise ValueError("include_deleted_records() only works for FeatureGroup if no join.") + if self._include_duplicated_records: + raise ValueError("include_duplicated_records() only works for FeatureGroup if no join.") + if self._write_time_ending_timestamp: + raise ValueError("as_of() only works for FeatureGroup if no join.") + if self._point_in_time_accurate_join and no_joins: + raise ValueError("point_in_time_accurate_join() requires at least one join.") + + def _construct_event_time_conditions(self, table: str, event_time_feature: FeatureDefinition) -> List[str]: + cast_fn = "from_iso8601_timestamp" if event_time_feature.feature_type == FeatureTypeEnum.STRING else "from_unixtime" + conditions = [] + if self._event_time_starting_timestamp: + conditions.append( + f'{cast_fn}({table}."{event_time_feature.feature_name}") >= ' + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})" + ) + if self._event_time_ending_timestamp: + conditions.append( + f'{cast_fn}({table}."{event_time_feature.feature_name}") <= ' + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})" + ) + return conditions + + def _construct_write_time_condition(self, table: str) -> str: + ts = self._write_time_ending_timestamp.replace(microsecond=0) + return f'{table}."write_time" <= to_timestamp(\'{ts}\', \'yyyy-mm-dd hh24:mi:ss\')' + + def _construct_join_condition(self, fg: FeatureGroupToBeMerged, suffix: str) -> str: + target_feature = fg.feature_name_in_target or fg.record_identifier_feature_name + join = ( + f"\n{fg.join_type.value} fg_{suffix}\n" + f'ON fg_base."{fg.target_feature_name_in_base}" {fg.join_comparator.value} fg_{suffix}."{target_feature}"' + ) + + if self._point_in_time_accurate_join: + base_cast = "from_iso8601_timestamp" if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING else "from_unixtime" + fg_cast = "from_iso8601_timestamp" if fg.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING else "from_unixtime" + join += ( + f'\nAND {base_cast}(fg_base."{self._event_time_identifier_feature_name}") >= ' + f'{fg_cast}(fg_{suffix}."{fg.event_time_identifier_feature.feature_name}")' + ) + + return join diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py new file mode 100644 index 0000000000..32408e5585 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_definition.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Definitions for FeatureStore.""" +from __future__ import absolute_import + +from enum import Enum +from typing import Optional, Union + +from sagemaker.core.shapes import ( + FeatureDefinition, + CollectionConfig, + VectorConfig, +) + +class FeatureTypeEnum(Enum): + """Feature data types: Fractional, Integral, or String.""" + + FRACTIONAL = "Fractional" + INTEGRAL = "Integral" + STRING = "String" + +class CollectionTypeEnum(Enum): + """Collection types: List, Set, or Vector.""" + + LIST = "List" + SET = "Set" + VECTOR = "Vector" + +class ListCollectionType: + """List collection type.""" + + collection_type = CollectionTypeEnum.LIST.value + collection_config = None + +class SetCollectionType: + """Set collection type.""" + + collection_type = CollectionTypeEnum.SET.value + collection_config = None + +class VectorCollectionType: + """Vector collection type with dimension.""" + + collection_type = CollectionTypeEnum.VECTOR.value + + def __init__(self, dimension: int): + self.collection_config = CollectionConfig( + vector_config=VectorConfig(dimension=dimension) + ) + +CollectionType = Union[ListCollectionType, SetCollectionType, VectorCollectionType] + +def _create_feature_definition( + feature_name: str, + feature_type: FeatureTypeEnum, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Internal helper to create FeatureDefinition from collection type.""" + return FeatureDefinition( + feature_name=feature_name, + feature_type=feature_type.value, + collection_type=collection_type.collection_type if collection_type else None, + collection_config=collection_type.collection_config if collection_type else None, + ) + +def FractionalFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with Fractional type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.FRACTIONAL, collection_type) + +def IntegralFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with Integral type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.INTEGRAL, collection_type) + +def StringFeatureDefinition( + feature_name: str, + collection_type: Optional[CollectionType] = None, +) -> FeatureDefinition: + """Create a feature definition with String type.""" + return _create_feature_definition(feature_name, FeatureTypeEnum.STRING, collection_type) + +__all__ = [ + "FeatureDefinition", + "FeatureTypeEnum", + "CollectionTypeEnum", + "ListCollectionType", + "SetCollectionType", + "VectorCollectionType", + "FractionalFeatureDefinition", + "IntegralFeatureDefinition", + "StringFeatureDefinition", +] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py new file mode 100644 index 0000000000..f7f6523b8d --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -0,0 +1,488 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Utilities for working with FeatureGroups and FeatureStores.""" +import logging +import os +import time +from typing import Any, Dict, Sequence, Union + +import boto3 +import pandas as pd +from pandas import DataFrame, Series + +from sagemaker.mlops.feature_store import FeatureGroup as CoreFeatureGroup, FeatureGroup +from sagemaker.core.helper.session_helper import Session +from sagemaker.core.s3.client import S3Uploader, S3Downloader +from sagemaker.mlops.feature_store.dataset_builder import DatasetBuilder +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + ListCollectionType, + StringFeatureDefinition, +) +from sagemaker.mlops.feature_store.ingestion_manager_pandas import IngestionManagerPandas + +from sagemaker import utils + + +logger = logging.getLogger(__name__) + +# --- Constants --- + +_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = { + "Integral": "INT", + "Fractional": "FLOAT", + "String": "STRING", +} + +_DTYPE_TO_FEATURE_TYPE_MAP = { + "object": "String", + "string": "String", + "int64": "Integral", + "float64": "Fractional", +} + +_INTEGER_TYPES = {"int_", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"} +_FLOAT_TYPES = {"float_", "float16", "float32", "float64"} + + +def _get_athena_client(session: Session): + """Get Athena client from session.""" + return session.boto_session.client("athena", region_name=session.boto_region_name) + + +def _get_s3_client(session: Session): + """Get S3 client from session.""" + return session.boto_session.client("s3", region_name=session.boto_region_name) + + +def start_query_execution( + session: Session, + catalog: str, + database: str, + query_string: str, + output_location: str, + kms_key: str = None, + workgroup: str = None, +) -> Dict[str, str]: + """Start Athena query execution. + + Args: + session: Session instance for boto calls. + catalog: Name of the data catalog. + database: Name of the database. + query_string: SQL query string. + output_location: S3 URI for query results. + kms_key: KMS key for encryption (default: None). + workgroup: Athena workgroup name (default: None). + + Returns: + Response dict with QueryExecutionId. + """ + kwargs = { + "QueryString": query_string, + "QueryExecutionContext": {"Catalog": catalog, "Database": database}, + "ResultConfiguration": {"OutputLocation": output_location}, + } + if kms_key: + kwargs["ResultConfiguration"]["EncryptionConfiguration"] = { + "EncryptionOption": "SSE_KMS", + "KmsKey": kms_key, + } + if workgroup: + kwargs["WorkGroup"] = workgroup + return _get_athena_client(session).start_query_execution(**kwargs) + + +def get_query_execution(session: Session, query_execution_id: str) -> Dict[str, Any]: + """Get execution status of an Athena query. + + Args: + session: Session instance for boto calls. + query_execution_id: The query execution ID. + + Returns: + Response dict from Athena. + """ + return _get_athena_client(session).get_query_execution(QueryExecutionId=query_execution_id) + + +def wait_for_athena_query(session: Session, query_execution_id: str, poll: int = 5): + """Wait for Athena query to finish. + + Args: + session: Session instance for boto calls. + query_execution_id: The query execution ID. + poll: Polling interval in seconds (default: 5). + """ + while True: + state = get_query_execution(session, query_execution_id)["QueryExecution"]["Status"]["State"] + if state in ("SUCCEEDED", "FAILED"): + logger.info("Query %s %s.", query_execution_id, state.lower()) + break + logger.info("Query %s is being executed.", query_execution_id) + time.sleep(poll) + + +def run_athena_query( + session: Session, + catalog: str, + database: str, + query_string: str, + output_location: str, + kms_key: str = None, +) -> Dict[str, Any]: + """Execute Athena query, wait for completion, and return result. + + Args: + session: Session instance for boto calls. + catalog: Name of the data catalog. + database: Name of the database. + query_string: SQL query string. + output_location: S3 URI for query results. + kms_key: KMS key for encryption (default: None). + + Returns: + Query execution result dict. + + Raises: + RuntimeError: If query fails. + """ + response = start_query_execution( + session=session, + catalog=catalog, + database=database, + query_string=query_string, + output_location=output_location, + kms_key=kms_key, + ) + query_id = response["QueryExecutionId"] + wait_for_athena_query(session, query_id) + + result = get_query_execution(session, query_id) + if result["QueryExecution"]["Status"]["State"] != "SUCCEEDED": + raise RuntimeError(f"Athena query {query_id} failed.") + return result + + +def download_athena_query_result( + session: Session, + bucket: str, + prefix: str, + query_execution_id: str, + filename: str, +): + """Download query result file from S3. + + Args: + session: Session instance for boto calls. + bucket: S3 bucket name. + prefix: S3 key prefix. + query_execution_id: The query execution ID. + filename: Local filename to save to. + """ + _get_s3_client(session).download_file( + Bucket=bucket, + Key=f"{prefix}/{query_execution_id}.csv", + Filename=filename, + ) + + +def upload_dataframe_to_s3( + data_frame: DataFrame, + output_path: str, + session: Session, + kms_key: str = None, +) -> tuple[str, str]: + """Upload DataFrame to S3 as CSV. + + Args: + data_frame: DataFrame to upload. + output_path: S3 URI base path. + session: Session instance for boto calls. + kms_key: KMS key for encryption (default: None). + + Returns: + Tuple of (s3_folder, temp_table_name). + """ + + temp_id = utils.unique_name_from_base("dataframe-base") + local_file = f"{temp_id}.csv" + s3_folder = os.path.join(output_path, temp_id) + + data_frame.to_csv(local_file, index=False, header=False) + S3Uploader.upload( + local_path=local_file, + desired_s3_uri=s3_folder, + sagemaker_session=session, + kms_key=kms_key, + ) + os.remove(local_file) + + table_name = f'dataframe_{temp_id.replace("-", "_")}' + return s3_folder, table_name + + +def download_csv_from_s3( + s3_uri: str, + session: Session, + kms_key: str = None, +) -> DataFrame: + """Download CSV from S3 and return as DataFrame. + + Args: + s3_uri: S3 URI of the CSV file. + session: Session instance for boto calls. + kms_key: KMS key for decryption (default: None). + + Returns: + DataFrame with CSV contents. + """ + + S3Downloader.download( + s3_uri=s3_uri, + local_path="./", + kms_key=kms_key, + sagemaker_session=session, + ) + + local_file = s3_uri.split("/")[-1] + df = pd.read_csv(local_file) + os.remove(local_file) + + metadata_file = f"{local_file}.metadata" + if os.path.exists(metadata_file): + os.remove(metadata_file) + + return df + + +def get_session_from_role(region: str, assume_role: str = None) -> Session: + """Get a Session from a region and optional IAM role. + + Args: + region: AWS region name. + assume_role: IAM role ARN to assume (default: None). + + Returns: + Session instance. + """ + boto_session = boto3.Session(region_name=region) + + if assume_role: + sts = boto_session.client("sts", region_name=region) + credentials = sts.assume_role( + RoleArn=assume_role, + RoleSessionName="SagemakerExecution", + )["Credentials"] + + boto_session = boto3.Session( + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + return Session( + boto_session=boto_session, + sagemaker_client=boto_session.client("sagemaker"), + sagemaker_runtime_client=boto_session.client("sagemaker-runtime"), + sagemaker_featurestore_runtime_client=boto_session.client("sagemaker-featurestore-runtime"), + ) + + +# --- FeatureDefinition Functions --- + +def _is_collection_column(series: Series, sample_size: int = 1000) -> bool: + """Check if column contains list/set values.""" + sample = series.head(sample_size).dropna() + return sample.apply(lambda x: isinstance(x, (list, set))).any() + + +def _generate_feature_definition( + series: Series, + online_storage_type: str = None, +) -> FeatureDefinition: + """Generate a FeatureDefinition from a pandas Series.""" + dtype = str(series.dtype) + collection_type = None + + if online_storage_type == "InMemory" and _is_collection_column(series): + collection_type = ListCollectionType() + + if dtype in _INTEGER_TYPES: + return IntegralFeatureDefinition(series.name, collection_type) + if dtype in _FLOAT_TYPES: + return FractionalFeatureDefinition(series.name, collection_type) + return StringFeatureDefinition(series.name, collection_type) + + +def load_feature_definitions_from_dataframe( + data_frame: DataFrame, + online_storage_type: str = None, +) -> Sequence[FeatureDefinition]: + """Infer FeatureDefinitions from DataFrame dtypes. + + Column name is used as feature name. Feature type is inferred from the dtype + of the column. Integer dtypes are mapped to Integral feature type. Float dtypes + are mapped to Fractional feature type. All other dtypes are mapped to String. + + For IN_MEMORY online_storage_type, collection type columns within DataFrame + will be inferred as List instead of String. + + Args: + data_frame: DataFrame to infer features from. + online_storage_type: "Standard" or "InMemory" (default: None). + + Returns: + List of FeatureDefinition objects. + """ + return [ + _generate_feature_definition(data_frame[col], online_storage_type) + for col in data_frame.columns + ] + + +# --- FeatureGroup Functions --- + +def create_athena_query(feature_group_name: str, session: Session): + """Create an AthenaQuery for a FeatureGroup. + + Args: + feature_group_name: Name of the FeatureGroup. + session: Session instance for Athena boto calls. + + Returns: + AthenaQuery initialized with data catalog config. + + Raises: + RuntimeError: If no metastore is configured. + """ + from sagemaker.mlops.feature_store.athena_query import AthenaQuery + + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + + if not fg.offline_store_config or not fg.offline_store_config.data_catalog_config: + raise RuntimeError("No metastore is configured with this feature group.") + + catalog_config = fg.offline_store_config.data_catalog_config + disable_glue = catalog_config.disable_glue_table_creation or False + + return AthenaQuery( + catalog=catalog_config.catalog if disable_glue else "AwsDataCatalog", + database=catalog_config.database, + table_name=catalog_config.table_name, + sagemaker_session=session, + ) + + +def as_hive_ddl( + feature_group_name: str, + database: str = "sagemaker_featurestore", + table_name: str = None, +) -> str: + """Generate Hive DDL for a FeatureGroup's offline store table. + + Schema of the table is generated based on the feature definitions. Columns are named + after feature name and data-type are inferred based on feature type. Integral feature + type is mapped to INT data-type. Fractional feature type is mapped to FLOAT data-type. + String feature type is mapped to STRING data-type. + + Args: + feature_group_name: Name of the FeatureGroup. + database: Hive database name (default: "sagemaker_featurestore"). + table_name: Hive table name (default: feature_group_name). + + Returns: + CREATE EXTERNAL TABLE DDL string. + """ + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + table_name = table_name or feature_group_name + resolved_output_s3_uri = fg.offline_store_config.s3_storage_config.resolved_output_s3_uri + + ddl = f"CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" + for fd in fg.feature_definitions: + ddl += f" {fd.feature_name} {_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP.get(fd.feature_type)}\n" + ddl += " write_time TIMESTAMP\n" + ddl += " event_time TIMESTAMP\n" + ddl += " is_deleted BOOLEAN\n" + ddl += ")\n" + ddl += ( + "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" + " STORED AS\n" + " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" + " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" + f"LOCATION '{resolved_output_s3_uri}'" + ) + return ddl + + +def ingest_dataframe( + feature_group_name: str, + data_frame: DataFrame, + max_workers: int = 1, + max_processes: int = 1, + wait: bool = True, + timeout: Union[int, float] = None, +): + """Ingest a pandas DataFrame to a FeatureGroup. + + Args: + feature_group_name: Name of the FeatureGroup. + data_frame: DataFrame to ingest. + max_workers: Threads per process (default: 1). + max_processes: Number of processes (default: 1). + wait: Wait for ingestion to complete (default: True). + timeout: Timeout in seconds (default: None). + + Returns: + IngestionManagerPandas instance. + + Raises: + ValueError: If max_workers or max_processes <= 0. + """ + + if max_processes <= 0: + raise ValueError("max_processes must be greater than 0.") + if max_workers <= 0: + raise ValueError("max_workers must be greater than 0.") + + fg = CoreFeatureGroup.get(feature_group_name=feature_group_name) + feature_definitions = {fd.feature_name: fd.feature_type for fd in fg.feature_definitions} + + manager = IngestionManagerPandas( + feature_group_name=feature_group_name, + feature_definitions=feature_definitions, + max_workers=max_workers, + max_processes=max_processes, + ) + manager.run(data_frame=data_frame, wait=wait, timeout=timeout) + return manager + +def create_dataset( + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + session: Session, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: Sequence[str] = None, + kms_key_id: str = None, +) -> DatasetBuilder: + """Create a DatasetBuilder for generating a Dataset.""" + if isinstance(base, pd.DataFrame): + if not record_identifier_feature_name or not event_time_identifier_feature_name: + raise ValueError( + "record_identifier_feature_name and event_time_identifier_feature_name " + "are required when base is a DataFrame." + ) + return DatasetBuilder( + _sagemaker_session=session, + _base=base, + _output_path=output_path, + _record_identifier_feature_name=record_identifier_feature_name, + _event_time_identifier_feature_name=event_time_identifier_feature_name, + _included_feature_names=included_feature_names, + _kms_key_id=kms_key_id, + ) + diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py new file mode 100644 index 0000000000..60d022dab1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py @@ -0,0 +1,321 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Multi-threaded data ingestion for FeatureStore using SageMaker Core.""" +import logging +import math +import signal +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field +from multiprocessing.pool import AsyncResult +from typing import Any, Dict, Iterable, List, Sequence, Union + +import pandas as pd +from pandas import DataFrame +from pandas.api.types import is_list_like +from pathos.multiprocessing import ProcessingPool + +from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup +from sagemaker.core.shapes import FeatureValue + +logger = logging.getLogger(__name__) + + +class IngestionError(Exception): + """Exception raised for errors during ingestion. + + Attributes: + failed_rows: List of row indices that failed to ingest. + message: Error message. + """ + + def __init__(self, failed_rows: List[int], message: str): + self.failed_rows = failed_rows + self.message = message + super().__init__(self.message) + + +@dataclass +class IngestionManagerPandas: + """Class to manage the multi-threaded data ingestion process. + + This class will manage the data ingestion process which is multi-threaded. + + Attributes: + feature_group_name (str): name of the Feature Group. + feature_definitions (Dict[str, Dict[Any, Any]]): dictionary of feature definitions + where the key is the feature name and the value is the FeatureDefinition. + The FeatureDefinition contains the data type of the feature. + max_workers (int): number of threads to create. + max_processes (int): number of processes to create. Each process spawns + ``max_workers`` threads. + """ + + feature_group_name: str + feature_definitions: Dict[str, Dict[Any, Any]] + max_workers: int = 1 + max_processes: int = 1 + _async_result: AsyncResult = field(default=None, init=False) + _processing_pool: ProcessingPool = field(default=None, init=False) + _failed_indices: List[int] = field(default_factory=list, init=False) + + @property + def failed_rows(self) -> List[int]: + """Get rows that failed to ingest. + + Returns: + List of row indices that failed to be ingested. + """ + return self._failed_indices + + def run( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + wait: bool = True, + timeout: Union[int, float] = None, + ): + """Start the ingestion process. + + Args: + data_frame (DataFrame): source DataFrame to be ingested. + target_stores (List[str]): list of target stores ("OnlineStore", "OfflineStore"). + If None, the default target store is used. + wait (bool): whether to wait for the ingestion to finish or not. + timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised + if timeout is reached. + """ + if self.max_workers == 1 and self.max_processes == 1: + self._run_single_process_single_thread(data_frame=data_frame, target_stores=target_stores) + else: + self._run_multi_process(data_frame=data_frame, target_stores=target_stores, wait=wait, timeout=timeout) + + def wait(self, timeout: Union[int, float] = None): + """Wait for the ingestion process to finish. + + Args: + timeout (Union[int, float]): ``concurrent.futures.TimeoutError`` will be raised + if timeout is reached. + """ + try: + results = self._async_result.get(timeout=timeout) + except KeyboardInterrupt as e: + self._processing_pool.terminate() + self._processing_pool.close() + self._processing_pool.clear() + raise e + else: + self._processing_pool.close() + self._processing_pool.clear() + + self._failed_indices = [idx for failed in results for idx in failed] + + if self._failed_indices: + raise IngestionError( + self._failed_indices, + f"Failed to ingest some data into FeatureGroup {self.feature_group_name}", + ) + + def _run_single_process_single_thread( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + ): + """Ingest utilizing a single process and a single thread.""" + logger.info("Started single-threaded ingestion for %d rows", len(data_frame)) + failed_rows = [] + + fg = CoreFeatureGroup(feature_group_name=self.feature_group_name) + + for row in data_frame.itertuples(): + self._ingest_row( + data_frame=data_frame, + row=row, + feature_group=fg, + feature_definitions=self.feature_definitions, + failed_rows=failed_rows, + target_stores=target_stores, + ) + + self._failed_indices = failed_rows + if self._failed_indices: + raise IngestionError( + self._failed_indices, + f"Failed to ingest some data into FeatureGroup {self.feature_group_name}", + ) + + def _run_multi_process( + self, + data_frame: DataFrame, + target_stores: List[str] = None, + wait: bool = True, + timeout: Union[int, float] = None, + ): + """Start the ingestion process with the specified number of processes.""" + batch_size = math.ceil(data_frame.shape[0] / self.max_processes) + + args = [] + for i in range(self.max_processes): + start_index = min(i * batch_size, data_frame.shape[0]) + end_index = min(i * batch_size + batch_size, data_frame.shape[0]) + args.append(( + self.max_workers, + self.feature_group_name, + self.feature_definitions, + data_frame[start_index:end_index], + target_stores, + start_index, + timeout, + )) + + def init_worker(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + + self._processing_pool = ProcessingPool(self.max_processes, init_worker) + self._processing_pool.restart(force=True) + + self._async_result = self._processing_pool.amap( + lambda x: IngestionManagerPandas._run_multi_threaded(*x), + args, + ) + + if wait: + self.wait(timeout=timeout) + + @staticmethod + def _run_multi_threaded( + max_workers: int, + feature_group_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + data_frame: DataFrame, + target_stores: List[str] = None, + row_offset: int = 0, + timeout: Union[int, float] = None, + ) -> List[int]: + """Start multi-threaded ingestion within a single process.""" + executor = ThreadPoolExecutor(max_workers=max_workers) + batch_size = math.ceil(data_frame.shape[0] / max_workers) + + futures = {} + for i in range(max_workers): + start_index = min(i * batch_size, data_frame.shape[0]) + end_index = min(i * batch_size + batch_size, data_frame.shape[0]) + future = executor.submit( + IngestionManagerPandas._ingest_single_batch, + data_frame=data_frame, + feature_group_name=feature_group_name, + feature_definitions=feature_definitions, + start_index=start_index, + end_index=end_index, + target_stores=target_stores, + ) + futures[future] = (start_index + row_offset, end_index + row_offset) + + failed_indices = [] + for future in as_completed(futures, timeout=timeout): + start, end = futures[future] + failed_rows = future.result() + if not failed_rows: + logger.info("Successfully ingested row %d to %d", start, end) + failed_indices.extend(failed_rows) + + executor.shutdown(wait=False) + return failed_indices + + @staticmethod + def _ingest_single_batch( + data_frame: DataFrame, + feature_group_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + start_index: int, + end_index: int, + target_stores: List[str] = None, + ) -> List[int]: + """Ingest a single batch of DataFrame rows into FeatureStore.""" + logger.info("Started ingesting index %d to %d", start_index, end_index) + failed_rows = [] + + fg = CoreFeatureGroup(feature_group_name=feature_group_name) + + for row in data_frame[start_index:end_index].itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=data_frame, + row=row, + feature_group=fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=target_stores, + ) + + return failed_rows + + @staticmethod + def _ingest_row( + data_frame: DataFrame, + row: Iterable, + feature_group: CoreFeatureGroup, + feature_definitions: Dict[str, Dict[Any, Any]], + failed_rows: List[int], + target_stores: List[str] = None, + ): + """Ingest a single DataFrame row into FeatureStore using SageMaker Core.""" + try: + record = [] + for index in range(1, len(row)): + feature_name = data_frame.columns[index - 1] + feature_value = row[index] + + if not IngestionManagerPandas._feature_value_is_not_none(feature_value): + continue + + if IngestionManagerPandas._is_feature_collection_type(feature_name, feature_definitions): + record.append(FeatureValue( + feature_name=feature_name, + value_as_string_list=IngestionManagerPandas._convert_to_string_list(feature_value), + )) + else: + record.append(FeatureValue( + feature_name=feature_name, + value_as_string=str(feature_value), + )) + + # Use SageMaker Core's put_record directly + feature_group.put_record( + record=record, + target_stores=target_stores, + ) + + except Exception as e: + logger.error("Failed to ingest row %d: %s", row[0], e) + failed_rows.append(row[0]) + + @staticmethod + def _is_feature_collection_type( + feature_name: str, + feature_definitions: Dict[str, Dict[Any, Any]], + ) -> bool: + """Check if the feature is a collection type.""" + feature_def = feature_definitions.get(feature_name) + if feature_def: + return feature_def.get("CollectionType") is not None + return False + + @staticmethod + def _feature_value_is_not_none(feature_value: Any) -> bool: + """Check if the feature value is not None. + + For Collection Type features, we check if the value is not None. + For Scalar values, we use pd.notna() to keep the behavior same. + """ + if not is_list_like(feature_value): + return pd.notna(feature_value) + return feature_value is not None + + @staticmethod + def _convert_to_string_list(feature_value: List[Any]) -> List[str]: + """Convert a list of feature values to a list of strings.""" + if not is_list_like(feature_value): + raise ValueError( + f"Invalid feature value: {feature_value} for a collection type feature " + f"must be an Array, but was {type(feature_value)}" + ) + return [str(v) if v is not None else None for v in feature_value] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py new file mode 100644 index 0000000000..f264059eb3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/inputs.py @@ -0,0 +1,60 @@ +"""Enums for FeatureStore operations.""" +from enum import Enum + +class TargetStoreEnum(Enum): + """Store types for put_record.""" + ONLINE_STORE = "OnlineStore" + OFFLINE_STORE = "OfflineStore" + +class OnlineStoreStorageTypeEnum(Enum): + """Storage types for online store.""" + STANDARD = "Standard" + IN_MEMORY = "InMemory" + +class TableFormatEnum(Enum): + """Offline store table formats.""" + GLUE = "Glue" + ICEBERG = "Iceberg" + +class ResourceEnum(Enum): + """Resource types for search.""" + FEATURE_GROUP = "FeatureGroup" + FEATURE_METADATA = "FeatureMetadata" + +class SearchOperatorEnum(Enum): + """Search operators.""" + AND = "And" + OR = "Or" + +class SortOrderEnum(Enum): + """Sort orders.""" + ASCENDING = "Ascending" + DESCENDING = "Descending" + +class FilterOperatorEnum(Enum): + """Filter operators.""" + EQUALS = "Equals" + NOT_EQUALS = "NotEquals" + GREATER_THAN = "GreaterThan" + GREATER_THAN_OR_EQUAL_TO = "GreaterThanOrEqualTo" + LESS_THAN = "LessThan" + LESS_THAN_OR_EQUAL_TO = "LessThanOrEqualTo" + CONTAINS = "Contains" + EXISTS = "Exists" + NOT_EXISTS = "NotExists" + IN = "In" + +class DeletionModeEnum(Enum): + """Deletion modes for delete_record.""" + SOFT_DELETE = "SoftDelete" + HARD_DELETE = "HardDelete" + +class ExpirationTimeResponseEnum(Enum): + """ExpiresAt response toggle.""" + DISABLED = "Disabled" + ENABLED = "Enabled" + +class ThroughputModeEnum(Enum): + """Throughput modes for feature group.""" + ON_DEMAND = "OnDemand" + PROVISIONED = "Provisioned" \ No newline at end of file From 193d16fa2c61440ed1cd4d906e64e2726229ce6e Mon Sep 17 00:00:00 2001 From: adishaa Date: Fri, 16 Jan 2026 11:34:01 -0800 Subject: [PATCH 2/8] Add feature store tests --- .../mlops/feature_store/MIGRATION_GUIDE.md | 513 ++++++++++++++++++ .../sagemaker/mlops/feature_store/__init__.py | 8 +- .../mlops/feature_store/dataset_builder.py | 45 +- .../mlops/feature_store/feature_utils.py | 31 +- .../feature_store/ingestion_manager_pandas.py | 19 +- sagemaker-mlops/tests/__init__.py | 0 sagemaker-mlops/tests/unit/__init__.py | 0 .../tests/unit/sagemaker/__init__.py | 0 .../tests/unit/sagemaker/mlops/__init__.py | 0 .../sagemaker/mlops/feature_store/__init__.py | 2 + .../sagemaker/mlops/feature_store/conftest.py | 80 +++ .../mlops/feature_store/test_athena_query.py | 113 ++++ .../feature_store/test_dataset_builder.py | 345 ++++++++++++ .../feature_store/test_feature_definition.py | 126 +++++ .../mlops/feature_store/test_feature_utils.py | 202 +++++++ .../test_ingestion_manager_pandas.py | 256 +++++++++ .../mlops/feature_store/test_inputs.py | 109 ++++ 17 files changed, 1802 insertions(+), 47 deletions(-) create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md create mode 100644 sagemaker-mlops/tests/__init__.py create mode 100644 sagemaker-mlops/tests/unit/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md b/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md new file mode 100644 index 0000000000..40942fa6f3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/MIGRATION_GUIDE.md @@ -0,0 +1,513 @@ +# SageMaker FeatureStore V2 to V3 Migration Guide + +## Overview + +V3 uses **sagemaker-core** as the foundation, which provides: +- Pydantic-based shapes with automatic serialization +- Resource classes that manage boto clients internally +- No need for explicit Session management in most cases + +## File Mapping + +| V2 File | V3 File | Notes | +|---------|---------|-------| +| `feature_group.py` | Re-exported from `sagemaker_core.main.resources` | No wrapper class needed | +| `feature_store.py` | Re-exported from `sagemaker_core.main.resources` | `FeatureStore.search()` available | +| `feature_definition.py` | `feature_definition.py` | Helper factories retained | +| `feature_utils.py` | `feature_utils.py` | Standalone functions | +| `inputs.py` | `inputs.py` | Enums only (shapes from core) | +| `dataset_builder.py` | `dataset_builder.py` | Converted to dataclass | +| N/A | `athena_query.py` | Extracted from feature_group.py | +| N/A | `ingestion_manager_pandas.py` | Extracted from feature_group.py | + +--- + +## FeatureGroup Operations + +### Create FeatureGroup + +**V2:** +```python +from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.session import Session + +session = Session() +fg = FeatureGroup(name="my-fg", sagemaker_session=session) +fg.load_feature_definitions(data_frame=df) +fg.create( + s3_uri="s3://bucket/prefix", + record_identifier_name="id", + event_time_feature_name="ts", + role_arn=role, + enable_online_store=True, +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ( + FeatureGroup, + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + load_feature_definitions_from_dataframe, +) + +feature_defs = load_feature_definitions_from_dataframe(df) + +FeatureGroup.create( + feature_group_name="my-fg", + feature_definitions=feature_defs, + record_identifier_feature_name="id", + event_time_feature_name="ts", + role_arn=role, + online_store_config=OnlineStoreConfig(enable_online_store=True), + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/prefix") + ), +) +``` + +### Get/Describe FeatureGroup + +**V2:** +```python +fg = FeatureGroup(name="my-fg", sagemaker_session=session) +response = fg.describe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup + +fg = FeatureGroup.get(feature_group_name="my-fg") +# fg is now a typed object with attributes: +# fg.feature_group_name, fg.feature_definitions, fg.offline_store_config, etc. +``` + +### Delete FeatureGroup + +**V2:** +```python +fg.delete() +``` + +**V3:** +```python +FeatureGroup(feature_group_name="my-fg").delete() +# or +fg = FeatureGroup.get(feature_group_name="my-fg") +fg.delete() +``` + +### Update FeatureGroup + +**V2:** +```python +fg.update( + feature_additions=[FeatureDefinition("new_col", FeatureTypeEnum.STRING)], + throughput_config=ThroughputConfigUpdate(mode=ThroughputModeEnum.ON_DEMAND), +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup, ThroughputConfig + +fg = FeatureGroup.get(feature_group_name="my-fg") +fg.update( + feature_additions=[{"FeatureName": "new_col", "FeatureType": "String"}], + throughput_config=ThroughputConfig(throughput_mode="OnDemand"), +) +``` + +--- + +## Record Operations + +### Put Record + +**V2:** +```python +from sagemaker.feature_store.inputs import FeatureValue + +fg.put_record( + record=[ + FeatureValue(feature_name="id", value_as_string="123"), + FeatureValue(feature_name="name", value_as_string="John"), + ], + target_stores=[TargetStoreEnum.ONLINE_STORE], +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup, FeatureValue + +FeatureGroup(feature_group_name="my-fg").put_record( + record=[ + FeatureValue(feature_name="id", value_as_string="123"), + FeatureValue(feature_name="name", value_as_string="John"), + ], + target_stores=["OnlineStore"], # strings, not enums +) +``` + +### Get Record + +**V2:** +```python +response = fg.get_record(record_identifier_value_as_string="123") +``` + +**V3:** +```python +response = FeatureGroup(feature_group_name="my-fg").get_record( + record_identifier_value_as_string="123" +) +``` + +### Delete Record + +**V2:** +```python +fg.delete_record( + record_identifier_value_as_string="123", + event_time="2024-01-15T00:00:00Z", + deletion_mode=DeletionModeEnum.SOFT_DELETE, +) +``` + +**V3:** +```python +FeatureGroup(feature_group_name="my-fg").delete_record( + record_identifier_value_as_string="123", + event_time="2024-01-15T00:00:00Z", + deletion_mode="SoftDelete", # string, not enum +) +``` + +### Batch Get Record + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore +from sagemaker.feature_store.inputs import Identifier + +fs = FeatureStore(sagemaker_session=session) +response = fs.batch_get_record( + identifiers=[ + Identifier(feature_group_name="my-fg", record_identifiers_value_as_string=["123", "456"]) + ] +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureGroup + +response = FeatureGroup(feature_group_name="my-fg").batch_get_record( + identifiers=[ + {"FeatureGroupName": "my-fg", "RecordIdentifiersValueAsString": ["123", "456"]} + ] +) +``` + +--- + +## DataFrame Ingestion + +**V2:** +```python +fg.ingest(data_frame=df, max_workers=4, max_processes=2, wait=True) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ingest_dataframe + +manager = ingest_dataframe( + feature_group_name="my-fg", + data_frame=df, + max_workers=4, + max_processes=2, + wait=True, +) +# Access failed rows: manager.failed_rows +``` + +--- + +## Athena Query + +**V2:** +```python +query = fg.athena_query() +query.run(query_string="SELECT * FROM ...", output_location="s3://...") +query.wait() +df = query.as_dataframe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import create_athena_query + +query = create_athena_query("my-fg", session) +query.run(query_string="SELECT * FROM ...", output_location="s3://...") +query.wait() +df = query.as_dataframe() +``` + +--- + +## Hive DDL Generation + +**V2:** +```python +ddl = fg.as_hive_ddl(database="mydb", table_name="mytable") +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import as_hive_ddl + +ddl = as_hive_ddl("my-fg", database="mydb", table_name="mytable") +``` + +--- + +## Feature Definitions + +**V2:** +```python +fg.load_feature_definitions(data_frame=df) +# Modifies fg.feature_definitions in place +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import load_feature_definitions_from_dataframe + +defs = load_feature_definitions_from_dataframe(df) +# Returns list, doesn't modify any object +``` + +### Using Helper Factories + +**V2 & V3 (same):** +```python +from sagemaker.mlops.feature_store import ( + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, +) + +defs = [ + IntegralFeatureDefinition("id"), + StringFeatureDefinition("name"), + FractionalFeatureDefinition("embedding", VectorCollectionType(128)), +] +``` + +--- + +## Search + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore +from sagemaker.feature_store.inputs import Filter, ResourceEnum + +fs = FeatureStore(sagemaker_session=session) +response = fs.search( + resource=ResourceEnum.FEATURE_GROUP, + filters=[Filter(name="FeatureGroupName", value="my-prefix", operator=FilterOperatorEnum.CONTAINS)], +) +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureStore, Filter, SearchExpression + +response = FeatureStore.search( + resource="FeatureGroup", + search_expression=SearchExpression( + filters=[Filter(name="FeatureGroupName", value="my-prefix", operator="Contains")] + ), +) +``` + +--- + +## Feature Metadata + +**V2:** +```python +fg.describe_feature_metadata(feature_name="my-feature") +fg.update_feature_metadata(feature_name="my-feature", description="Updated desc") +fg.list_parameters_for_feature_metadata(feature_name="my-feature") +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import FeatureMetadata + +# Get metadata +metadata = FeatureMetadata.get(feature_group_name="my-fg", feature_name="my-feature") +print(metadata.description) +print(metadata.parameters) + +# Update metadata +metadata.update(description="Updated desc") +``` + +--- + +## Dataset Builder + +**V2:** +```python +from sagemaker.feature_store.feature_store import FeatureStore + +fs = FeatureStore(sagemaker_session=session) +builder = fs.create_dataset( + base=fg, + output_path="s3://bucket/output", +) +builder.with_feature_group(other_fg, target_feature_name_in_base="id") +builder.point_in_time_accurate_join() +df, query = builder.to_dataframe() +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import create_dataset, FeatureGroup + +fg = FeatureGroup.get(feature_group_name="my-fg") +other_fg = FeatureGroup.get(feature_group_name="other-fg") + +builder = create_dataset( + base=fg, + output_path="s3://bucket/output", + session=session, +) +builder.with_feature_group(other_fg, target_feature_name_in_base="id") +builder.point_in_time_accurate_join() +df, query = builder.to_dataframe() +``` + +--- + +## Config Objects (Shapes) + +**V2:** +```python +from sagemaker.feature_store.inputs import ( + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + TtlDuration, +) + +config = OnlineStoreConfig(enable_online_store=True, ttl_duration=TtlDuration(unit="Hours", value=24)) +config.to_dict() # Manual serialization required +``` + +**V3:** +```python +from sagemaker.mlops.feature_store import ( + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + TtlDuration, +) + +config = OnlineStoreConfig(enable_online_store=True, ttl_duration=TtlDuration(unit="Hours", value=24)) +# No to_dict() needed - Pydantic handles serialization automatically +``` + +--- + +## Key Differences Summary + +| Aspect | V2 | V3 | +|--------|----|----| +| **Session** | Required for most operations | Optional - core manages clients | +| **FeatureGroup** | Wrapper class with session | Direct core resource class | +| **Shapes** | `@attr.s` with `to_dict()` | Pydantic with auto-serialization | +| **Enums** | `TargetStoreEnum.ONLINE_STORE.value` | Just use strings: `"OnlineStore"` | +| **Methods** | Instance methods on FeatureGroup | Standalone functions + core methods | +| **Ingestion** | `fg.ingest(df)` | `ingest_dataframe(name, df)` | +| **Athena** | `fg.athena_query()` | `create_athena_query(name, session)` | +| **DDL** | `fg.as_hive_ddl()` | `as_hive_ddl(name)` | +| **Feature Defs** | `fg.load_feature_definitions(df)` | `load_feature_definitions_from_dataframe(df)` | +| **Imports** | Multiple modules | Single `__init__.py` re-exports all | + +--- + +## Missing in V3 (Intentionally) + +These V2 features are **not wrapped** because core provides them directly: + +- `FeatureGroup.create()` - use `FeatureGroup.create()` from core +- `FeatureGroup.delete()` - use `FeatureGroup(...).delete()` from core +- `FeatureGroup.describe()` - use `FeatureGroup.get()` from core (returns typed object) +- `FeatureGroup.update()` - use `FeatureGroup(...).update()` from core +- `FeatureGroup.put_record()` - use `FeatureGroup(...).put_record()` from core +- `FeatureGroup.get_record()` - use `FeatureGroup(...).get_record()` from core +- `FeatureGroup.delete_record()` - use `FeatureGroup(...).delete_record()` from core +- `FeatureGroup.batch_get_record()` - use `FeatureGroup(...).batch_get_record()` from core +- `FeatureStore.search()` - use `FeatureStore.search()` from core +- `FeatureStore.list_feature_groups()` - use `FeatureGroup.get_all()` from core +- All config shapes (`OnlineStoreConfig`, etc.) - re-exported from core + +--- + +## Import Cheatsheet + +```python +# V3 - Everything from one place +from sagemaker.mlops.feature_store import ( + # Resources (from core) + FeatureGroup, + FeatureStore, + FeatureMetadata, + + # Shapes (from core) + OnlineStoreConfig, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + TtlDuration, + FeatureValue, + FeatureParameter, + ThroughputConfig, + Filter, + SearchExpression, + + # Enums (local) + TargetStoreEnum, + OnlineStoreStorageTypeEnum, + TableFormatEnum, + DeletionModeEnum, + ThroughputModeEnum, + + # Feature Definition helpers (local) + FeatureDefinition, + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, + + # Utility functions (local) + create_athena_query, + as_hive_ddl, + load_feature_definitions_from_dataframe, + ingest_dataframe, + create_dataset, + + # Classes (local) + DatasetBuilder, +) +``` diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py index ee8cd7d1a3..f15d6d3845 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -3,11 +3,10 @@ """SageMaker FeatureStore V3 - powered by sagemaker-core.""" # Resources from core -from sagemaker_core.main.resources import FeatureGroup, FeatureMetadata -from sagemaker_core.main.resources import FeatureStore +from sagemaker.core.resources import FeatureGroup, FeatureMetadata # Shapes from core (Pydantic - no to_dict() needed) -from sagemaker_core.main.shapes import ( +from sagemaker.core.shapes import ( DataCatalogConfig, FeatureParameter, FeatureValue, @@ -52,7 +51,6 @@ from sagemaker.mlops.feature_store.feature_utils import ( as_hive_ddl, create_athena_query, - create_dataset, get_session_from_role, ingest_dataframe, load_feature_definitions_from_dataframe, @@ -76,7 +74,6 @@ # Resources "FeatureGroup", "FeatureMetadata", - "FeatureStore", # Shapes "DataCatalogConfig", "FeatureParameter", @@ -113,7 +110,6 @@ # Utility functions "as_hive_ddl", "create_athena_query", - "create_dataset", "get_session_from_role", "ingest_dataframe", "load_feature_definitions_from_dataframe", diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py index f5450663a6..72e9535320 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/dataset_builder.py @@ -178,7 +178,9 @@ def construct_feature_group_to_be_merged( database=catalog_config.database, table_name=catalog_config.table_name, record_identifier_feature_name=record_id, - event_time_identifier_feature=FeatureDefinition(event_time_name, FeatureTypeEnum(event_time_type)), + event_time_identifier_feature=FeatureDefinition( + feature_name=event_time_name, feature_type=FeatureTypeEnum(event_time_type).value + ), target_feature_name_in_base=target_feature_name_in_base, table_type=TableType.FEATURE_GROUP, feature_name_in_target=feature_name_in_target, @@ -256,6 +258,47 @@ class DatasetBuilder: _event_time_ending_timestamp: datetime.datetime = field(default=None, init=False) _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = field(default_factory=list, init=False) + @classmethod + def create( + cls, + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + session: Session, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: List[str] = None, + kms_key_id: str = None, + ) -> "DatasetBuilder": + """Create a DatasetBuilder for generating a Dataset. + + Args: + base: A FeatureGroup or DataFrame to use as the base. + output_path: S3 URI for output. + session: SageMaker session. + record_identifier_feature_name: Required if base is DataFrame. + event_time_identifier_feature_name: Required if base is DataFrame. + included_feature_names: Features to include in output. + kms_key_id: KMS key for encryption. + + Returns: + DatasetBuilder instance. + """ + if isinstance(base, pd.DataFrame): + if not record_identifier_feature_name or not event_time_identifier_feature_name: + raise ValueError( + "record_identifier_feature_name and event_time_identifier_feature_name " + "are required when base is a DataFrame." + ) + return cls( + _sagemaker_session=session, + _base=base, + _output_path=output_path, + _record_identifier_feature_name=record_identifier_feature_name, + _event_time_identifier_feature_name=event_time_identifier_feature_name, + _included_feature_names=included_feature_names, + _kms_key_id=kms_key_id, + ) + def with_feature_group( self, feature_group: FeatureGroup, diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py index f7f6523b8d..0b7c747515 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -13,7 +13,6 @@ from sagemaker.mlops.feature_store import FeatureGroup as CoreFeatureGroup, FeatureGroup from sagemaker.core.helper.session_helper import Session from sagemaker.core.s3.client import S3Uploader, S3Downloader -from sagemaker.mlops.feature_store.dataset_builder import DatasetBuilder from sagemaker.mlops.feature_store.feature_definition import ( FeatureDefinition, FractionalFeatureDefinition, @@ -23,7 +22,7 @@ ) from sagemaker.mlops.feature_store.ingestion_manager_pandas import IngestionManagerPandas -from sagemaker import utils +from sagemaker.core.utils import unique_name_from_base logger = logging.getLogger(__name__) @@ -207,7 +206,7 @@ def upload_dataframe_to_s3( Tuple of (s3_folder, temp_table_name). """ - temp_id = utils.unique_name_from_base("dataframe-base") + temp_id = unique_name_from_base("dataframe-base") local_file = f"{temp_id}.csv" s3_folder = os.path.join(output_path, temp_id) @@ -460,29 +459,3 @@ def ingest_dataframe( manager.run(data_frame=data_frame, wait=wait, timeout=timeout) return manager -def create_dataset( - base: Union[FeatureGroup, pd.DataFrame], - output_path: str, - session: Session, - record_identifier_feature_name: str = None, - event_time_identifier_feature_name: str = None, - included_feature_names: Sequence[str] = None, - kms_key_id: str = None, -) -> DatasetBuilder: - """Create a DatasetBuilder for generating a Dataset.""" - if isinstance(base, pd.DataFrame): - if not record_identifier_feature_name or not event_time_identifier_feature_name: - raise ValueError( - "record_identifier_feature_name and event_time_identifier_feature_name " - "are required when base is a DataFrame." - ) - return DatasetBuilder( - _sagemaker_session=session, - _base=base, - _output_path=output_path, - _record_identifier_feature_name=record_identifier_feature_name, - _event_time_identifier_feature_name=event_time_identifier_feature_name, - _included_feature_names=included_feature_names, - _kms_key_id=kms_key_id, - ) - diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py index 60d022dab1..4d7b4e5375 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/ingestion_manager_pandas.py @@ -6,13 +6,12 @@ import signal from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field -from multiprocessing.pool import AsyncResult +from multiprocessing import Pool from typing import Any, Dict, Iterable, List, Sequence, Union import pandas as pd from pandas import DataFrame from pandas.api.types import is_list_like -from pathos.multiprocessing import ProcessingPool from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup from sagemaker.core.shapes import FeatureValue @@ -54,8 +53,8 @@ class IngestionManagerPandas: feature_definitions: Dict[str, Dict[Any, Any]] max_workers: int = 1 max_processes: int = 1 - _async_result: AsyncResult = field(default=None, init=False) - _processing_pool: ProcessingPool = field(default=None, init=False) + _async_result: Any = field(default=None, init=False) + _processing_pool: Pool = field(default=None, init=False) _failed_indices: List[int] = field(default_factory=list, init=False) @property @@ -100,12 +99,11 @@ def wait(self, timeout: Union[int, float] = None): results = self._async_result.get(timeout=timeout) except KeyboardInterrupt as e: self._processing_pool.terminate() - self._processing_pool.close() - self._processing_pool.clear() + self._processing_pool.join() raise e else: self._processing_pool.close() - self._processing_pool.clear() + self._processing_pool.join() self._failed_indices = [idx for failed in results for idx in failed] @@ -170,11 +168,10 @@ def _run_multi_process( def init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN) - self._processing_pool = ProcessingPool(self.max_processes, init_worker) - self._processing_pool.restart(force=True) + self._processing_pool = Pool(self.max_processes, init_worker) - self._async_result = self._processing_pool.amap( - lambda x: IngestionManagerPandas._run_multi_threaded(*x), + self._async_result = self._processing_pool.starmap_async( + IngestionManagerPandas._run_multi_threaded, args, ) diff --git a/sagemaker-mlops/tests/__init__.py b/sagemaker-mlops/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/__init__.py b/sagemaker-mlops/tests/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py new file mode 100644 index 0000000000..f34bf7d447 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py new file mode 100644 index 0000000000..9b2ec55895 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/conftest.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Conftest for feature_store tests.""" +import pytest +from unittest.mock import Mock, MagicMock +import pandas as pd +import numpy as np + + +@pytest.fixture +def mock_session(): + """Create a mock Session.""" + session = Mock() + session.boto_session = Mock() + session.boto_region_name = "us-west-2" + session.sagemaker_client = Mock() + session.sagemaker_runtime_client = Mock() + session.sagemaker_featurestore_runtime_client = Mock() + return session + + +@pytest.fixture +def sample_dataframe(): + """Create a sample DataFrame for testing.""" + return pd.DataFrame({ + "id": pd.Series([1, 2, 3, 4, 5], dtype="int64"), + "value": pd.Series([1.1, 2.2, 3.3, 4.4, 5.5], dtype="float64"), + "name": pd.Series(["a", "b", "c", "d", "e"], dtype="string"), + "event_time": pd.Series( + ["2024-01-01T00:00:00Z"] * 5, + dtype="string" + ), + }) + + +@pytest.fixture +def dataframe_with_collections(): + """Create a DataFrame with collection type columns.""" + return pd.DataFrame({ + "id": pd.Series([1, 2, 3], dtype="int64"), + "tags": pd.Series([["a", "b"], ["c"], ["d", "e", "f"]], dtype="object"), + "scores": pd.Series([[1.0, 2.0], [3.0], [4.0, 5.0]], dtype="object"), + "event_time": pd.Series(["2024-01-01"] * 3, dtype="string"), + }) + + +@pytest.fixture +def feature_definitions_dict(): + """Create a feature definitions dictionary.""" + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + "event_time": {"FeatureName": "event_time", "FeatureType": "String"}, + } + + +@pytest.fixture +def mock_feature_group(): + """Create a mock FeatureGroup from core.""" + fg = MagicMock() + fg.feature_group_name = "test-feature-group" + fg.record_identifier_feature_name = "id" + fg.event_time_feature_name = "event_time" + fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="name", feature_type="String"), + MagicMock(feature_name="event_time", feature_type="String"), + ] + fg.offline_store_config = MagicMock() + fg.offline_store_config.s3_storage_config.s3_uri = "s3://bucket/prefix" + fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix/resolved" + fg.offline_store_config.data_catalog_config.catalog = "AwsDataCatalog" + fg.offline_store_config.data_catalog_config.database = "sagemaker_featurestore" + fg.offline_store_config.data_catalog_config.table_name = "test_feature_group" + fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + fg.online_store_config = MagicMock() + fg.online_store_config.enable_online_store = True + return fg diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py new file mode 100644 index 0000000000..2fed784208 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_athena_query.py @@ -0,0 +1,113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for athena_query.py""" +import os +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd + +from sagemaker.mlops.feature_store.athena_query import AthenaQuery + + +class TestAthenaQuery: + @pytest.fixture + def mock_session(self): + session = Mock() + session.boto_session.client.return_value = Mock() + session.boto_region_name = "us-west-2" + return session + + @pytest.fixture + def athena_query(self, mock_session): + return AthenaQuery( + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + table_name="my_feature_group", + sagemaker_session=mock_session, + ) + + def test_initialization(self, athena_query): + assert athena_query.catalog == "AwsDataCatalog" + assert athena_query.database == "sagemaker_featurestore" + assert athena_query.table_name == "my_feature_group" + assert athena_query._current_query_execution_id is None + + @patch("sagemaker.mlops.feature_store.athena_query.start_query_execution") + def test_run_starts_query(self, mock_start, athena_query): + mock_start.return_value = {"QueryExecutionId": "query-123"} + + result = athena_query.run( + query_string="SELECT * FROM table", + output_location="s3://bucket/output", + ) + + assert result == "query-123" + assert athena_query._current_query_execution_id == "query-123" + assert athena_query._result_bucket == "bucket" + assert athena_query._result_file_prefix == "output" + + @patch("sagemaker.mlops.feature_store.athena_query.start_query_execution") + def test_run_with_kms_key(self, mock_start, athena_query): + mock_start.return_value = {"QueryExecutionId": "query-123"} + + athena_query.run( + query_string="SELECT * FROM table", + output_location="s3://bucket/output", + kms_key="arn:aws:kms:us-west-2:123:key/abc", + ) + + mock_start.assert_called_once() + call_kwargs = mock_start.call_args[1] + assert call_kwargs["kms_key"] == "arn:aws:kms:us-west-2:123:key/abc" + + @patch("sagemaker.mlops.feature_store.athena_query.wait_for_athena_query") + def test_wait_calls_helper(self, mock_wait, athena_query): + athena_query._current_query_execution_id = "query-123" + + athena_query.wait() + + mock_wait.assert_called_once_with(athena_query.sagemaker_session, "query-123") + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_get_query_execution(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} + + result = athena_query.get_query_execution() + + assert result["QueryExecution"]["Status"]["State"] == "SUCCEEDED" + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + @patch("sagemaker.mlops.feature_store.athena_query.download_athena_query_result") + @patch("pandas.read_csv") + @patch("os.path.join") + def test_as_dataframe_success(self, mock_join, mock_read_csv, mock_download, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + athena_query._result_bucket = "bucket" + athena_query._result_file_prefix = "prefix" + + mock_get.return_value = {"QueryExecution": {"Status": {"State": "SUCCEEDED"}}} + mock_join.return_value = "/tmp/query-123.csv" + mock_read_csv.return_value = pd.DataFrame({"col": [1, 2, 3]}) + + with patch("tempfile.gettempdir", return_value="/tmp"): + with patch("os.remove"): + df = athena_query.as_dataframe() + + assert len(df) == 3 + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_as_dataframe_raises_when_running(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "RUNNING"}}} + + with pytest.raises(RuntimeError, match="still executing"): + athena_query.as_dataframe() + + @patch("sagemaker.mlops.feature_store.athena_query.get_query_execution") + def test_as_dataframe_raises_when_failed(self, mock_get, athena_query): + athena_query._current_query_execution_id = "query-123" + mock_get.return_value = {"QueryExecution": {"Status": {"State": "FAILED"}}} + + with pytest.raises(RuntimeError, match="failed"): + athena_query.as_dataframe() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py new file mode 100644 index 0000000000..254fb0e196 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_dataset_builder.py @@ -0,0 +1,345 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for dataset_builder.py""" +import datetime +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd + +from sagemaker.mlops.feature_store import FeatureGroup +from sagemaker.mlops.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + TableType, + JoinTypeEnum, + JoinComparatorEnum, + construct_feature_group_to_be_merged, +) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, +) + + +class TestTableType: + def test_feature_group_value(self): + assert TableType.FEATURE_GROUP.value == "FeatureGroup" + + def test_data_frame_value(self): + assert TableType.DATA_FRAME.value == "DataFrame" + + +class TestJoinTypeEnum: + def test_inner_join(self): + assert JoinTypeEnum.INNER_JOIN.value == "JOIN" + + def test_left_join(self): + assert JoinTypeEnum.LEFT_JOIN.value == "LEFT JOIN" + + def test_right_join(self): + assert JoinTypeEnum.RIGHT_JOIN.value == "RIGHT JOIN" + + def test_full_join(self): + assert JoinTypeEnum.FULL_JOIN.value == "FULL JOIN" + + def test_cross_join(self): + assert JoinTypeEnum.CROSS_JOIN.value == "CROSS JOIN" + + +class TestJoinComparatorEnum: + def test_equals(self): + assert JoinComparatorEnum.EQUALS.value == "=" + + def test_greater_than(self): + assert JoinComparatorEnum.GREATER_THAN.value == ">" + + def test_less_than(self): + assert JoinComparatorEnum.LESS_THAN.value == "<" + + +class TestFeatureGroupToBeMerged: + def test_initialization(self): + fg = FeatureGroupToBeMerged( + features=["id", "value"], + included_feature_names=["id", "value"], + projected_feature_names=["id", "value"], + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + table_name="my_table", + record_identifier_feature_name="id", + event_time_identifier_feature=FeatureDefinition( + feature_name="event_time", + feature_type="String", + ), + ) + + assert fg.features == ["id", "value"] + assert fg.catalog == "AwsDataCatalog" + assert fg.table_name == "my_table" + assert fg.join_type == JoinTypeEnum.INNER_JOIN + assert fg.join_comparator == JoinComparatorEnum.EQUALS + + def test_custom_join_settings(self): + fg = FeatureGroupToBeMerged( + features=["id"], + included_feature_names=["id"], + projected_feature_names=["id"], + catalog="AwsDataCatalog", + database="db", + table_name="table", + record_identifier_feature_name="id", + event_time_identifier_feature=FeatureDefinition( + feature_name="ts", + feature_type="String", + ), + join_type=JoinTypeEnum.LEFT_JOIN, + join_comparator=JoinComparatorEnum.GREATER_THAN, + ) + + assert fg.join_type == JoinTypeEnum.LEFT_JOIN + assert fg.join_comparator == JoinComparatorEnum.GREATER_THAN + + +class TestConstructFeatureGroupToBeMerged: + @patch("sagemaker.mlops.feature_store.dataset_builder.FeatureGroup") + def test_constructs_from_feature_group(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_group_name = "test-fg" + mock_fg.record_identifier_feature_name = "id" + mock_fg.event_time_feature_name = "event_time" + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="event_time", feature_type="String"), + ] + mock_fg.offline_store_config.data_catalog_config.catalog = "MyCatalog" + mock_fg.offline_store_config.data_catalog_config.database = "MyDatabase" + mock_fg.offline_store_config.data_catalog_config.table_name = "MyTable" + mock_fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + mock_fg_class.get.return_value = mock_fg + + target_fg = MagicMock() + target_fg.feature_group_name = "test-fg" + + result = construct_feature_group_to_be_merged( + target_feature_group=target_fg, + included_feature_names=["id", "value"], + ) + + assert result.table_name == "MyTable" + assert result.database == "MyDatabase" + assert result.record_identifier_feature_name == "id" + assert result.table_type == TableType.FEATURE_GROUP + + @patch("sagemaker.mlops.feature_store.dataset_builder.FeatureGroup") + def test_raises_when_no_metastore(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_group_name = "test-fg" + mock_fg.offline_store_config = None + mock_fg_class.get.return_value = mock_fg + + target_fg = MagicMock() + target_fg.feature_group_name = "test-fg" + + with pytest.raises(RuntimeError, match="No metastore"): + construct_feature_group_to_be_merged(target_fg, None) + + +class TestDatasetBuilder: + @pytest.fixture + def mock_session(self): + return Mock() + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + "event_time": ["2024-01-01", "2024-01-02", "2024-01-03"], + }) + + def test_initialization_with_dataframe(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + assert builder._output_path == "s3://bucket/output" + assert builder._record_identifier_feature_name == "id" + + def test_fluent_api_point_in_time(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.point_in_time_accurate_join() + + assert result is builder + assert builder._point_in_time_accurate_join is True + + def test_fluent_api_include_duplicated(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.include_duplicated_records() + + assert result is builder + assert builder._include_duplicated_records is True + + def test_fluent_api_include_deleted(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.include_deleted_records() + + assert result is builder + assert builder._include_deleted_records is True + + def test_fluent_api_number_of_recent_records(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.with_number_of_recent_records_by_record_identifier(5) + + assert result is builder + assert builder._number_of_recent_records == 5 + + def test_fluent_api_number_of_records(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + result = builder.with_number_of_records_from_query_results(100) + + assert result is builder + assert builder._number_of_records == 100 + + def test_fluent_api_as_of(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + timestamp = datetime.datetime(2024, 1, 15, 12, 0, 0) + result = builder.as_of(timestamp) + + assert result is builder + assert builder._write_time_ending_timestamp == timestamp + + def test_fluent_api_event_time_range(self, mock_session, sample_dataframe): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 1, 31) + result = builder.with_event_time_range(start, end) + + assert result is builder + assert builder._event_time_starting_timestamp == start + assert builder._event_time_ending_timestamp == end + + @patch.object(DatasetBuilder, "_run_query") + @patch("sagemaker.mlops.feature_store.dataset_builder.construct_feature_group_to_be_merged") + def test_with_feature_group(self, mock_construct, mock_run, mock_session, sample_dataframe): + mock_fg_to_merge = MagicMock() + mock_construct.return_value = mock_fg_to_merge + + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base=sample_dataframe, + _output_path="s3://bucket/output", + _record_identifier_feature_name="id", + _event_time_identifier_feature_name="event_time", + ) + + mock_fg = MagicMock() + result = builder.with_feature_group(mock_fg, target_feature_name_in_base="id") + + assert result is builder + assert len(builder._feature_groups_to_be_merged) == 1 + + +class TestDatasetBuilderCreate: + @pytest.fixture + def mock_session(self): + return Mock() + + def test_create_with_feature_group(self, mock_session): + mock_fg = MagicMock(spec=FeatureGroup) + builder = DatasetBuilder.create( + base=mock_fg, + output_path="s3://bucket/output", + session=mock_session, + ) + assert builder._base == mock_fg + assert builder._output_path == "s3://bucket/output" + + def test_create_with_dataframe(self, mock_session): + df = pd.DataFrame({"id": [1], "value": [10]}) + builder = DatasetBuilder.create( + base=df, + output_path="s3://bucket/output", + session=mock_session, + record_identifier_feature_name="id", + event_time_identifier_feature_name="event_time", + ) + assert builder._record_identifier_feature_name == "id" + + def test_create_with_dataframe_requires_identifiers(self, mock_session): + df = pd.DataFrame({"id": [1], "value": [10]}) + with pytest.raises(ValueError, match="record_identifier_feature_name"): + DatasetBuilder.create( + base=df, + output_path="s3://bucket/output", + session=mock_session, + ) + + +class TestDatasetBuilderValidation: + @pytest.fixture + def mock_session(self): + return Mock() + + def test_to_csv_raises_for_invalid_base(self, mock_session): + builder = DatasetBuilder( + _sagemaker_session=mock_session, + _base="invalid", # Not DataFrame or FeatureGroup + _output_path="s3://bucket/output", + ) + + with pytest.raises(ValueError, match="must be either"): + builder.to_csv_file() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py new file mode 100644 index 0000000000..299868b5d2 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_definition.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for feature_definition.py""" +import pytest + +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + FeatureTypeEnum, + CollectionTypeEnum, + IntegralFeatureDefinition, + FractionalFeatureDefinition, + StringFeatureDefinition, + VectorCollectionType, + ListCollectionType, + SetCollectionType, +) + + +class TestFeatureTypeEnum: + def test_fractional_value(self): + assert FeatureTypeEnum.FRACTIONAL.value == "Fractional" + + def test_integral_value(self): + assert FeatureTypeEnum.INTEGRAL.value == "Integral" + + def test_string_value(self): + assert FeatureTypeEnum.STRING.value == "String" + + +class TestCollectionTypeEnum: + def test_list_value(self): + assert CollectionTypeEnum.LIST.value == "List" + + def test_set_value(self): + assert CollectionTypeEnum.SET.value == "Set" + + def test_vector_value(self): + assert CollectionTypeEnum.VECTOR.value == "Vector" + + +class TestCollectionTypes: + def test_list_collection_type(self): + collection = ListCollectionType() + assert collection.collection_type == "List" + assert collection.collection_config is None + + def test_set_collection_type(self): + collection = SetCollectionType() + assert collection.collection_type == "Set" + assert collection.collection_config is None + + def test_vector_collection_type(self): + collection = VectorCollectionType(dimension=128) + assert collection.collection_type == "Vector" + assert collection.collection_config is not None + assert collection.collection_config.vector_config.dimension == 128 + + +class TestFeatureDefinitionFactories: + def test_integral_feature_definition(self): + definition = IntegralFeatureDefinition(feature_name="my_int_feature") + assert definition.feature_name == "my_int_feature" + assert definition.feature_type == "Integral" + assert definition.collection_type is None + + def test_fractional_feature_definition(self): + definition = FractionalFeatureDefinition(feature_name="my_float_feature") + assert definition.feature_name == "my_float_feature" + assert definition.feature_type == "Fractional" + assert definition.collection_type is None + + def test_string_feature_definition(self): + definition = StringFeatureDefinition(feature_name="my_string_feature") + assert definition.feature_name == "my_string_feature" + assert definition.feature_type == "String" + assert definition.collection_type is None + + def test_integral_with_list_collection(self): + definition = IntegralFeatureDefinition( + feature_name="my_int_list", + collection_type=ListCollectionType(), + ) + assert definition.feature_name == "my_int_list" + assert definition.feature_type == "Integral" + assert definition.collection_type == "List" + + def test_string_with_set_collection(self): + definition = StringFeatureDefinition( + feature_name="my_string_set", + collection_type=SetCollectionType(), + ) + assert definition.feature_name == "my_string_set" + assert definition.feature_type == "String" + assert definition.collection_type == "Set" + + def test_fractional_with_vector_collection(self): + definition = FractionalFeatureDefinition( + feature_name="my_embedding", + collection_type=VectorCollectionType(dimension=256), + ) + assert definition.feature_name == "my_embedding" + assert definition.feature_type == "Fractional" + assert definition.collection_type == "Vector" + assert definition.collection_config.vector_config.dimension == 256 + + +class TestFeatureDefinitionSerialization: + """Test that FeatureDefinition can be serialized (Pydantic model_dump).""" + + def test_simple_definition_serialization(self): + definition = IntegralFeatureDefinition(feature_name="id") + # Pydantic model - use model_dump + data = definition.model_dump(exclude_none=True) + assert data["feature_name"] == "id" + assert data["feature_type"] == "Integral" + + def test_collection_definition_serialization(self): + definition = FractionalFeatureDefinition( + feature_name="vector", + collection_type=VectorCollectionType(dimension=10), + ) + data = definition.model_dump(exclude_none=True) + assert data["feature_name"] == "vector" + assert data["feature_type"] == "Fractional" + assert data["collection_type"] == "Vector" + assert data["collection_config"]["vector_config"]["dimension"] == 10 diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py new file mode 100644 index 0000000000..a9d5408bf6 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_feature_utils.py @@ -0,0 +1,202 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for feature_utils.py""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd +import numpy as np + +from sagemaker.mlops.feature_store.feature_utils import ( + load_feature_definitions_from_dataframe, + as_hive_ddl, + create_athena_query, + ingest_dataframe, + get_session_from_role, + _is_collection_column, + _generate_feature_definition, +) +from sagemaker.mlops.feature_store.feature_definition import ( + FeatureDefinition, + ListCollectionType, +) + + +class TestLoadFeatureDefinitionsFromDataframe: + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": pd.Series([1, 2, 3], dtype="int64"), + "value": pd.Series([1.1, 2.2, 3.3], dtype="float64"), + "name": pd.Series(["a", "b", "c"], dtype="string"), + }) + + def test_infers_integral_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + id_def = next(d for d in defs if d.feature_name == "id") + assert id_def.feature_type == "Integral" + + def test_infers_fractional_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + value_def = next(d for d in defs if d.feature_name == "value") + assert value_def.feature_type == "Fractional" + + def test_infers_string_type(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + name_def = next(d for d in defs if d.feature_name == "name") + assert name_def.feature_type == "String" + + def test_returns_correct_count(self, sample_dataframe): + defs = load_feature_definitions_from_dataframe(sample_dataframe) + assert len(defs) == 3 + + def test_collection_type_with_in_memory_storage(self): + df = pd.DataFrame({ + "id": pd.Series([1, 2], dtype="int64"), + "tags": pd.Series([["a", "b"], ["c"]], dtype="object"), + }) + defs = load_feature_definitions_from_dataframe(df, online_storage_type="InMemory") + tags_def = next(d for d in defs if d.feature_name == "tags") + assert tags_def.collection_type == "List" + + +class TestIsCollectionColumn: + def test_list_column_returns_true(self): + series = pd.Series([[1, 2], [3, 4], [5]]) + assert _is_collection_column(series) == True + + def test_scalar_column_returns_false(self): + series = pd.Series([1, 2, 3]) + assert _is_collection_column(series) == False + + def test_empty_series(self): + series = pd.Series([], dtype="object") + assert _is_collection_column(series) == False + + +class TestAsHiveDdl: + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_generates_ddl_string(self, mock_fg_class): + # Setup mock + mock_fg = MagicMock() + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + MagicMock(feature_name="value", feature_type="Fractional"), + MagicMock(feature_name="name", feature_type="String"), + ] + mock_fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix" + mock_fg_class.get.return_value = mock_fg + + ddl = as_hive_ddl("my-feature-group") + + assert "CREATE EXTERNAL TABLE" in ddl + assert "my-feature-group" in ddl + assert "id INT" in ddl + assert "value FLOAT" in ddl + assert "name STRING" in ddl + assert "write_time TIMESTAMP" in ddl + assert "event_time TIMESTAMP" in ddl + assert "is_deleted BOOLEAN" in ddl + assert "s3://bucket/prefix" in ddl + + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_custom_database_and_table(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.feature_definitions = [] + mock_fg.offline_store_config.s3_storage_config.resolved_output_s3_uri = "s3://bucket/prefix" + mock_fg_class.get.return_value = mock_fg + + ddl = as_hive_ddl("my-fg", database="custom_db", table_name="custom_table") + + assert "custom_db.custom_table" in ddl + + +class TestCreateAthenaQuery: + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_creates_athena_query(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.offline_store_config.data_catalog_config.catalog = "MyCatalog" + mock_fg.offline_store_config.data_catalog_config.database = "MyDatabase" + mock_fg.offline_store_config.data_catalog_config.table_name = "MyTable" + mock_fg.offline_store_config.data_catalog_config.disable_glue_table_creation = False + mock_fg_class.get.return_value = mock_fg + + session = Mock() + query = create_athena_query("my-fg", session) + + assert query.catalog == "AwsDataCatalog" # disable_glue=False uses default + assert query.database == "MyDatabase" + assert query.table_name == "MyTable" + + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_raises_when_no_metastore(self, mock_fg_class): + mock_fg = MagicMock() + mock_fg.offline_store_config = None + mock_fg_class.get.return_value = mock_fg + + session = Mock() + with pytest.raises(RuntimeError, match="No metastore"): + create_athena_query("my-fg", session) + + +class TestIngestDataframe: + @patch("sagemaker.mlops.feature_store.feature_utils.IngestionManagerPandas") + @patch("sagemaker.mlops.feature_store.feature_utils.CoreFeatureGroup") + def test_creates_manager_and_runs(self, mock_fg_class, mock_manager_class): + mock_fg = MagicMock() + mock_fg.feature_definitions = [ + MagicMock(feature_name="id", feature_type="Integral"), + ] + mock_fg_class.get.return_value = mock_fg + + mock_manager = MagicMock() + mock_manager_class.return_value = mock_manager + + df = pd.DataFrame({"id": [1, 2, 3]}) + result = ingest_dataframe("my-fg", df, max_workers=2, max_processes=1) + + mock_manager_class.assert_called_once() + mock_manager.run.assert_called_once() + assert result == mock_manager + + def test_raises_on_invalid_max_workers(self): + df = pd.DataFrame({"id": [1, 2, 3]}) + with pytest.raises(ValueError, match="max_workers"): + ingest_dataframe("my-fg", df, max_workers=0) + + def test_raises_on_invalid_max_processes(self): + df = pd.DataFrame({"id": [1, 2, 3]}) + with pytest.raises(ValueError, match="max_processes"): + ingest_dataframe("my-fg", df, max_processes=-1) + + +class TestGetSessionFromRole: + @patch("sagemaker.mlops.feature_store.feature_utils.boto3") + @patch("sagemaker.mlops.feature_store.feature_utils.Session") + def test_creates_session_without_role(self, mock_session_class, mock_boto3): + mock_boto_session = MagicMock() + mock_boto3.Session.return_value = mock_boto_session + + get_session_from_role(region="us-west-2") + + mock_boto3.Session.assert_called_with(region_name="us-west-2") + mock_session_class.assert_called_once() + + @patch("sagemaker.mlops.feature_store.feature_utils.boto3") + @patch("sagemaker.mlops.feature_store.feature_utils.Session") + def test_assumes_role_when_provided(self, mock_session_class, mock_boto3): + mock_boto_session = MagicMock() + mock_sts = MagicMock() + mock_sts.assume_role.return_value = { + "Credentials": { + "AccessKeyId": "key", + "SecretAccessKey": "secret", + "SessionToken": "token", + } + } + mock_boto_session.client.return_value = mock_sts + mock_boto3.Session.return_value = mock_boto_session + + get_session_from_role(region="us-west-2", assume_role="arn:aws:iam::123:role/MyRole") + + mock_sts.assume_role.assert_called_once() + assert mock_boto3.Session.call_count == 2 # Initial + after assume diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py new file mode 100644 index 0000000000..2ecf495967 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_ingestion_manager_pandas.py @@ -0,0 +1,256 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for ingestion_manager_pandas.py""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import pandas as pd +import numpy as np + +from sagemaker.mlops.feature_store.ingestion_manager_pandas import ( + IngestionManagerPandas, + IngestionError, +) + + +class TestIngestionError: + def test_stores_failed_rows(self): + error = IngestionError([1, 5, 10], "Some rows failed") + assert error.failed_rows == [1, 5, 10] + assert "Some rows failed" in str(error) + + +class TestIngestionManagerPandas: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + } + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + "name": ["a", "b", "c"], + }) + + @pytest.fixture + def manager(self, feature_definitions): + return IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=1, + max_processes=1, + ) + + def test_initialization(self, manager): + assert manager.feature_group_name == "test-fg" + assert manager.max_workers == 1 + assert manager.max_processes == 1 + assert manager.failed_rows == [] + + def test_failed_rows_property(self, manager): + manager._failed_indices = [1, 2, 3] + assert manager.failed_rows == [1, 2, 3] + + +class TestIngestionManagerHelpers: + def test_is_feature_collection_type_true(self): + feature_defs = { + "tags": {"FeatureName": "tags", "FeatureType": "String", "CollectionType": "List"}, + } + assert IngestionManagerPandas._is_feature_collection_type("tags", feature_defs) is True + + def test_is_feature_collection_type_false(self): + feature_defs = { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + } + assert IngestionManagerPandas._is_feature_collection_type("id", feature_defs) is False + + def test_is_feature_collection_type_missing(self): + feature_defs = {} + assert IngestionManagerPandas._is_feature_collection_type("unknown", feature_defs) is False + + def test_feature_value_is_not_none_scalar(self): + assert IngestionManagerPandas._feature_value_is_not_none(5) is True + assert IngestionManagerPandas._feature_value_is_not_none(None) is False + assert IngestionManagerPandas._feature_value_is_not_none(np.nan) is False + + def test_feature_value_is_not_none_list(self): + assert IngestionManagerPandas._feature_value_is_not_none([1, 2, 3]) is True + assert IngestionManagerPandas._feature_value_is_not_none([]) is True + assert IngestionManagerPandas._feature_value_is_not_none(None) is False + + def test_convert_to_string_list(self): + result = IngestionManagerPandas._convert_to_string_list([1, 2, 3]) + assert result == ["1", "2", "3"] + + def test_convert_to_string_list_with_none(self): + result = IngestionManagerPandas._convert_to_string_list([1, None, 3]) + assert result == ["1", None, "3"] + + def test_convert_to_string_list_raises_for_non_list(self): + with pytest.raises(ValueError, match="must be an Array"): + IngestionManagerPandas._convert_to_string_list("not a list") + + +class TestIngestionManagerRun: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "value": {"FeatureName": "value", "FeatureType": "Fractional"}, + } + + @pytest.fixture + def sample_dataframe(self): + return pd.DataFrame({ + "id": [1, 2, 3], + "value": [1.1, 2.2, 3.3], + }) + + @patch.object(IngestionManagerPandas, "_run_single_process_single_thread") + def test_run_single_thread_mode(self, mock_single, feature_definitions, sample_dataframe): + manager = IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=1, + max_processes=1, + ) + + manager.run(sample_dataframe) + + mock_single.assert_called_once() + + @patch.object(IngestionManagerPandas, "_run_multi_process") + def test_run_multi_process_mode(self, mock_multi, feature_definitions, sample_dataframe): + manager = IngestionManagerPandas( + feature_group_name="test-fg", + feature_definitions=feature_definitions, + max_workers=2, + max_processes=2, + ) + + manager.run(sample_dataframe) + + mock_multi.assert_called_once() + + +class TestIngestionManagerIngestRow: + @pytest.fixture + def feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "name": {"FeatureName": "name", "FeatureType": "String"}, + } + + @pytest.fixture + def collection_feature_definitions(self): + return { + "id": {"FeatureName": "id", "FeatureType": "Integral"}, + "tags": {"FeatureName": "tags", "FeatureType": "String", "CollectionType": "List"}, + } + + def test_ingest_row_success(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + mock_fg.put_record.assert_called_once() + assert len(failed_rows) == 0 + + def test_ingest_row_with_collection_type(self, collection_feature_definitions): + df = pd.DataFrame({ + "id": [1], + "tags": [["tag1", "tag2"]], + }) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=collection_feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + mock_fg.put_record.assert_called_once() + call_args = mock_fg.put_record.call_args + record = call_args[1]["record"] + + # Find the tags feature value + tags_value = next(v for v in record if v.feature_name == "tags") + assert tags_value.value_as_string_list == ["tag1", "tag2"] + + def test_ingest_row_failure_appends_to_failed(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + mock_fg.put_record.side_effect = Exception("API Error") + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + assert len(failed_rows) == 1 + assert failed_rows[0] == 0 # Index of failed row + + def test_ingest_row_with_target_stores(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": ["test"]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=["OnlineStore"], + ) + + call_args = mock_fg.put_record.call_args + assert call_args[1]["target_stores"] == ["OnlineStore"] + + def test_ingest_row_skips_none_values(self, feature_definitions): + df = pd.DataFrame({"id": [1], "name": [None]}) + mock_fg = MagicMock() + failed_rows = [] + + for row in df.itertuples(): + IngestionManagerPandas._ingest_row( + data_frame=df, + row=row, + feature_group=mock_fg, + feature_definitions=feature_definitions, + failed_rows=failed_rows, + target_stores=None, + ) + + call_args = mock_fg.put_record.call_args + record = call_args[1]["record"] + # Only id should be in record, name is None + assert len(record) == 1 + assert record[0].feature_name == "id" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py new file mode 100644 index 0000000000..44e3ec6085 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_inputs.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""Unit tests for inputs.py (enums).""" +import pytest + +from sagemaker.mlops.feature_store.inputs import ( + TargetStoreEnum, + OnlineStoreStorageTypeEnum, + TableFormatEnum, + ResourceEnum, + SearchOperatorEnum, + SortOrderEnum, + FilterOperatorEnum, + DeletionModeEnum, + ExpirationTimeResponseEnum, + ThroughputModeEnum, +) + + +class TestTargetStoreEnum: + def test_online_store(self): + assert TargetStoreEnum.ONLINE_STORE.value == "OnlineStore" + + def test_offline_store(self): + assert TargetStoreEnum.OFFLINE_STORE.value == "OfflineStore" + + +class TestOnlineStoreStorageTypeEnum: + def test_standard(self): + assert OnlineStoreStorageTypeEnum.STANDARD.value == "Standard" + + def test_in_memory(self): + assert OnlineStoreStorageTypeEnum.IN_MEMORY.value == "InMemory" + + +class TestTableFormatEnum: + def test_glue(self): + assert TableFormatEnum.GLUE.value == "Glue" + + def test_iceberg(self): + assert TableFormatEnum.ICEBERG.value == "Iceberg" + + +class TestResourceEnum: + def test_feature_group(self): + assert ResourceEnum.FEATURE_GROUP.value == "FeatureGroup" + + def test_feature_metadata(self): + assert ResourceEnum.FEATURE_METADATA.value == "FeatureMetadata" + + +class TestSearchOperatorEnum: + def test_and(self): + assert SearchOperatorEnum.AND.value == "And" + + def test_or(self): + assert SearchOperatorEnum.OR.value == "Or" + + +class TestSortOrderEnum: + def test_ascending(self): + assert SortOrderEnum.ASCENDING.value == "Ascending" + + def test_descending(self): + assert SortOrderEnum.DESCENDING.value == "Descending" + + +class TestFilterOperatorEnum: + def test_equals(self): + assert FilterOperatorEnum.EQUALS.value == "Equals" + + def test_not_equals(self): + assert FilterOperatorEnum.NOT_EQUALS.value == "NotEquals" + + def test_greater_than(self): + assert FilterOperatorEnum.GREATER_THAN.value == "GreaterThan" + + def test_contains(self): + assert FilterOperatorEnum.CONTAINS.value == "Contains" + + def test_exists(self): + assert FilterOperatorEnum.EXISTS.value == "Exists" + + def test_in(self): + assert FilterOperatorEnum.IN.value == "In" + + +class TestDeletionModeEnum: + def test_soft_delete(self): + assert DeletionModeEnum.SOFT_DELETE.value == "SoftDelete" + + def test_hard_delete(self): + assert DeletionModeEnum.HARD_DELETE.value == "HardDelete" + + +class TestExpirationTimeResponseEnum: + def test_disabled(self): + assert ExpirationTimeResponseEnum.DISABLED.value == "Disabled" + + def test_enabled(self): + assert ExpirationTimeResponseEnum.ENABLED.value == "Enabled" + + +class TestThroughputModeEnum: + def test_on_demand(self): + assert ThroughputModeEnum.ON_DEMAND.value == "OnDemand" + + def test_provisioned(self): + assert ThroughputModeEnum.PROVISIONED.value == "Provisioned" From f38579455720e0a70a5b7d4dc215e8e095489b18 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Mon, 26 Jan 2026 16:32:37 -0800 Subject: [PATCH 3/8] feat(feature_store): Add Lake Formation support to Feature Group - Add LakeFormationConfig class to configure Lake Formation governance on offline stores - Implement FeatureGroup subclass with Lake Formation integration capabilities - Add helper methods for S3 URI/ARN conversion and Lake Formation role management - Add S3 deny policy generation for Lake Formation access control - Implement Lake Formation resource registration and S3 bucket policy setup - Add integration tests for Lake Formation feature store workflows - Add unit tests for Lake Formation configuration and policy generation - Update feature_store module exports to include FeatureGroup and LakeFormationConfig - Update API documentation to include Feature Store section in sagemaker_mlops.rst - Enable fine-grained access control for feature store offline stores using AWS Lake Formation --- docs/api/sagemaker_mlops.rst | 8 + .../sagemaker/mlops/feature_store/__init__.py | 6 +- .../mlops/feature_store/feature_group.py | 711 ++++++ .../integ/test_featureStore_lakeformation.py | 660 +++++ .../mlops/feature_store/test_lakeformation.py | 2141 +++++++++++++++++ 5 files changed, 3525 insertions(+), 1 deletion(-) create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py create mode 100644 sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py diff --git a/docs/api/sagemaker_mlops.rst b/docs/api/sagemaker_mlops.rst index f67879111d..d9f911068e 100644 --- a/docs/api/sagemaker_mlops.rst +++ b/docs/api/sagemaker_mlops.rst @@ -21,6 +21,14 @@ Workflow Management :undoc-members: :show-inheritance: +Feature Store +------------- + +.. automodule:: sagemaker.mlops.feature_store + :members: + :undoc-members: + :show-inheritance: + Local Development ----------------- diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py index f15d6d3845..1b635df2a7 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -2,8 +2,11 @@ # Licensed under the Apache License, Version 2.0 """SageMaker FeatureStore V3 - powered by sagemaker-core.""" +# FeatureGroup with Lake Formation support (local subclass) +from sagemaker.mlops.feature_store.feature_group import FeatureGroup, LakeFormationConfig + # Resources from core -from sagemaker.core.resources import FeatureGroup, FeatureMetadata +from sagemaker.core.resources import FeatureMetadata # Shapes from core (Pydantic - no to_dict() needed) from sagemaker.core.shapes import ( @@ -79,6 +82,7 @@ "FeatureParameter", "FeatureValue", "Filter", + "LakeFormationConfig", "OfflineStoreConfig", "OnlineStoreConfig", "OnlineStoreSecurityConfig", diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py new file mode 100644 index 0000000000..c5fcb9211a --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py @@ -0,0 +1,711 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""FeatureGroup with Lake Formation support.""" + +import logging +from typing import List, Optional + +import botocore.exceptions + +from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup +from sagemaker.core.resources import Base +from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + OnlineStoreConfig, + Tag, + ThroughputConfig, +) +from sagemaker.core.shapes import Unassigned +from sagemaker.core.helper.pipeline_variable import StrPipeVar +from sagemaker.core.s3.utils import parse_s3_url +from sagemaker.core.common_utils import aws_partition +from boto3 import Session + + +logger = logging.getLogger(__name__) + + +class LakeFormationConfig: + """Configuration for Lake Formation governance on Feature Group offline stores. + + Attributes: + enabled: If True, enables Lake Formation governance for the offline store. + Requires offline_store_config and role_arn to be set on the Feature Group. + use_service_linked_role: Whether to use the Lake Formation service-linked role + for S3 registration. If True, Lake Formation uses its service-linked role. + If False, registration_role_arn must be provided. Default is True. + registration_role_arn: IAM role ARN to use for S3 registration with Lake Formation. + Required when use_service_linked_role is False. This can be different from the + Feature Group's execution role. + show_s3_policy: If True, prints the S3 deny policy to the console after successful + Lake Formation setup. This policy should be added to your S3 bucket to restrict + access to only the allowed principals. Default is False. + """ + + enabled: bool = False + use_service_linked_role: bool = True + registration_role_arn: Optional[str] = None + show_s3_policy: bool = False + + +class FeatureGroup(CoreFeatureGroup): + + # Inherit parent docstring and append our additions + if CoreFeatureGroup.__doc__ and __doc__: + __doc__ = CoreFeatureGroup.__doc__ + + @staticmethod + def _s3_uri_to_arn(s3_uri: str, region: Optional[str] = None) -> str: + """ + Convert S3 URI to S3 ARN format for Lake Formation. + + Args: + s3_uri: S3 URI in format s3://bucket/path or already an ARN + region: AWS region name (e.g., 'us-west-2'). Used to determine the correct + partition for the ARN. If not provided, defaults to 'aws' partition. + + Returns: + S3 ARN in format arn:{partition}:s3:::bucket/path + + Note: + This format is specifically used for Lake Formation resource registration. + The triple colon (:::) after 's3' is correct - S3 ARNs don't include + region or account ID fields. + """ + if s3_uri.startswith("arn:"): + return s3_uri + + # Determine partition based on region + partition = aws_partition(region) if region else "aws" + + bucket, key = parse_s3_url(s3_uri) + # Reconstruct as ARN - key may be empty string + s3_path = f"{bucket}/{key}" if key else bucket + return f"arn:{partition}:s3:::{s3_path}" + + @staticmethod + def _extract_account_id_from_arn(arn: str) -> str: + """ + Extract AWS account ID from an ARN. + + Args: + arn: AWS ARN in format arn:aws:service:region:account:resource + + Returns: + AWS account ID (the 5th colon-separated field) + + Raises: + ValueError: If ARN format is invalid (fewer than 5 colon-separated parts) + """ + parts = arn.split(":") + if len(parts) < 5: + raise ValueError(f"Invalid ARN format: {arn}") + return parts[4] + + @staticmethod + def _get_lake_formation_service_linked_role_arn( + account_id: str, region: Optional[str] = None + ) -> str: + """ + Generate the Lake Formation service-linked role ARN for an account. + + Args: + account_id: AWS account ID + region: AWS region name (e.g., 'us-west-2'). Used to determine the correct + partition for the ARN. If not provided, defaults to 'aws' partition. + + Returns: + Lake Formation service-linked role ARN in format: + arn:{partition}:iam::{account}:role/aws-service-role/lakeformation.amazonaws.com/ + AWSServiceRoleForLakeFormationDataAccess + """ + partition = aws_partition(region) if region else "aws" + return ( + f"arn:{partition}:iam::{account_id}:role/aws-service-role/" + f"lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" + ) + + def _generate_s3_deny_policy( + self, + bucket_name: str, + s3_prefix: str, + lake_formation_role_arn: str, + feature_store_role_arn: str, + ) -> dict: + """ + Generate an S3 deny policy for Lake Formation governance. + + This policy denies S3 access to the offline store data prefix except for + the Lake Formation role and Feature Store execution role. + + Args: + bucket_name: S3 bucket name. + s3_prefix: S3 prefix path (without bucket name). + lake_formation_role_arn: Lake Formation registration role ARN. + feature_store_role_arn: Feature Store execution role ARN. + + Returns: + S3 bucket policy as a dict with valid JSON structure containing: + - Version: "2012-10-17" + - Statement: List with two deny statements: + 1. Deny GetObject, PutObject, DeleteObject on data prefix except allowed principals + 2. Deny ListBucket on bucket with prefix condition except allowed principals + """ + policy = { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "DenyAllAccessToFeatureStorePrefixExceptAllowedPrincipals", + "Effect": "Deny", + "Principal": "*", + "Action": ["s3:GetObject", "s3:PutObject", "s3:DeleteObject"], + "Resource": f"arn:aws:s3:::{bucket_name}/{s3_prefix}/*", + "Condition": { + "StringNotEquals": { + "aws:PrincipalArn": [ + lake_formation_role_arn, + feature_store_role_arn, + ] + } + }, + }, + { + "Sid": "DenyListOnPrefixExceptAllowedPrincipals", + "Effect": "Deny", + "Principal": "*", + "Action": "s3:ListBucket", + "Resource": f"arn:aws:s3:::{bucket_name}", + "Condition": { + "StringLike": {"s3:prefix": f"{s3_prefix}/*"}, + "StringNotEquals": { + "aws:PrincipalArn": [ + lake_formation_role_arn, + feature_store_role_arn, + ] + }, + }, + }, + ], + } + return policy + + def _get_lake_formation_client( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ): + """ + Get a Lake Formation client. + + Args: + session: Boto3 session. If not provided, a new session will be created. + region: AWS region name. + + Returns: + A boto3 Lake Formation client. + """ + # TODO: don't create w new client for each call + boto_session = session or Session() + return boto_session.client("lakeformation", region_name=region) + + def _register_s3_with_lake_formation( + self, + s3_location: str, + session: Optional[Session] = None, + region: Optional[str] = None, + use_service_linked_role: bool = True, + role_arn: Optional[str] = None, + ) -> bool: + """ + Register an S3 location with Lake Formation. + + Args: + s3_location: S3 URI or ARN to register. + session: Boto3 session. + region: AWS region. If not provided, will be inferred from the session. + use_service_linked_role: Whether to use the Lake Formation service-linked role. + If True, Lake Formation uses its service-linked role for registration. + If False, role_arn must be provided. + role_arn: IAM role ARN to use for registration. Required when + use_service_linked_role is False. + + Returns: + True if registration succeeded or location already registered. + + Raises: + ValueError: If use_service_linked_role is False but role_arn is not provided. + ClientError: If registration fails for unexpected reasons. + """ + if not use_service_linked_role and not role_arn: + raise ValueError("role_arn must be provided when use_service_linked_role is False") + + # Get region from session if not provided + if region is None and session is not None: + region = session.region_name() + + client = self._get_lake_formation_client(session, region) + resource_arn = self._s3_uri_to_arn(s3_location, region) + + try: + register_params = {"ResourceArn": resource_arn} + + if use_service_linked_role: + register_params["UseServiceLinkedRole"] = True + else: + register_params["RoleArn"] = role_arn + + client.register_resource(**register_params) + logger.info(f"Successfully registered S3 location: {resource_arn}") + return True + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "AlreadyExistsException": + logger.info(f"S3 location already registered: {resource_arn}") + return True + raise + + def _revoke_iam_allowed_principal( + self, + database_name: str, + table_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> bool: + """ + Revoke IAMAllowedPrincipal permissions from a Glue table. + + Args: + database_name: Glue database name. + table_name: Glue table name. + session: Boto3 session. + region: AWS region. If not provided, will be inferred from the session. + + Returns: + True if revocation succeeded or permissions didn't exist. + + Raises: + ClientError: If revocation fails for unexpected reasons. + """ + # Get region from session if not provided + if region is None and session is not None: + region = session.region_name() + + client = self._get_lake_formation_client(session, region) + + try: + client.revoke_permissions( + Principal={"DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS"}, + Resource={ + "Table": { + "DatabaseName": database_name, + "Name": table_name, + } + }, + Permissions=["ALL"], + ) + logger.info(f"Revoked IAMAllowedPrincipal from table: {database_name}.{table_name}") + return True + except botocore.exceptions.ClientError as e: + # if the Table doesn't have that permission because the user already revoked it + # then just return True + if e.response["Error"]["Code"] == "InvalidInputException": + logger.info( + f"IAMAllowedPrincipal permissions may not exist on: {database_name}.{table_name}" + ) + return True + raise + + def _grant_lake_formation_permissions( + self, + role_arn: str, + database_name: str, + table_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> bool: + """ + Grant permissions to a role on a Glue table via Lake Formation. + + Args: + role_arn: IAM role ARN to grant permissions to. + database_name: Glue database name. + table_name: Glue table name. + session: Boto3 session. + region: AWS region. If not provided, will be inferred from the session. + + Returns: + True if grant succeeded or permissions already exist. + + Raises: + ClientError: If grant fails for unexpected reasons. + """ + # Get region from session if not provided + if region is None and session is not None: + region = session.region_name() + + client = self._get_lake_formation_client(session, region) + permissions = ["SELECT", "INSERT", "DELETE", "DESCRIBE", "ALTER"] + + try: + client.grant_permissions( + Principal={"DataLakePrincipalIdentifier": role_arn}, + Resource={ + "Table": { + "DatabaseName": database_name, + "Name": table_name, + } + }, + Permissions=permissions, + PermissionsWithGrantOption=[], + ) + logger.info(f"Granted permissions to {role_arn} on table: {database_name}.{table_name}") + return True + except botocore.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "InvalidInputException": + logger.info( + f"Permissions may already exist for {role_arn} on: {database_name}.{table_name}" + ) + return True + raise + + @Base.add_validate_call + def enable_lake_formation( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + use_service_linked_role: bool = True, + registration_role_arn: Optional[str] = None, + wait_for_active: bool = False, + show_s3_policy: bool = False, + ) -> dict: + """ + Enable Lake Formation governance for this Feature Group's offline store. + + This method: + 1. Optionally waits for Feature Group to reach 'Created' status + 2. Validates Feature Group status is 'Created' + 3. Registers the offline store S3 location as data lake location + 4. Grants the execution role permissions on the Glue table + 5. Revokes IAMAllowedPrincipal permissions from the Glue table + + The role ARN is automatically extracted from the Feature Group's configuration. + Each phase depends on the success of the previous phase - if any phase fails, + subsequent phases are not executed. + + Parameters: + session: Boto3 session. + region: Region name. + use_service_linked_role: Whether to use the Lake Formation service-linked role + for S3 registration. If True, Lake Formation uses its service-linked role. + If False, registration_role_arn must be provided. Default is True. + registration_role_arn: IAM role ARN to use for S3 registration with Lake Formation. + Required when use_service_linked_role is False. This can be different from the + Feature Group's execution role (role_arn) + wait_for_active: If True, waits for the Feature Group to reach 'Created' status + before enabling Lake Formation. Default is False. + show_s3_policy: If True, prints the S3 deny policy to the console after successful + Lake Formation setup. This policy should be added to your S3 bucket to restrict + access to only the allowed principals. Default is False. + + Returns: + Dict with status of each Lake Formation operation: + - s3_registration: bool + - iam_principal_revoked: bool + - permissions_granted: bool + + Raises: + ValueError: If the Feature Group has no offline store configured, + if role_arn is not set on the Feature Group, if use_service_linked_role + is False but registration_role_arn is not provided, or if the Feature Group + is not in 'Created' status. + ClientError: If Lake Formation operations fail. + RuntimeError: If a phase fails and subsequent phases cannot proceed. + """ + # Get region from session if not provided + if region is None and session is not None: + region = session.region_name() + + # Wait for Created status if requested + if wait_for_active: + self.wait_for_status(target_status="Created") + + # Refresh to get latest state + self.refresh() + + # Validate Feature Group status + if self.feature_group_status not in ["Created"]: + raise ValueError( + f"Feature Group '{self.feature_group_name}' must be in 'Created' status " + f"to enable Lake Formation. Current status: '{self.feature_group_status}'. " + f"Use wait_for_active=True to automatically wait for the Feature Group to be ready." + ) + + # Validate offline store exists + if self.offline_store_config is None or self.offline_store_config == Unassigned(): + raise ValueError( + f"Feature Group '{self.feature_group_name}' does not have an offline store configured. " + "Lake Formation can only be enabled for Feature Groups with offline stores." + ) + + # Get role ARN from Feature Group config + if self.role_arn is None or self.role_arn == Unassigned(): + raise ValueError( + f"Feature Group '{self.feature_group_name}' does not have a role_arn configured. " + "Lake Formation requires a role ARN to grant permissions." + ) + if not use_service_linked_role and registration_role_arn is None: + raise ValueError( + "Either 'use_service_linked_role' must be True or 'registration_role_arn' must be provided." + ) + + # Extract required configuration + s3_config = self.offline_store_config.s3_storage_config + if s3_config is None: + raise ValueError("Offline store S3 configuration is missing") + + resolved_s3_uri = s3_config.resolved_output_s3_uri + if resolved_s3_uri is None or resolved_s3_uri == Unassigned(): + raise ValueError( + "Resolved S3 URI not available. Ensure the Feature Group is in 'Created' status." + ) + + data_catalog_config = self.offline_store_config.data_catalog_config + if data_catalog_config is None: + raise ValueError("Data catalog configuration is missing from offline store config") + + database_name = data_catalog_config.database + table_name = data_catalog_config.table_name + + if not database_name or not table_name: + raise ValueError("Database name and table name are required from data catalog config") + + # Convert to str to handle PipelineVariable types + resolved_s3_uri_str = str(resolved_s3_uri) + database_name_str = str(database_name) + table_name_str = str(table_name) + role_arn_str = str(self.role_arn) + + # Execute Lake Formation setup with fail-fast behavior + results = { + "s3_registration": False, + "iam_principal_revoked": False, + "permissions_granted": False, + } + + # Phase 1: Register S3 with Lake Formation + try: + results["s3_registration"] = self._register_s3_with_lake_formation( + resolved_s3_uri_str, + session, + region, + use_service_linked_role=use_service_linked_role, + role_arn=registration_role_arn, + ) + except Exception as e: + raise RuntimeError( + f"Failed to register S3 location with Lake Formation. " + f"Subsequent phases skipped. Results: {results}. Error: {e}" + ) from e + + if not results["s3_registration"]: + raise RuntimeError( + f"Failed to register S3 location with Lake Formation. " + f"Subsequent phases skipped. Results: {results}" + ) + + # Phase 2: Grant Lake Formation permissions to the role + try: + results["permissions_granted"] = self._grant_lake_formation_permissions( + role_arn_str, database_name_str, table_name_str, session, region + ) + except Exception as e: + raise RuntimeError( + f"Failed to grant Lake Formation permissions. " + f"Subsequent phases skipped. Results: {results}. Error: {e}" + ) from e + + if not results["permissions_granted"]: + raise RuntimeError( + f"Failed to grant Lake Formation permissions. " + f"Subsequent phases skipped. Results: {results}" + ) + + # Phase 3: Revoke IAMAllowedPrincipal permissions + try: + results["iam_principal_revoked"] = self._revoke_iam_allowed_principal( + database_name_str, table_name_str, session, region + ) + except Exception as e: + raise RuntimeError( + f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}. Error: {e}" + ) from e + + if not results["iam_principal_revoked"]: + raise RuntimeError( + f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}" + ) + + logger.info(f"Lake Formation setup complete for {self.feature_group_name}: {results}") + + # Generate and optionally print S3 deny policy + if show_s3_policy: + # Extract bucket name and prefix from resolved S3 URI using core utility + bucket_name, s3_prefix = parse_s3_url(resolved_s3_uri_str) + + # Extract account ID from Feature Group ARN + feature_group_arn_str = str(self.feature_group_arn) if self.feature_group_arn else "" + account_id = self._extract_account_id_from_arn(feature_group_arn_str) + + # Determine Lake Formation role ARN based on use_service_linked_role flag + if use_service_linked_role: + lf_role_arn = self._get_lake_formation_service_linked_role_arn(account_id, region) + else: + # registration_role_arn is validated earlier when use_service_linked_role is False + lf_role_arn = str(registration_role_arn) + + # Generate the S3 deny policy + policy = self._generate_s3_deny_policy( + bucket_name=bucket_name, + s3_prefix=s3_prefix, + lake_formation_role_arn=lf_role_arn, + feature_store_role_arn=role_arn_str, + ) + + # Print policy with clear instructions + import json + + print("\n" + "=" * 80) + print("S3 Bucket Policy Update recommended") + print("=" * 80) + print( + "\nTo complete Lake Formation setup, add the following deny policy to your S3 bucket." + ) + print( + "This policy restricts access to the offline store data to only the allowed principals." + ) + print("\nBucket:", bucket_name) + print("\nPolicy to add:") + print("-" * 40) + print(json.dumps(policy, indent=2)) + print("-" * 40) + print("\nNote: Merge this with your existing bucket policy if one exists.") + print("=" * 80 + "\n") + + return results + + @classmethod + @Base.add_validate_call + def create( + cls, + feature_group_name: StrPipeVar, + record_identifier_feature_name: StrPipeVar, + event_time_feature_name: StrPipeVar, + feature_definitions: List[FeatureDefinition], + online_store_config: Optional[OnlineStoreConfig] = None, + offline_store_config: Optional[OfflineStoreConfig] = None, + throughput_config: Optional[ThroughputConfig] = None, + role_arn: Optional[StrPipeVar] = None, + description: Optional[StrPipeVar] = None, + tags: Optional[List[Tag]] = None, + use_pre_prod_offline_store_replicator_lambda: Optional[bool] = None, + lake_formation_config: Optional[LakeFormationConfig] = None, + session: Optional[Session] = None, + region: Optional[StrPipeVar] = None, + ) -> Optional["FeatureGroup"]: + """ + Create a FeatureGroup resource with optional Lake Formation governance. + + Parameters: + feature_group_name: The name of the FeatureGroup. + record_identifier_feature_name: The name of the Feature whose value uniquely + identifies a Record. + event_time_feature_name: The name of the feature that stores the EventTime. + feature_definitions: A list of Feature names and types. + online_store_config: Configuration for the OnlineStore. + offline_store_config: Configuration for the OfflineStore. + throughput_config: Throughput configuration. + role_arn: IAM execution role ARN for the OfflineStore. + description: A free-form description of the FeatureGroup. + tags: Tags used to identify Features in each FeatureGroup. + use_pre_prod_offline_store_replicator_lambda: Pre-prod replicator flag. + lake_formation_config: Optional LakeFormationConfig to configure Lake Formation + governance. When enabled=True, requires offline_store_config and role_arn. + session: Boto3 session. + region: Region name. + + Returns: + The FeatureGroup resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. + For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + """ + # Validation for Lake Formation + if lake_formation_config is not None and lake_formation_config.enabled: + if offline_store_config is None: + raise ValueError( + "lake_formation_config with enabled=True requires offline_store_config to be configured" + ) + if role_arn is None: + raise ValueError( + "lake_formation_config with enabled=True requires role_arn to be specified" + ) + if ( + not lake_formation_config.use_service_linked_role + and not lake_formation_config.registration_role_arn + ): + raise ValueError( + "registration_role_arn must be provided in lake_formation_config " + "when use_service_linked_role is False" + ) + + # Build kwargs, only including non-None values so parent uses its defaults + create_kwargs = { + "feature_group_name": feature_group_name, + "record_identifier_feature_name": record_identifier_feature_name, + "event_time_feature_name": event_time_feature_name, + "feature_definitions": feature_definitions, + "session": session, + "region": region, + } + if online_store_config is not None: + create_kwargs["online_store_config"] = online_store_config + if offline_store_config is not None: + create_kwargs["offline_store_config"] = offline_store_config + if throughput_config is not None: + create_kwargs["throughput_config"] = throughput_config + if role_arn is not None: + create_kwargs["role_arn"] = role_arn + if description is not None: + create_kwargs["description"] = description + if tags is not None: + create_kwargs["tags"] = tags + if use_pre_prod_offline_store_replicator_lambda is not None: + create_kwargs["use_pre_prod_offline_store_replicator_lambda"] = use_pre_prod_offline_store_replicator_lambda + + feature_group = super().create(**create_kwargs) + + # Enable Lake Formation if requested + if lake_formation_config is not None and lake_formation_config.enabled: + feature_group.wait_for_status(target_status="Created") + feature_group.enable_lake_formation( + session=session, + region=region, + use_service_linked_role=lake_formation_config.use_service_linked_role, + registration_role_arn=lake_formation_config.registration_role_arn, + show_s3_policy=lake_formation_config.show_s3_policy, + ) + return feature_group diff --git a/sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py b/sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py new file mode 100644 index 0000000000..dc6f12181a --- /dev/null +++ b/sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py @@ -0,0 +1,660 @@ +""" +Integration tests for Lake Formation with FeatureGroup. + +These tests require: +- AWS credentials with Lake Formation and SageMaker permissions +- An S3 bucket for offline store (uses default SageMaker bucket) +- An IAM role for Feature Store (uses execution role) + +Run with: pytest tests/integ/test_featureStore_lakeformation.py -v -m integ +""" + +import uuid + +import boto3 +import pytest +from botocore.exceptions import ClientError + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.mlops.feature_store import ( + FeatureGroup, + LakeFormationConfig, + OfflineStoreConfig, + OnlineStoreConfig, + S3StorageConfig, + StringFeatureDefinition, + FractionalFeatureDefinition, +) + +feature_definitions = [ + StringFeatureDefinition(feature_name="record_id"), + StringFeatureDefinition(feature_name="event_time"), + FractionalFeatureDefinition(feature_name="feature_value"), +] + + +@pytest.fixture(scope="module") +def sagemaker_session(): + return Session() + + +@pytest.fixture(scope="module") +def role(sagemaker_session): + return get_execution_role(sagemaker_session) + + +@pytest.fixture(scope="module") +def s3_uri(sagemaker_session): + bucket = sagemaker_session.default_bucket() + return f"s3://{bucket}/feature-store-test" + + +@pytest.fixture(scope="module") +def region(): + return "us-west-2" + + +@pytest.fixture(scope="module") +def shared_feature_group_for_negative_tests(s3_uri, role, region): + """ + Create a single FeatureGroup for negative tests that only need to verify + error conditions without modifying the resource. + + This fixture is module-scoped to be created once and shared across tests, + reducing test execution time. + """ + fg_name = f"test-lf-negative-{uuid.uuid4().hex[:8]}" + fg = None + + try: + fg = create_test_feature_group(fg_name, s3_uri, role, region) + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + yield fg + finally: + if fg: + cleanup_feature_group(fg) + + +def generate_feature_group_name(): + """Generate a unique feature group name for testing.""" + return f"test-lf-fg-{uuid.uuid4().hex[:8]}" + + +def create_test_feature_group(name: str, s3_uri: str, role_arn: str, region: str) -> FeatureGroup: + """Create a FeatureGroup with offline store for testing.""" + + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + + fg = FeatureGroup.create( + feature_group_name=name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=role_arn, + region=region, + ) + + return fg + + +def cleanup_feature_group(fg: FeatureGroup): + """ + Delete a FeatureGroup and its associated Glue table. + + Args: + fg: The FeatureGroup to delete. + """ + try: + # Delete the Glue table if it exists + if fg.offline_store_config is not None: + try: + fg.refresh() # Ensure we have latest config + data_catalog_config = fg.offline_store_config.data_catalog_config + if data_catalog_config is not None: + database_name = data_catalog_config.database + table_name = data_catalog_config.table_name + + if database_name and table_name: + glue_client = boto3.client("glue") + try: + glue_client.delete_table(DatabaseName=database_name, Name=table_name) + except ClientError as e: + # Ignore if table doesn't exist + if e.response["Error"]["Code"] != "EntityNotFoundException": + raise + except Exception: + # Don't fail cleanup if Glue table deletion fails + pass + + # Delete the FeatureGroup + fg.delete() + except ClientError: + # Don't fail cleanup if Glue table deletion fails + pass + + +@pytest.mark.serial +@pytest.mark.slow_test +def test_create_feature_group_and_enable_lake_formation(s3_uri, role, region): + """ + Test creating a FeatureGroup and enabling Lake Formation governance. + + This test: + 1. Creates a new FeatureGroup with offline store + 2. Waits for it to reach Created status + 3. Enables Lake Formation governance (registers S3, grants permissions, revokes IAM principals) + 4. Cleans up the FeatureGroup + """ + + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup + fg = create_test_feature_group(fg_name, s3_uri, role, region) + assert fg is not None + + # Wait for Created status + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + assert fg.feature_group_status == "Created" + + # Enable Lake Formation governance + result = fg.enable_lake_formation() + + # Verify all phases completed successfully + assert result["s3_registration"] is True + assert result["permissions_granted"] is True + assert result["iam_principal_revoked"] is True + + finally: + print('done') + # Cleanup + if fg: + cleanup_feature_group(fg) + + +@pytest.mark.serial +@pytest.mark.slow_test +def test_create_feature_group_with_lake_formation_enabled(s3_uri, role, region): + """ + Test creating a FeatureGroup with lake_formation_config.enabled=True. + + This test verifies the integrated workflow where Lake Formation is enabled + automatically during FeatureGroup creation: + 1. Creates a new FeatureGroup with lake_formation_config.enabled=True + 2. Verifies the FeatureGroup is created and Lake Formation is configured + 3. Cleans up the FeatureGroup + """ + + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup with Lake Formation enabled + + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + lake_formation_config = LakeFormationConfig() + lake_formation_config.enabled = True + + fg = FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=role, + lake_formation_config=lake_formation_config, + ) + + # Verify the FeatureGroup was created + assert fg is not None + assert fg.feature_group_name == fg_name + assert fg.feature_group_status == "Created" + + # Verify Lake Formation is configured by checking we can refresh without errors + fg.refresh() + assert fg.offline_store_config is not None + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +@pytest.mark.serial +def test_create_feature_group_without_lake_formation(s3_uri, role, region): + """ + Test creating a FeatureGroup without Lake Formation enabled. + + This test verifies that when lake_formation_config is not provided or enabled=False, + the FeatureGroup is created successfully without any Lake Formation operations: + 1. Creates a new FeatureGroup without lake_formation_config + 2. Verifies the FeatureGroup is created successfully + 3. Verifies no Lake Formation operations were performed + 4. Cleans up the FeatureGroup + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup without Lake Formation + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + + # Create without lake_formation_config (default behavior) + fg = FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=role, + ) + + # Verify the FeatureGroup was created + assert fg is not None + assert fg.feature_group_name == fg_name + + # Wait for Created status to ensure it's fully provisioned + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + assert fg.feature_group_status == "Created" + + # Verify offline store is configured + fg.refresh() + assert fg.offline_store_config is not None + assert fg.offline_store_config.s3_storage_config is not None + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +# ============================================================================ +# Negative Integration Tests +# ============================================================================ + + +def test_create_feature_group_with_lake_formation_fails_without_offline_store(role, region): + """ + Test that creating a FeatureGroup with enable_lake_formation=True fails + when no offline store is configured. + + Expected behavior: ValueError should be raised indicating offline store is required. + """ + fg_name = generate_feature_group_name() + + lake_formation_config = LakeFormationConfig() + lake_formation_config.enabled = True + + # Attempt to create without offline store but with Lake Formation enabled + with pytest.raises(ValueError) as exc_info: + FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + lake_formation_config=lake_formation_config, + ) + + # Verify error message mentions offline_store_config requirement + assert "lake_formation_config with enabled=True requires offline_store_config to be configured" in str( + exc_info.value + ) + + +def test_create_feature_group_with_lake_formation_fails_without_role(s3_uri, region): + """ + Test that creating a FeatureGroup with lake_formation_config.enabled=True fails + when no role_arn is provided. + + Expected behavior: ValueError should be raised indicating role_arn is required. + """ + fg_name = generate_feature_group_name() + + offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) + lake_formation_config = LakeFormationConfig() + lake_formation_config.enabled = True + + # Attempt to create without role_arn but with Lake Formation enabled + with pytest.raises(ValueError) as exc_info: + FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + lake_formation_config=lake_formation_config, + ) + + # Verify error message mentions role_arn requirement + assert "lake_formation_config with enabled=True requires role_arn to be specified" in str(exc_info.value) + + +def test_enable_lake_formation_fails_for_non_created_status(s3_uri, role, region): + """ + Test that enable_lake_formation() fails when called on a FeatureGroup + that is not in 'Created' status. + + Expected behavior: ValueError should be raised indicating the Feature Group + must be in 'Created' status. + + Note: This test creates its own FeatureGroup because it needs to test + behavior during the 'Creating' status, which requires a fresh resource. + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup + fg = create_test_feature_group(fg_name, s3_uri, role, region) + assert fg is not None + + # Immediately try to enable Lake Formation without waiting for Created status + # The Feature Group will be in 'Creating' status + with pytest.raises(ValueError) as exc_info: + fg.enable_lake_formation(wait_for_active=False) + + # Verify error message mentions status requirement + error_msg = str(exc_info.value) + assert "must be in 'Created' status to enable Lake Formation" in error_msg + + finally: + # Cleanup + if fg: + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + cleanup_feature_group(fg) + + +def test_enable_lake_formation_without_offline_store(role, region): + """ + Test that enable_lake_formation() fails when called on a FeatureGroup + without an offline store configured. + + Expected behavior: ValueError should be raised indicating offline store is required. + + Note: This test creates a FeatureGroup with only online store, which is a valid + configuration, but Lake Formation cannot be enabled for it. + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create a FeatureGroup with only online store (no offline store) + online_store_config = OnlineStoreConfig(enable_online_store=True) + + fg = FeatureGroup.create( + feature_group_name=fg_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + online_store_config=online_store_config, + role_arn=role, + ) + + # Wait for Created status + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + + # Attempt to enable Lake Formation + with pytest.raises(ValueError) as exc_info: + fg.enable_lake_formation() + # Verify error message mentions offline store requirement + assert "does not have an offline store configured" in str(exc_info.value) + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +def test_enable_lake_formation_fails_with_invalid_registration_role( + shared_feature_group_for_negative_tests, +): + """ + Test that enable_lake_formation() fails when use_service_linked_role=False + but no registration_role_arn is provided. + + Expected behavior: ValueError should be raised indicating registration_role_arn + is required when not using service-linked role. + """ + fg = shared_feature_group_for_negative_tests + + # Attempt to enable Lake Formation without service-linked role and without registration_role_arn + with pytest.raises(ValueError) as exc_info: + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=None, + ) + + # Verify error message mentions role requirement + error_msg = str(exc_info.value) + assert "registration_role_arn" in error_msg + + +def test_enable_lake_formation_fails_with_nonexistent_role( + shared_feature_group_for_negative_tests, role +): + """ + Test that enable_lake_formation() properly bubbles errors when using + a nonexistent role ARN for Lake Formation registration. + + Expected behavior: RuntimeError or ClientError should be raised with details + about the registration failure. + + Note: This test uses a nonexistent role ARN (current role with random suffix) + to trigger an error during S3 registration with Lake Formation. + """ + fg = shared_feature_group_for_negative_tests + + # Generate a nonexistent role ARN by appending a random string to the current role + nonexistent_role = f"{role}-nonexistent-{uuid.uuid4().hex[:8]}" + + with pytest.raises(RuntimeError) as exc_info: + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=nonexistent_role, + ) + + # Verify we got an appropriate error + error_msg = str(exc_info.value) + print(exc_info) + # Should mention role-related issues (not found, invalid, access denied, etc.) + assert "EntityNotFoundException" in error_msg + + +# ============================================================================ +# Full Flow Integration Tests with Policy Output +# ============================================================================ + + +@pytest.mark.serial +@pytest.mark.slow_test +def test_enable_lake_formation_full_flow_with_policy_output(s3_uri, role, region, capsys): + """ + Test the full Lake Formation flow with S3 deny policy output. + + This test verifies: + 1. Creates a FeatureGroup with offline store + 2. Enables Lake Formation with show_s3_policy=True + 3. Verifies all Lake Formation phases complete successfully + 4. Verifies the S3 deny policy is printed to the console + 5. Verifies the policy structure contains expected elements + + This validates Requirements 6.1-6.9 from the design document. + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup + fg = create_test_feature_group(fg_name, s3_uri, role, region) + assert fg is not None + + # Wait for Created status + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + assert fg.feature_group_status == "Created" + + # Enable Lake Formation governance with policy output + result = fg.enable_lake_formation(show_s3_policy=True) + + # Verify all phases completed successfully + assert result["s3_registration"] is True + assert result["permissions_granted"] is True + assert result["iam_principal_revoked"] is True + + # Capture the printed output + captured = capsys.readouterr() + output = captured.out + + # Re-print the output so it's visible in terminal with -s flag + print(output) + + # Verify the policy header is printed + assert "S3 Bucket Policy Update recommended" in output + assert "=" * 80 in output + + # Verify bucket information is printed + # Extract bucket name from s3_uri (s3://bucket/path -> bucket) + expected_bucket = s3_uri.replace("s3://", "").split("/")[0] + assert f"Bucket: {expected_bucket}" in output + + # Verify policy structure elements are present + assert '"Version": "2012-10-17"' in output + assert '"Statement"' in output + assert '"Effect": "Deny"' in output + assert '"Principal": "*"' in output + + # Verify the deny actions are present + assert "s3:GetObject" in output + assert "s3:PutObject" in output + assert "s3:DeleteObject" in output + assert "s3:ListBucket" in output + + # Verify the condition structure is present + assert "StringNotEquals" in output + assert "aws:PrincipalArn" in output + + # Verify the role ARN is in the allowed principals + assert role in output + + # Verify the service-linked role pattern is present (default use_service_linked_role=True) + assert "aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" in output + + # Verify instructions are printed + assert "Merge this with your existing bucket policy" in output + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +@pytest.mark.serial +@pytest.mark.slow_test +def test_enable_lake_formation_no_policy_output_by_default(s3_uri, role, region, capsys): + """ + Test that S3 deny policy is NOT printed when show_s3_policy=False (default). + + This test verifies: + 1. Creates a FeatureGroup with offline store + 2. Enables Lake Formation without show_s3_policy (defaults to False) + 3. Verifies all Lake Formation phases complete successfully + 4. Verifies the S3 deny policy is NOT printed to the console + + This validates Requirement 6.2 from the design document. + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup + fg = create_test_feature_group(fg_name, s3_uri, role, region) + assert fg is not None + + # Wait for Created status + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + assert fg.feature_group_status == "Created" + + # Enable Lake Formation governance WITHOUT policy output (default) + result = fg.enable_lake_formation() + + # Verify all phases completed successfully + assert result["s3_registration"] is True + assert result["permissions_granted"] is True + assert result["iam_principal_revoked"] is True + + # Capture the printed output + captured = capsys.readouterr() + output = captured.out + + # Verify the policy is NOT printed + assert "S3 Bucket Policy Update recommended" not in output + assert '"Version": "2012-10-17"' not in output + assert "s3:GetObject" not in output + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) + + +@pytest.mark.serial +@pytest.mark.slow_test +def test_enable_lake_formation_with_custom_role_policy_output(s3_uri, role, region, capsys): + """ + Test the full Lake Formation flow with custom registration role and policy output. + + This test verifies: + 1. Creates a FeatureGroup with offline store + 2. Enables Lake Formation with use_service_linked_role=False and a custom registration_role_arn + 3. Verifies the S3 deny policy uses the custom role ARN instead of service-linked role + + This validates Requirements 6.4, 6.5 from the design document. + + Note: This test uses the same execution role as the registration role for simplicity. + In production, these would typically be different roles. + """ + fg_name = generate_feature_group_name() + fg = None + + try: + # Create the FeatureGroup + fg = create_test_feature_group(fg_name, s3_uri, role, region) + assert fg is not None + + # Wait for Created status + fg.wait_for_status(target_status="Created", poll=30, timeout=300) + assert fg.feature_group_status == "Created" + + # Enable Lake Formation with custom registration role and policy output + # Using the same role for both execution and registration for test simplicity + result = fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=role, + show_s3_policy=True, + ) + + # Verify all phases completed successfully + assert result["s3_registration"] is True + assert result["permissions_granted"] is True + assert result["iam_principal_revoked"] is True + + # Capture the printed output + captured = capsys.readouterr() + output = captured.out + + # Verify the policy header is printed + assert "S3 Bucket Policy Update recommended" in output + + # Verify the custom role ARN is used in the policy (appears twice - once for each principal) + # The role should appear as both the Lake Formation role and the Feature Store role + assert output.count(role) >= 2 + + # Verify the service-linked role is NOT used + assert "aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" not in output + + finally: + # Cleanup + if fg: + cleanup_feature_group(fg) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py new file mode 100644 index 0000000000..e4d44df37a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py @@ -0,0 +1,2141 @@ +"""Unit tests for Lake Formation integration with FeatureGroup.""" +from unittest.mock import MagicMock, patch + +import botocore.exceptions +import pytest + +from sagemaker.mlops.feature_store import FeatureGroup, LakeFormationConfig + + +class TestS3UriToArn: + """Tests for _s3_uri_to_arn static method.""" + + def test_converts_s3_uri_to_arn(self): + """Test S3 URI is converted to ARN format.""" + uri = "s3://my-bucket/my-prefix/data" + result = FeatureGroup._s3_uri_to_arn(uri) + assert result == "arn:aws:s3:::my-bucket/my-prefix/data" + + def test_handles_bucket_only_uri(self): + """Test S3 URI with bucket only.""" + uri = "s3://my-bucket" + result = FeatureGroup._s3_uri_to_arn(uri) + assert result == "arn:aws:s3:::my-bucket" + + def test_returns_arn_unchanged(self): + """Test ARN input is returned unchanged (idempotent).""" + arn = "arn:aws:s3:::my-bucket/path" + result = FeatureGroup._s3_uri_to_arn(arn) + assert result == arn + + def test_uses_region_for_partition(self): + """Test that region is used to determine partition.""" + uri = "s3://my-bucket/path" + result = FeatureGroup._s3_uri_to_arn(uri, region="cn-north-1") + assert result.startswith("arn:aws-cn:s3:::") + + + +class TestGetLakeFormationClient: + """Tests for _get_lake_formation_client method.""" + + @patch("sagemaker.mlops.feature_store.feature_group.Session") + def test_creates_client_with_default_session(self, mock_session_class): + """Test client creation with default session.""" + mock_session = MagicMock() + mock_client = MagicMock() + mock_session.client.return_value = mock_client + mock_session_class.return_value = mock_session + + fg = MagicMock(spec=FeatureGroup) + fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg) + + client = fg._get_lake_formation_client(region="us-west-2") + + mock_session.client.assert_called_with("lakeformation", region_name="us-west-2") + assert client == mock_client + + def test_creates_client_with_provided_session(self): + """Test client creation with provided session.""" + mock_session = MagicMock() + mock_client = MagicMock() + mock_session.client.return_value = mock_client + + fg = MagicMock(spec=FeatureGroup) + fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg) + + client = fg._get_lake_formation_client(session=mock_session, region="us-west-2") + + mock_session.client.assert_called_with("lakeformation", region_name="us-west-2") + assert client == mock_client + + +class TestRegisterS3WithLakeFormation: + """Tests for _register_s3_with_lake_formation method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroup) + self.fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn + self.fg._register_s3_with_lake_formation = ( + FeatureGroup._register_s3_with_lake_formation.__get__(self.fg) + ) + self.mock_client = MagicMock() + self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) + + def test_successful_registration_returns_true(self): + """Test successful S3 registration returns True.""" + self.mock_client.register_resource.return_value = {} + + result = self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") + + assert result is True + self.mock_client.register_resource.assert_called_with( + ResourceArn="arn:aws:s3:::test-bucket/prefix", + UseServiceLinkedRole=True, + ) + + def test_already_exists_exception_returns_true(self): + """Test AlreadyExistsException is handled gracefully.""" + self.mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AlreadyExistsException", "Message": "Already exists"}}, + "RegisterResource", + ) + + result = self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") + + assert result is True + + def test_other_exceptions_are_propagated(self): + """Test non-AlreadyExistsException errors are propagated.""" + self.mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, + "RegisterResource", + ) + + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") + + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + def test_uses_service_linked_role(self): + """Test UseServiceLinkedRole is set to True.""" + self.mock_client.register_resource.return_value = {} + + self.fg._register_s3_with_lake_formation("s3://bucket/path") + + call_args = self.mock_client.register_resource.call_args + assert call_args[1]["UseServiceLinkedRole"] is True + + def test_uses_custom_role_arn_when_service_linked_role_disabled(self): + """Test custom role ARN is used when use_service_linked_role is False.""" + self.mock_client.register_resource.return_value = {} + custom_role = "arn:aws:iam::123456789012:role/CustomLakeFormationRole" + + self.fg._register_s3_with_lake_formation( + "s3://bucket/path", + use_service_linked_role=False, + role_arn=custom_role, + ) + + call_args = self.mock_client.register_resource.call_args + assert call_args[1]["RoleArn"] == custom_role + assert "UseServiceLinkedRole" not in call_args[1] + + def test_raises_error_when_role_arn_missing_and_service_linked_role_disabled(self): + """Test ValueError when use_service_linked_role is False but role_arn not provided.""" + with pytest.raises(ValueError) as exc_info: + self.fg._register_s3_with_lake_formation( + "s3://bucket/path", use_service_linked_role=False + ) + + assert "role_arn must be provided when use_service_linked_role is False" in str( + exc_info.value + ) + + + +class TestRevokeIamAllowedPrincipal: + """Tests for _revoke_iam_allowed_principal method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroup) + self.fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__( + self.fg + ) + self.mock_client = MagicMock() + self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) + + def test_successful_revocation_returns_true(self): + """Test successful revocation returns True.""" + self.mock_client.revoke_permissions.return_value = {} + + result = self.fg._revoke_iam_allowed_principal("test_database", "test_table") + + assert result is True + self.mock_client.revoke_permissions.assert_called_once() + + def test_revoke_permissions_call_structure(self): + """Test that revoke_permissions is called with correct parameters.""" + self.mock_client.revoke_permissions.return_value = {} + database_name = "my_database" + table_name = "my_table" + + self.fg._revoke_iam_allowed_principal(database_name, table_name) + + call_args = self.mock_client.revoke_permissions.call_args + assert call_args[1]["Principal"] == { + "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" + } + assert call_args[1]["Permissions"] == ["ALL"] + assert call_args[1]["Resource"] == { + "Table": { + "DatabaseName": database_name, + "Name": table_name, + } + } + + def test_invalid_input_exception_returns_true(self): + """Test InvalidInputException is handled gracefully (permissions may not exist).""" + self.mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Permissions not found"}}, + "RevokePermissions", + ) + + result = self.fg._revoke_iam_allowed_principal("test_database", "test_table") + + assert result is True + + def test_other_exceptions_are_propagated(self): + """Test non-InvalidInputException errors are propagated.""" + self.mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, + "RevokePermissions", + ) + + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + self.fg._revoke_iam_allowed_principal("test_database", "test_table") + + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + def test_passes_session_and_region_to_client(self): + """Test session and region are passed to get_lake_formation_client.""" + self.mock_client.revoke_permissions.return_value = {} + mock_session = MagicMock() + + self.fg._revoke_iam_allowed_principal( + "test_database", "test_table", session=mock_session, region="us-west-2" + ) + + self.fg._get_lake_formation_client.assert_called_with(mock_session, "us-west-2") + + + +class TestGrantLakeFormationPermissions: + """Tests for _grant_lake_formation_permissions method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroup) + self.fg._grant_lake_formation_permissions = ( + FeatureGroup._grant_lake_formation_permissions.__get__(self.fg) + ) + self.mock_client = MagicMock() + self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) + + def test_successful_grant_returns_true(self): + """Test successful permission grant returns True.""" + self.mock_client.grant_permissions.return_value = {} + + result = self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + assert result is True + self.mock_client.grant_permissions.assert_called_once() + + def test_grant_permissions_call_structure(self): + """Test that grant_permissions is called with correct parameters.""" + self.mock_client.grant_permissions.return_value = {} + role_arn = "arn:aws:iam::123456789012:role/MyExecutionRole" + + self.fg._grant_lake_formation_permissions(role_arn, "my_database", "my_table") + + call_args = self.mock_client.grant_permissions.call_args + assert call_args[1]["Principal"] == {"DataLakePrincipalIdentifier": role_arn} + assert call_args[1]["Resource"] == { + "Table": { + "DatabaseName": "my_database", + "Name": "my_table", + } + } + assert call_args[1]["Permissions"] == ["SELECT", "INSERT", "DELETE", "DESCRIBE", "ALTER"] + assert call_args[1]["PermissionsWithGrantOption"] == [] + + def test_invalid_input_exception_returns_true(self): + """Test InvalidInputException is handled gracefully (permissions may exist).""" + self.mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Permissions already exist"}}, + "GrantPermissions", + ) + + result = self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + assert result is True + + def test_other_exceptions_are_propagated(self): + """Test non-InvalidInputException errors are propagated.""" + self.mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, + "GrantPermissions", + ) + + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + + def test_passes_session_and_region_to_client(self): + """Test session and region are passed to get_lake_formation_client.""" + self.mock_client.grant_permissions.return_value = {} + mock_session = MagicMock() + + self.fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", + "test_database", + "test_table", + session=mock_session, + region="us-west-2", + ) + + self.fg._get_lake_formation_client.assert_called_with(mock_session, "us-west-2") + + + +class TestEnableLakeFormationValidation: + """Tests for enable_lake_formation validation logic.""" + + @patch.object(FeatureGroup, "refresh") + def test_raises_error_when_no_offline_store(self, mock_refresh): + """Test that enable_lake_formation raises ValueError when no offline store is configured.""" + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = None + fg.feature_group_status = "Created" + + with pytest.raises(ValueError, match="does not have an offline store configured"): + fg.enable_lake_formation() + + # Verify refresh was called + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "refresh") + def test_raises_error_when_no_role_arn(self, mock_refresh): + """Test that enable_lake_formation raises ValueError when no role_arn is configured.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = None + fg.feature_group_status = "Created" + + with pytest.raises(ValueError, match="does not have a role_arn configured"): + fg.enable_lake_formation() + + # Verify refresh was called + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "refresh") + def test_raises_error_when_invalid_status(self, mock_refresh): + """Test enable_lake_formation raises ValueError when Feature Group not in Created status.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" + fg.feature_group_status = "Creating" + + with pytest.raises(ValueError, match="must be in 'Created' status"): + fg.enable_lake_formation() + + # Verify refresh was called + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + def test_wait_for_active_calls_wait_for_status( + self, mock_revoke, mock_grant, mock_register, mock_refresh, mock_wait + ): + """Test that wait_for_active=True calls wait_for_status with 'Created' target.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call with wait_for_active=True + fg.enable_lake_formation(wait_for_active=True) + + # Verify wait_for_status was called with "Created" + mock_wait.assert_called_once_with(target_status="Created") + # Verify refresh was called after wait + mock_refresh.assert_called_once() + + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + def test_wait_for_active_false_does_not_call_wait( + self, mock_revoke, mock_grant, mock_register, mock_refresh, mock_wait + ): + """Test that wait_for_active=False does not call wait_for_status.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call with wait_for_active=False (default) + fg.enable_lake_formation(wait_for_active=False) + + # Verify wait_for_status was NOT called + mock_wait.assert_not_called() + # Verify refresh was still called + mock_refresh.assert_called_once() + + + @pytest.mark.parametrize( + "feature_group_name,role_arn,s3_uri,database_name,table_name", + [ + ("test-fg", "TestRole", "path1", "db1", "table1"), + ("my_feature_group", "ExecutionRole", "data/features", "feature_db", "feature_table"), + ("fg123", "MyRole123", "ml/features/v1", "analytics", "features_v1"), + ("simple", "SimpleRole", "simple-path", "simple_db", "simple_table"), + ( + "complex-name", + "ComplexExecutionRole", + "complex/path/structure", + "complex_database", + "complex_table_name", + ), + ( + "underscore_name", + "Underscore_Role", + "underscore_path", + "underscore_db", + "underscore_table", + ), + ("mixed-123", "Mixed123Role", "mixed/path/123", "mixed_db_123", "mixed_table_123"), + ("x", "XRole", "x", "x", "x"), + ( + "very-long-name", + "VeryLongRoleName", + "very/long/path/structure", + "very_long_database_name", + "very_long_table_name", + ), + ], + ) + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + def test_fail_fast_phase_execution( + self, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + feature_group_name, + role_arn, + s3_uri, + database_name, + table_name, + ): + """ + Test fail-fast behavior for Lake Formation phases. + + If Phase 1 (S3 registration) fails, Phase 2 and 3 should not execute. + If Phase 2 fails, Phase 3 should not execute. + RuntimeError should indicate which phase failed. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name=feature_group_name) + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri=f"s3://test-bucket/{s3_uri}", + resolved_output_s3_uri=f"s3://test-bucket/resolved-{s3_uri}", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database=database_name, table_name=table_name + ), + ) + fg.role_arn = f"arn:aws:iam::123456789012:role/{role_arn}" + fg.feature_group_status = "Created" + + # Test Phase 1 failure - subsequent phases should not be called + mock_register.side_effect = Exception("Phase 1 failed") + mock_grant.return_value = True + mock_revoke.return_value = True + + with pytest.raises( + RuntimeError, match="Failed to register S3 location with Lake Formation" + ): + fg.enable_lake_formation() + + # Verify Phase 1 was called but Phase 2 and 3 were not + mock_register.assert_called_once() + mock_grant.assert_not_called() + mock_revoke.assert_not_called() + + # Reset mocks for Phase 2 failure test + mock_register.reset_mock() + mock_grant.reset_mock() + mock_revoke.reset_mock() + + # Test Phase 2 failure - Phase 3 should not be called + mock_register.side_effect = None + mock_register.return_value = True + mock_grant.side_effect = Exception("Phase 2 failed") + mock_revoke.return_value = True + + with pytest.raises(RuntimeError, match="Failed to grant Lake Formation permissions"): + fg.enable_lake_formation() + + # Verify Phase 1 and 2 were called but Phase 3 was not + mock_register.assert_called_once() + mock_grant.assert_called_once() + mock_revoke.assert_not_called() + + # Reset mocks for Phase 3 failure test + mock_register.reset_mock() + mock_grant.reset_mock() + mock_revoke.reset_mock() + + # Test Phase 3 failure - all phases should be called + mock_register.side_effect = None + mock_register.return_value = True + mock_grant.side_effect = None + mock_grant.return_value = True + mock_revoke.side_effect = Exception("Phase 3 failed") + + with pytest.raises(RuntimeError, match="Failed to revoke IAMAllowedPrincipal permissions"): + fg.enable_lake_formation() + + # Verify all phases were called + mock_register.assert_called_once() + mock_grant.assert_called_once() + mock_revoke.assert_called_once() + + + +class TestUnhandledExceptionPropagation: + """Tests for proper propagation of unhandled boto3 exceptions.""" + + def test_register_s3_propagates_unhandled_exceptions(self): + """ + Non-AlreadyExists Errors Propagate from S3 Registration + + For any error from Lake Formation's register_resource API that is not + AlreadyExistsException, the error should be propagated to the caller unchanged. + + """ + fg = MagicMock(spec=FeatureGroup) + fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn + fg._register_s3_with_lake_formation = FeatureGroup._register_s3_with_lake_formation.__get__( + fg + ) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Configure mock to raise an unhandled error + mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User does not have permission", + } + }, + "RegisterResource", + ) + + # Verify the exception is propagated unchanged + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + fg._register_s3_with_lake_formation("s3://test-bucket/path") + + # Verify error details are preserved + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + assert exc_info.value.response["Error"]["Message"] == "User does not have permission" + assert exc_info.value.operation_name == "RegisterResource" + + def test_revoke_iam_principal_propagates_unhandled_exceptions(self): + """ + Non-InvalidInput Errors Propagate from IAM Principal Revocation + + For any error from Lake Formation's revoke_permissions API that is not + InvalidInputException, the error should be propagated to the caller unchanged. + + """ + fg = MagicMock(spec=FeatureGroup) + fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__(fg) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Configure mock to raise an unhandled error + mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User does not have permission", + } + }, + "RevokePermissions", + ) + + # Verify the exception is propagated unchanged + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + fg._revoke_iam_allowed_principal("test_database", "test_table") + + # Verify error details are preserved + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + assert exc_info.value.response["Error"]["Message"] == "User does not have permission" + assert exc_info.value.operation_name == "RevokePermissions" + + def test_grant_permissions_propagates_unhandled_exceptions(self): + """ + Non-InvalidInput Errors Propagate from Permission Grant + + For any error from Lake Formation's grant_permissions API that is not + InvalidInputException, the error should be propagated to the caller unchanged. + + """ + fg = MagicMock(spec=FeatureGroup) + fg._grant_lake_formation_permissions = ( + FeatureGroup._grant_lake_formation_permissions.__get__(fg) + ) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Configure mock to raise an unhandled error + mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + { + "Error": { + "Code": "AccessDeniedException", + "Message": "User does not have permission", + } + }, + "GrantPermissions", + ) + + # Verify the exception is propagated unchanged + with pytest.raises(botocore.exceptions.ClientError) as exc_info: + fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" + ) + + # Verify error details are preserved + assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" + assert exc_info.value.response["Error"]["Message"] == "User does not have permission" + assert exc_info.value.operation_name == "GrantPermissions" + + def test_handled_exceptions_do_not_propagate(self): + """ + Verify that specifically handled exceptions (AlreadyExistsException, InvalidInputException) + do NOT propagate but return True instead, while all other exceptions are propagated. + """ + fg = MagicMock(spec=FeatureGroup) + fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn + fg._register_s3_with_lake_formation = FeatureGroup._register_s3_with_lake_formation.__get__( + fg + ) + fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__(fg) + fg._grant_lake_formation_permissions = ( + FeatureGroup._grant_lake_formation_permissions.__get__(fg) + ) + mock_client = MagicMock() + fg._get_lake_formation_client = MagicMock(return_value=mock_client) + + # Test AlreadyExistsException is handled (not propagated) + mock_client.register_resource.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "AlreadyExistsException", "Message": "Already exists"}}, + "RegisterResource", + ) + result = fg._register_s3_with_lake_formation("s3://test-bucket/path") + assert result is True # Should return True, not raise + + # Test InvalidInputException is handled for revoke (not propagated) + mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Invalid input"}}, + "RevokePermissions", + ) + result = fg._revoke_iam_allowed_principal("db", "table") + assert result is True # Should return True, not raise + + # Test InvalidInputException is handled for grant (not propagated) + mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( + {"Error": {"Code": "InvalidInputException", "Message": "Invalid input"}}, + "GrantPermissions", + ) + result = fg._grant_lake_formation_permissions( + "arn:aws:iam::123456789012:role/TestRole", "db", "table" + ) + assert result is True # Should return True, not raise + + + +class TestCreateWithLakeFormation: + """Tests for create() method with Lake Formation integration.""" + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature", + [ + ("test-fg", "record_id", "event_time"), + ("my_feature_group", "id", "timestamp"), + ("fg123", "identifier", "time"), + ("simple", "rec_id", "evt_time"), + ("complex-name", "record_identifier", "event_timestamp"), + ("underscore_name", "record_id_field", "event_time_field"), + ("mixed-123", "id_123", "time_123"), + ("x", "x_id", "x_time"), + ("very-long-name", "very_long_record_id", "very_long_event_time"), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroup, "get") + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "enable_lake_formation") + def test_no_lake_formation_operations_when_disabled( + self, + mock_enable_lf, + mock_wait, + mock_get, + mock_get_client, + feature_group_name, + record_id_feature, + event_time_feature, + ): + """ + No Lake Formation Operations When Disabled + + For any call to FeatureGroup.create() where lake_formation_config is None or has enabled=False, + no Lake Formation client methods should be invoked. + + """ + from sagemaker.core.shapes import FeatureDefinition + + # Mock the SageMaker client + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + # Mock the get method to return a feature group + mock_fg = MagicMock(spec=FeatureGroup) + mock_fg.feature_group_name = feature_group_name + mock_get.return_value = mock_fg + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Test 1: lake_formation_config with enabled=False (explicit) + lf_config = LakeFormationConfig() + lf_config.enabled = False + result = FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + lake_formation_config=lf_config, + ) + + # Verify enable_lake_formation was NOT called + mock_enable_lf.assert_not_called() + # Verify wait_for_status was NOT called + mock_wait.assert_not_called() + # Verify the feature group was returned + assert result == mock_fg + + # Reset mocks for next test + mock_enable_lf.reset_mock() + mock_wait.reset_mock() + mock_get.reset_mock() + mock_get.return_value = mock_fg + + # Test 2: lake_formation_config not specified (defaults to None) + result = FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + # lake_formation_config not specified, should default to None + ) + + # Verify enable_lake_formation was NOT called + mock_enable_lf.assert_not_called() + # Verify wait_for_status was NOT called + mock_wait.assert_not_called() + # Verify the feature group was returned + assert result == mock_fg + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature,role_arn,s3_uri,database,table", + [ + ("test-fg", "record_id", "event_time", "TestRole", "path1", "db1", "table1"), + ( + "my_feature_group", + "id", + "timestamp", + "ExecutionRole", + "data/features", + "feature_db", + "feature_table", + ), + ( + "fg123", + "identifier", + "time", + "MyRole123", + "ml/features/v1", + "analytics", + "features_v1", + ), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroup, "get") + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "enable_lake_formation") + def test_enable_lake_formation_called_when_enabled( + self, + mock_enable_lf, + mock_wait, + mock_get, + mock_get_client, + feature_group_name, + record_id_feature, + event_time_feature, + role_arn, + s3_uri, + database, + table, + ): + """ + Test that enable_lake_formation is called when lake_formation_config has enabled=True. + + This verifies the integration between create() and enable_lake_formation(). + """ + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + # Mock the SageMaker client + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + # Mock the get method to return a feature group + mock_fg = MagicMock(spec=FeatureGroup) + mock_fg.feature_group_name = feature_group_name + mock_fg.wait_for_status = mock_wait + mock_fg.enable_lake_formation = mock_enable_lf + mock_get.return_value = mock_fg + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Create offline store config + offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database=database, table_name=table + ), + ) + + # Create LakeFormationConfig with enabled=True + lf_config = LakeFormationConfig() + lf_config.enabled = True + + # Create with lake_formation_config enabled=True + result = FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=f"arn:aws:iam::123456789012:role/{role_arn}", + lake_formation_config=lf_config, + ) + + # Verify wait_for_status was called with "Created" + mock_wait.assert_called_once_with(target_status="Created") + # Verify enable_lake_formation was called with default use_service_linked_role=True + mock_enable_lf.assert_called_once_with( + session=None, + region=None, + use_service_linked_role=True, + registration_role_arn=None, + show_s3_policy=False, + ) + # Verify the feature group was returned + assert result == mock_fg + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature", + [ + ("test-fg", "record_id", "event_time"), + ("my_feature_group", "id", "timestamp"), + ("fg123", "identifier", "time"), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_lake_formation_enabled_without_offline_store( + self, mock_get_client, feature_group_name, record_id_feature, event_time_feature + ): + """Test create() raises ValueError when lake_formation_config enabled=True without offline_store.""" + from sagemaker.core.shapes import FeatureDefinition + + # Mock the SageMaker client + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Create LakeFormationConfig with enabled=True + lf_config = LakeFormationConfig() + lf_config.enabled = True + + # Test with lake_formation_config enabled=True but no offline_store_config + with pytest.raises( + ValueError, + match="lake_formation_config with enabled=True requires offline_store_config to be configured", + ): + FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + lake_formation_config=lf_config, + # offline_store_config not provided + ) + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature,s3_uri,database,table", + [ + ("test-fg", "record_id", "event_time", "path1", "db1", "table1"), + ("my_feature_group", "id", "timestamp", "data/features", "feature_db", "feature_table"), + ("fg123", "identifier", "time", "ml/features/v1", "analytics", "features_v1"), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_lake_formation_enabled_without_role_arn( + self, + mock_get_client, + feature_group_name, + record_id_feature, + event_time_feature, + s3_uri, + database, + table, + ): + """Test create() raises ValueError when lake_formation_config enabled=True without role_arn.""" + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + # Mock the SageMaker client + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Create offline store config + offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database=database, table_name=table + ), + ) + + # Create LakeFormationConfig with enabled=True + lf_config = LakeFormationConfig() + lf_config.enabled = True + + # Test with lake_formation_config enabled=True but no role_arn + with pytest.raises( + ValueError, match="lake_formation_config with enabled=True requires role_arn to be specified" + ): + FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + lake_formation_config=lf_config, + # role_arn not provided + ) + + + @pytest.mark.parametrize( + "feature_group_name,record_id_feature,event_time_feature,role_arn,s3_uri,database,table,use_slr", + [ + ("test-fg", "record_id", "event_time", "TestRole", "path1", "db1", "table1", True), + ("my_feature_group", "id", "timestamp", "ExecutionRole", "data/features", "feature_db", "feature_table", False), + ("fg123", "identifier", "time", "MyRole123", "ml/features/v1", "analytics", "features_v1", True), + ], + ) + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroup, "get") + @patch.object(FeatureGroup, "wait_for_status") + @patch.object(FeatureGroup, "enable_lake_formation") + def test_use_service_linked_role_extraction_from_config( + self, + mock_enable_lf, + mock_wait, + mock_get, + mock_get_client, + feature_group_name, + record_id_feature, + event_time_feature, + role_arn, + s3_uri, + database, + table, + use_slr, + ): + """ + Test that use_service_linked_role is correctly extracted from lake_formation_config. + + Verifies: + - use_service_linked_role defaults to True when not specified + - use_service_linked_role is passed correctly to enable_lake_formation() + """ + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + # Mock the SageMaker client + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + # Mock the get method to return a feature group + mock_fg = MagicMock(spec=FeatureGroup) + mock_fg.feature_group_name = feature_group_name + mock_fg.wait_for_status = mock_wait + mock_fg.enable_lake_formation = mock_enable_lf + mock_get.return_value = mock_fg + + # Create feature definitions + feature_definitions = [ + FeatureDefinition(feature_name=record_id_feature, feature_type="String"), + FeatureDefinition(feature_name=event_time_feature, feature_type="String"), + ] + + # Create offline store config + offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database=database, table_name=table + ), + ) + + # Build LakeFormationConfig with use_service_linked_role + lf_config = LakeFormationConfig() + lf_config.enabled = True + lf_config.use_service_linked_role = use_slr + # When use_service_linked_role is False, registration_role_arn is required + expected_registration_role = None + if not use_slr: + lf_config.registration_role_arn = "arn:aws:iam::123456789012:role/LFRegistrationRole" + expected_registration_role = "arn:aws:iam::123456789012:role/LFRegistrationRole" + + # Create with lake_formation_config + result = FeatureGroup.create( + feature_group_name=feature_group_name, + record_identifier_feature_name=record_id_feature, + event_time_feature_name=event_time_feature, + feature_definitions=feature_definitions, + offline_store_config=offline_store_config, + role_arn=f"arn:aws:iam::123456789012:role/{role_arn}", + lake_formation_config=lf_config, + ) + + # Verify enable_lake_formation was called with correct use_service_linked_role value + mock_enable_lf.assert_called_once_with( + session=None, + region=None, + use_service_linked_role=use_slr, + registration_role_arn=expected_registration_role, + show_s3_policy=False, + ) + # Verify the feature group was returned + assert result == mock_fg + + +class TestExtractAccountIdFromArn: + """Tests for _extract_account_id_from_arn static method.""" + + def test_extracts_account_id_from_sagemaker_arn(self): + """Test extracting account ID from a SageMaker Feature Group ARN.""" + arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-feature-group" + result = FeatureGroup._extract_account_id_from_arn(arn) + assert result == "123456789012" + + def test_raises_value_error_for_invalid_arn_too_few_parts(self): + """Test that ValueError is raised for ARN with fewer than 5 colon-separated parts.""" + invalid_arn = "arn:aws:sagemaker:us-west-2" # Only 4 parts + with pytest.raises(ValueError, match="Invalid ARN format"): + FeatureGroup._extract_account_id_from_arn(invalid_arn) + + def test_raises_value_error_for_empty_string(self): + """Test that ValueError is raised for empty string.""" + with pytest.raises(ValueError, match="Invalid ARN format"): + FeatureGroup._extract_account_id_from_arn("") + + def test_raises_value_error_for_non_arn_string(self): + """Test that ValueError is raised for non-ARN string.""" + with pytest.raises(ValueError, match="Invalid ARN format"): + FeatureGroup._extract_account_id_from_arn("not-an-arn") + + def test_raises_value_error_for_s3_uri(self): + """Test that ValueError is raised for S3 URI (not ARN).""" + with pytest.raises(ValueError, match="Invalid ARN format"): + FeatureGroup._extract_account_id_from_arn("s3://my-bucket/my-prefix") + + def test_handles_arn_with_resource_path(self): + """Test extracting account ID from ARN with complex resource path.""" + arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-fg/version/1" + result = FeatureGroup._extract_account_id_from_arn(arn) + assert result == "123456789012" + + +class TestGetLakeFormationServiceLinkedRoleArn: + """Tests for _get_lake_formation_service_linked_role_arn static method.""" + + def test_generates_correct_service_linked_role_arn(self): + """Test that the method generates the correct service-linked role ARN format.""" + account_id = "123456789012" + result = FeatureGroup._get_lake_formation_service_linked_role_arn(account_id) + expected = "arn:aws:iam::123456789012:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" + assert result == expected + + def test_uses_region_for_partition(self): + """Test that region is used to determine partition.""" + account_id = "123456789012" + result = FeatureGroup._get_lake_formation_service_linked_role_arn(account_id, region="cn-north-1") + assert result.startswith("arn:aws-cn:iam::") + + + +class TestGenerateS3DenyPolicy: + """Tests for _generate_s3_deny_policy method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroup) + self.fg._generate_s3_deny_policy = FeatureGroup._generate_s3_deny_policy.__get__(self.fg) + + def test_policy_includes_correct_bucket_arn_in_object_statement(self): + """Test that the policy includes correct bucket ARN and prefix in object actions statement.""" + bucket_name = "my-feature-store-bucket" + s3_prefix = "feature-store/data/my-feature-group" + lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" + fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + + policy = self.fg._generate_s3_deny_policy( + bucket_name=bucket_name, + s3_prefix=s3_prefix, + lake_formation_role_arn=lf_role_arn, + feature_store_role_arn=fs_role_arn, + ) + + # Verify the object actions statement has correct Resource ARN + object_statement = policy["Statement"][0] + expected_resource = f"arn:aws:s3:::{bucket_name}/{s3_prefix}/*" + assert object_statement["Resource"] == expected_resource + + def test_policy_includes_correct_bucket_arn_in_list_statement(self): + """Test that the policy includes correct bucket ARN in ListBucket statement.""" + bucket_name = "my-feature-store-bucket" + s3_prefix = "feature-store/data/my-feature-group" + lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" + fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + + policy = self.fg._generate_s3_deny_policy( + bucket_name=bucket_name, + s3_prefix=s3_prefix, + lake_formation_role_arn=lf_role_arn, + feature_store_role_arn=fs_role_arn, + ) + + # Verify the ListBucket statement has correct Resource ARN (bucket only) + list_statement = policy["Statement"][1] + expected_resource = f"arn:aws:s3:::{bucket_name}" + assert list_statement["Resource"] == expected_resource + + def test_policy_includes_correct_prefix_condition_in_list_statement(self): + """Test that the policy includes correct prefix condition in ListBucket statement.""" + bucket_name = "my-feature-store-bucket" + s3_prefix = "feature-store/data/my-feature-group" + lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" + fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + + policy = self.fg._generate_s3_deny_policy( + bucket_name=bucket_name, + s3_prefix=s3_prefix, + lake_formation_role_arn=lf_role_arn, + feature_store_role_arn=fs_role_arn, + ) + + # Verify the ListBucket statement has correct prefix condition + list_statement = policy["Statement"][1] + expected_prefix = f"{s3_prefix}/*" + assert list_statement["Condition"]["StringLike"]["s3:prefix"] == expected_prefix + + def test_policy_preserves_bucket_name_exactly(self): + """Test that bucket name is preserved exactly without modification.""" + # Test with various bucket name formats + test_cases = [ + "simple-bucket", + "bucket.with.dots", + "bucket-with-dashes-123", + "mybucket", + "a" * 63, # Max bucket name length + ] + + for bucket_name in test_cases: + policy = self.fg._generate_s3_deny_policy( + bucket_name=bucket_name, + s3_prefix="prefix", + lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", + feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", + ) + + # Verify bucket name is preserved in both statements + assert bucket_name in policy["Statement"][0]["Resource"] + assert bucket_name in policy["Statement"][1]["Resource"] + + def test_policy_preserves_prefix_exactly(self): + """Test that S3 prefix is preserved exactly without modification.""" + # Test with various prefix formats + test_cases = [ + "simple-prefix", + "path/to/data", + "feature-store/account-id/region/feature-group-name", + "deep/nested/path/structure/data", + "prefix_with_underscores", + "prefix-with-dashes", + ] + + for s3_prefix in test_cases: + policy = self.fg._generate_s3_deny_policy( + bucket_name="test-bucket", + s3_prefix=s3_prefix, + lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", + feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", + ) + + # Verify prefix is preserved in object statement Resource + assert f"{s3_prefix}/*" in policy["Statement"][0]["Resource"] + # Verify prefix is preserved in list statement Condition + assert policy["Statement"][1]["Condition"]["StringLike"]["s3:prefix"] == f"{s3_prefix}/*" + + def test_policy_has_correct_s3_arn_format(self): + """Test that the policy uses correct S3 ARN format (arn:aws:s3:::bucket/path).""" + bucket_name = "test-bucket" + s3_prefix = "test/prefix" + + policy = self.fg._generate_s3_deny_policy( + bucket_name=bucket_name, + s3_prefix=s3_prefix, + lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", + feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", + ) + + # Verify object statement Resource starts with correct ARN prefix + object_resource = policy["Statement"][0]["Resource"] + assert object_resource.startswith("arn:aws:s3:::") + assert object_resource == f"arn:aws:s3:::{bucket_name}/{s3_prefix}/*" + + # Verify list statement Resource is bucket-only ARN + list_resource = policy["Statement"][1]["Resource"] + assert list_resource.startswith("arn:aws:s3:::") + assert list_resource == f"arn:aws:s3:::{bucket_name}" + + def test_policy_structure_validation(self): + """Test that the policy has correct overall structure.""" + policy = self.fg._generate_s3_deny_policy( + bucket_name="test-bucket", + s3_prefix="test/prefix", + lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", + feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", + ) + + # Verify policy version + assert policy["Version"] == "2012-10-17" + + # Verify exactly two statements + assert len(policy["Statement"]) == 2 + + # Verify first statement structure (object actions) + object_statement = policy["Statement"][0] + assert object_statement["Sid"] == "DenyAllAccessToFeatureStorePrefixExceptAllowedPrincipals" + assert object_statement["Effect"] == "Deny" + assert object_statement["Principal"] == "*" + assert "Condition" in object_statement + assert "StringNotEquals" in object_statement["Condition"] + + # Verify second statement structure (list bucket) + list_statement = policy["Statement"][1] + assert list_statement["Sid"] == "DenyListOnPrefixExceptAllowedPrincipals" + assert list_statement["Effect"] == "Deny" + assert list_statement["Principal"] == "*" + assert "Condition" in list_statement + assert "StringLike" in list_statement["Condition"] + assert "StringNotEquals" in list_statement["Condition"] + + def test_policy_includes_both_principals_in_allowed_list(self): + """Test that both Lake Formation role and Feature Store role are in allowed principals.""" + lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" + fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + + policy = self.fg._generate_s3_deny_policy( + bucket_name="test-bucket", + s3_prefix="test/prefix", + lake_formation_role_arn=lf_role_arn, + feature_store_role_arn=fs_role_arn, + ) + + # Verify both principals in object statement + object_principals = policy["Statement"][0]["Condition"]["StringNotEquals"]["aws:PrincipalArn"] + assert lf_role_arn in object_principals + assert fs_role_arn in object_principals + assert len(object_principals) == 2 + + # Verify both principals in list statement + list_principals = policy["Statement"][1]["Condition"]["StringNotEquals"]["aws:PrincipalArn"] + assert lf_role_arn in list_principals + assert fs_role_arn in list_principals + assert len(list_principals) == 2 + + def test_policy_has_correct_actions_in_each_statement(self): + """Test that each statement has the correct S3 actions.""" + policy = self.fg._generate_s3_deny_policy( + bucket_name="test-bucket", + s3_prefix="test/prefix", + lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", + feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", + ) + + # Verify object statement has correct actions + object_actions = policy["Statement"][0]["Action"] + assert "s3:GetObject" in object_actions + assert "s3:PutObject" in object_actions + assert "s3:DeleteObject" in object_actions + assert len(object_actions) == 3 + + # Verify list statement has correct action + list_action = policy["Statement"][1]["Action"] + assert list_action == "s3:ListBucket" + + + +class TestEnableLakeFormationServiceLinkedRoleInPolicy: + """Tests for service-linked role ARN usage in S3 deny policy generation.""" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch.object(FeatureGroup, "_generate_s3_deny_policy") + @patch("builtins.print") + def test_uses_service_linked_role_arn_when_use_service_linked_role_true( + self, + mock_print, + mock_generate_policy, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that enable_lake_formation uses the auto-generated service-linked role ARN + when use_service_linked_role=True. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} + + # Call with use_service_linked_role=True (default) + fg.enable_lake_formation(use_service_linked_role=True, show_s3_policy=True) + + # Verify _generate_s3_deny_policy was called with the service-linked role ARN + expected_slr_arn = "arn:aws:iam::123456789012:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" + mock_generate_policy.assert_called_once() + call_kwargs = mock_generate_policy.call_args[1] + assert call_kwargs["lake_formation_role_arn"] == expected_slr_arn + assert call_kwargs["feature_store_role_arn"] == fg.role_arn + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch.object(FeatureGroup, "_generate_s3_deny_policy") + @patch("builtins.print") + def test_uses_service_linked_role_arn_by_default( + self, + mock_print, + mock_generate_policy, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that enable_lake_formation uses the service-linked role ARN by default + (when use_service_linked_role is not explicitly specified). + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::987654321098:role/MyFeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-east-1:987654321098:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} + + # Call without specifying use_service_linked_role (should default to True) + fg.enable_lake_formation(show_s3_policy=True) + + # Verify _generate_s3_deny_policy was called with the service-linked role ARN + expected_slr_arn = "arn:aws:iam::987654321098:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" + mock_generate_policy.assert_called_once() + call_kwargs = mock_generate_policy.call_args[1] + assert call_kwargs["lake_formation_role_arn"] == expected_slr_arn + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch.object(FeatureGroup, "_generate_s3_deny_policy") + @patch("builtins.print") + def test_service_linked_role_arn_uses_correct_account_id( + self, + mock_print, + mock_generate_policy, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that the service-linked role ARN is generated with the correct account ID + extracted from the Feature Group ARN. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Use a specific account ID to verify it's extracted correctly + account_id = "111222333444" + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = f"arn:aws:iam::{account_id}:role/FeatureStoreRole" + fg.feature_group_arn = f"arn:aws:sagemaker:us-west-2:{account_id}:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} + + # Call with use_service_linked_role=True + fg.enable_lake_formation(use_service_linked_role=True, show_s3_policy=True) + + # Verify the service-linked role ARN contains the correct account ID + expected_slr_arn = f"arn:aws:iam::{account_id}:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" + mock_generate_policy.assert_called_once() + call_kwargs = mock_generate_policy.call_args[1] + assert call_kwargs["lake_formation_role_arn"] == expected_slr_arn + assert account_id in call_kwargs["lake_formation_role_arn"] + + + +class TestRegistrationRoleArnUsedWhenServiceLinkedRoleFalse: + """Tests for verifying registration_role_arn is used when use_service_linked_role=False.""" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch.object(FeatureGroup, "_generate_s3_deny_policy") + @patch("builtins.print") + def test_uses_registration_role_arn_when_use_service_linked_role_false( + self, + mock_print, + mock_generate_policy, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that when use_service_linked_role=False, the registration_role_arn is used + in the S3 deny policy instead of the auto-generated service-linked role ARN. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} + + # Custom registration role ARN + custom_registration_role = "arn:aws:iam::123456789012:role/CustomLakeFormationRole" + + # Call with use_service_linked_role=False and registration_role_arn + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=custom_registration_role, + show_s3_policy=True, + ) + + # Verify _generate_s3_deny_policy was called with the custom registration role ARN + mock_generate_policy.assert_called_once() + call_kwargs = mock_generate_policy.call_args[1] + assert call_kwargs["lake_formation_role_arn"] == custom_registration_role + + # Verify it's NOT the service-linked role ARN + service_linked_role_pattern = "aws-service-role/lakeformation.amazonaws.com" + assert service_linked_role_pattern not in call_kwargs["lake_formation_role_arn"] + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch.object(FeatureGroup, "_generate_s3_deny_policy") + @patch("builtins.print") + def test_registration_role_arn_passed_to_s3_registration( + self, + mock_print, + mock_generate_policy, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that when use_service_linked_role=False, the registration_role_arn is also + passed to _register_s3_with_lake_formation. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} + + # Custom registration role ARN + custom_registration_role = "arn:aws:iam::123456789012:role/CustomLakeFormationRole" + + # Call with use_service_linked_role=False and registration_role_arn + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=custom_registration_role, + show_s3_policy=True, + ) + + # Verify _register_s3_with_lake_formation was called with the correct parameters + mock_register.assert_called_once() + call_args = mock_register.call_args + assert call_args[1]["use_service_linked_role"] == False + assert call_args[1]["role_arn"] == custom_registration_role + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch.object(FeatureGroup, "_generate_s3_deny_policy") + @patch("builtins.print") + def test_different_registration_role_arns_produce_different_policies( + self, + mock_print, + mock_generate_policy, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that different registration_role_arn values result in different + lake_formation_role_arn values in the generated policy. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} + + # First call with one registration role + first_role = "arn:aws:iam::123456789012:role/FirstLakeFormationRole" + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=first_role, + show_s3_policy=True, + ) + + first_call_kwargs = mock_generate_policy.call_args[1] + first_lf_role = first_call_kwargs["lake_formation_role_arn"] + + # Reset mocks + mock_generate_policy.reset_mock() + mock_register.reset_mock() + mock_grant.reset_mock() + mock_revoke.reset_mock() + + # Second call with different registration role + second_role = "arn:aws:iam::123456789012:role/SecondLakeFormationRole" + fg.enable_lake_formation( + use_service_linked_role=False, + registration_role_arn=second_role, + show_s3_policy=True, + ) + + second_call_kwargs = mock_generate_policy.call_args[1] + second_lf_role = second_call_kwargs["lake_formation_role_arn"] + + # Verify different roles were used + assert first_lf_role == first_role + assert second_lf_role == second_role + assert first_lf_role != second_lf_role + + + +class TestPolicyPrintedWithClearInstructions: + """Tests for verifying the S3 deny policy is printed with clear instructions.""" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch("builtins.print") + def test_policy_printed_with_header_and_instructions( + self, + mock_print, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that enable_lake_formation prints the S3 deny policy with clear + header and instructions for the user. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call enable_lake_formation with show_s3_policy=True + fg.enable_lake_formation(show_s3_policy=True) + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify header is printed + assert "S3 Bucket Policy" in all_printed_text, "Header should mention 'S3 Bucket Policy'" + + # Verify instructions are printed + assert ( + "Lake Formation" in all_printed_text + or "deny policy" in all_printed_text + ), "Instructions should mention Lake Formation or deny policy" + + # Verify bucket name is printed + assert "test-bucket" in all_printed_text, "Bucket name should be printed" + + # Verify note about merging with existing policy is printed + assert ( + "Merge" in all_printed_text or "existing" in all_printed_text + ), "Note about merging with existing policy should be printed" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch("builtins.print") + def test_policy_json_is_printed( + self, + mock_print, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that the S3 deny policy JSON is printed to the console when show_s3_policy=True. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call enable_lake_formation with show_s3_policy=True + fg.enable_lake_formation(show_s3_policy=True) + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify policy JSON structure elements are printed + assert "Version" in all_printed_text, "Policy JSON should contain 'Version'" + assert "Statement" in all_printed_text, "Policy JSON should contain 'Statement'" + assert "Effect" in all_printed_text, "Policy JSON should contain 'Effect'" + assert "Deny" in all_printed_text, "Policy JSON should contain 'Deny' effect" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch("builtins.print") + def test_policy_printed_only_after_successful_setup( + self, + mock_print, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that the S3 deny policy is only printed after all Lake Formation + phases complete successfully. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock Phase 1 failure + mock_register.side_effect = Exception("Phase 1 failed") + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call enable_lake_formation with show_s3_policy=True - should fail + with pytest.raises(RuntimeError): + fg.enable_lake_formation(show_s3_policy=True) + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify policy was NOT printed when setup failed + assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when setup fails" + + # Reset mocks + mock_print.reset_mock() + mock_register.reset_mock() + mock_register.side_effect = None + mock_register.return_value = True + + # Mock Phase 2 failure + mock_grant.side_effect = Exception("Phase 2 failed") + + # Call enable_lake_formation with show_s3_policy=True - should fail + with pytest.raises(RuntimeError): + fg.enable_lake_formation(show_s3_policy=True) + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify policy was NOT printed when setup fails at Phase 2 + assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when Phase 2 fails" + + # Reset mocks + mock_print.reset_mock() + mock_grant.reset_mock() + mock_grant.side_effect = None + mock_grant.return_value = True + + # Mock Phase 3 failure + mock_revoke.side_effect = Exception("Phase 3 failed") + + # Call enable_lake_formation with show_s3_policy=True - should fail + with pytest.raises(RuntimeError): + fg.enable_lake_formation(show_s3_policy=True) + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify policy was NOT printed when setup fails at Phase 3 + assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when Phase 3 fails" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch("builtins.print") + def test_policy_includes_both_allowed_principals( + self, + mock_print, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that the printed policy includes both the Lake Formation role + and the Feature Store execution role as allowed principals. + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + feature_store_role = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.role_arn = feature_store_role + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call enable_lake_formation with service-linked role and show_s3_policy=True + fg.enable_lake_formation(use_service_linked_role=True, show_s3_policy=True) + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify Feature Store role is in the printed output + assert feature_store_role in all_printed_text, "Feature Store role should be in printed policy" + + # Verify Lake Formation service-linked role pattern is in the printed output + assert "AWSServiceRoleForLakeFormationDataAccess" in all_printed_text, \ + "Lake Formation service-linked role should be in printed policy" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch("builtins.print") + def test_policy_not_printed_when_show_s3_policy_false( + self, + mock_print, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that the S3 deny policy is NOT printed when show_s3_policy=False (default). + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call enable_lake_formation with show_s3_policy=False (default) + fg.enable_lake_formation(show_s3_policy=False) + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify policy was NOT printed + assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when show_s3_policy=False" + assert "Version" not in all_printed_text, "Policy JSON should not be printed when show_s3_policy=False" + + @patch.object(FeatureGroup, "refresh") + @patch.object(FeatureGroup, "_register_s3_with_lake_formation") + @patch.object(FeatureGroup, "_grant_lake_formation_permissions") + @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") + @patch("builtins.print") + def test_policy_not_printed_by_default( + self, + mock_print, + mock_revoke, + mock_grant, + mock_register, + mock_refresh, + ): + """ + Test that the S3 deny policy is NOT printed by default (when show_s3_policy is not specified). + """ + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + # Set up Feature Group with required configuration + fg = FeatureGroup(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig( + s3_uri="s3://test-bucket/path", + resolved_output_s3_uri="s3://test-bucket/resolved-path/data", + ), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + ) + fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" + fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" + fg.feature_group_status = "Created" + + # Mock successful Lake Formation operations + mock_register.return_value = True + mock_grant.return_value = True + mock_revoke.return_value = True + + # Call enable_lake_formation without specifying show_s3_policy (should default to False) + fg.enable_lake_formation() + + # Collect all print calls + print_calls = [str(call) for call in mock_print.call_args_list] + all_printed_text = " ".join(print_calls) + + # Verify policy was NOT printed + assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed by default" + assert "Version" not in all_printed_text, "Policy JSON should not be printed by default" From a984e4704520df3416b5028a4674d385e8f97f84 Mon Sep 17 00:00:00 2001 From: Basssem Halim Date: Wed, 28 Jan 2026 17:06:06 -0800 Subject: [PATCH 4/8] docs(feature_store): Add Lake Formation governance example notebook --- .../v3-feature-store-lake-formation.ipynb | 594 ++++++++++++++++++ 1 file changed, 594 insertions(+) create mode 100644 v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb diff --git a/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb b/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb new file mode 100644 index 0000000000..4f6efa03c4 --- /dev/null +++ b/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb @@ -0,0 +1,594 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Feature Store with Lake Formation Governance\n", + "\n", + "This notebook demonstrates two workflows for using SageMaker Feature Store with Lake Formation governance:\n", + "\n", + "1. **Example 1**: Create Feature Group with Lake Formation enabled at creation time\n", + "2. **Example 2**: Create Feature Group first, then enable Lake Formation separately\n", + "\n", + "Both workflows include record ingestion to verify everything works end-to-end.\n", + "\n", + "## Prerequisites\n", + "\n", + "- AWS credentials configured with permissions for SageMaker, S3, Glue, and Lake Formation\n", + "- An S3 bucket for the offline store\n", + "- An IAM role with Feature Store permissions\n", + "\n", + "## Required IAM Permissions\n", + "\n", + "TODO: Add the required IAM policy document here\n", + "\n", + "```json\n", + "{\n", + " \"Version\": \"2012-10-17\",\n", + " \"Statement\": [\n", + " // TODO: Add required permissions\n", + " 1. FS permissions to create FG\n", + " 2. LF permissons\n", + " ]\n", + "}\n", + "```\n", + "\n", + "## Lake Formation Admin Requirements\n", + "\n", + "The person enabling Lake Formation governance must be a **Data Lake Administrator** in Lake Formation. There are two options depending on your organization's setup:\n", + "\n", + "### Option 1: Single User (Data Lake Admin + Feature Store Admin)\n", + "\n", + "If the caller has both:\n", + "- Data Lake Administrator privileges in Lake Formation\n", + "- Permissions to create Feature Groups in SageMaker\n", + "\n", + "Then they can use `FeatureGroup.create()` with `lake_formation_config` to enable governance at creation time (Example 1).\n", + "\n", + "### Option 2: Separate Roles (ML Engineer + Data Lake Admin)\n", + "\n", + "If the person creating the Feature Group is different from the Data Lake Administrator:\n", + "\n", + "1. **ML Engineer** creates the Feature Group without Lake Formation using `FeatureGroup.create()`\n", + "2. **Data Lake Admin** later enables governance by calling `enable_lake_formation()` on the existing Feature Group (Example 2)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "from datetime import datetime\n", + "from datetime import timezone\n", + "\n", + "import boto3\n", + "from botocore.exceptions import ClientError\n", + "\n", + "# Import the FeatureGroup with Lake Formation support\n", + "from sagemaker.mlops.feature_store.feature_group import FeatureGroup, LakeFormationConfig\n", + "from sagemaker.core.shapes import (\n", + " FeatureDefinition,\n", + " FeatureValue,\n", + " OfflineStoreConfig,\n", + " OnlineStoreConfig,\n", + " S3StorageConfig,\n", + ")\n", + "from sagemaker.core.helper.session_helper import Session as SageMakerSession, get_execution_role" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use SageMaker session to get default bucket and execution role\n", + "sagemaker_session = SageMakerSession()\n", + "S3_BUCKET = sagemaker_session.default_bucket()\n", + "ROLE_ARN = get_execution_role(sagemaker_session)\n", + "REGION = sagemaker_session.boto_session.region_name\n", + "\n", + "print(f\"S3 Bucket: {S3_BUCKET}\")\n", + "print(f\"Role ARN: {ROLE_ARN}\")\n", + "print(f\"Region: {REGION}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Common Feature Definitions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feature_definitions = [\n", + " FeatureDefinition(feature_name=\"customer_id\", feature_type=\"String\"),\n", + " FeatureDefinition(feature_name=\"event_time\", feature_type=\"String\"),\n", + " FeatureDefinition(feature_name=\"age\", feature_type=\"Integral\"),\n", + " FeatureDefinition(feature_name=\"total_purchases\", feature_type=\"Integral\"),\n", + " FeatureDefinition(feature_name=\"avg_order_value\", feature_type=\"Fractional\"),\n", + "]\n", + "\n", + "print(\"Feature Definitions:\")\n", + "for fd in feature_definitions:\n", + " print(f\" - {fd.feature_name}: {fd.feature_type}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper Function: Ingest Records" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def ingest_sample_records(feature_group, num_records=3):\n", + " \"\"\"\n", + " Ingest sample records into the Feature Group.\n", + " \n", + " Args:\n", + " feature_group: The FeatureGroup to ingest records into\n", + " num_records: Number of sample records to ingest\n", + " \"\"\"\n", + " print(f\"\\nIngesting {num_records} sample records...\")\n", + " \n", + " for i in range(num_records):\n", + " event_time = datetime.now(timezone.utc).isoformat()\n", + " record = [\n", + " FeatureValue(feature_name=\"customer_id\", value_as_string=f\"cust_{i+1}\"),\n", + " FeatureValue(feature_name=\"event_time\", value_as_string=event_time),\n", + " FeatureValue(feature_name=\"age\", value_as_string=str(25 + i * 5)),\n", + " FeatureValue(feature_name=\"total_purchases\", value_as_string=str(10 + i * 3)),\n", + " FeatureValue(feature_name=\"avg_order_value\", value_as_string=str(50.0 + i * 10.5)),\n", + " ]\n", + " \n", + " feature_group.put_record(record=record)\n", + " print(f\" Ingested record for customer: cust_{i+1}\")\n", + " \n", + " print(f\"Successfully ingested {num_records} records!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper Function: Cleanup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def cleanup_feature_group(fg):\n", + " \"\"\"\n", + " Delete a FeatureGroup and its associated Glue table.\n", + " \n", + " Args:\n", + " fg: The FeatureGroup to delete.\n", + " \"\"\"\n", + " try:\n", + " # Delete the Glue table if it exists\n", + " if fg.offline_store_config is not None:\n", + " try:\n", + " fg.refresh() # Ensure we have latest config\n", + " data_catalog_config = fg.offline_store_config.data_catalog_config\n", + " if data_catalog_config is not None:\n", + " database_name = data_catalog_config.database\n", + " table_name = data_catalog_config.table_name\n", + "\n", + " if database_name and table_name:\n", + " glue_client = boto3.client(\"glue\")\n", + " try:\n", + " glue_client.delete_table(DatabaseName=database_name, Name=table_name)\n", + " print(f\"Deleted Glue table: {database_name}.{table_name}\")\n", + " except ClientError as e:\n", + " # Ignore if table doesn't exist\n", + " if e.response[\"Error\"][\"Code\"] != \"EntityNotFoundException\":\n", + " raise\n", + " except Exception as e:\n", + " # Don't fail cleanup if Glue table deletion fails\n", + " print(f\"Warning: Could not delete Glue table: {e}\")\n", + "\n", + " # Delete the FeatureGroup\n", + " fg.delete()\n", + " print(f\"Deleted Feature Group: {fg.feature_group_name}\")\n", + " except ClientError as e:\n", + " print(f\"Error during cleanup: {e}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# Example 1: Create Feature Group with Lake Formation Enabled\n", + "\n", + "This example creates a Feature Group with Lake Formation governance enabled at creation time using `LakeFormationConfig`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate unique name for example 1\n", + "timestamp = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n", + "FG_NAME_WORKFLOW1 = f\"lf-demo-workflow1-{timestamp}\"\n", + "\n", + "print(f\"Example 1 Feature Group: {FG_NAME_WORKFLOW1}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure online and offline stores\n", + "online_store_config = OnlineStoreConfig(enable_online_store=True)\n", + "\n", + "offline_store_config_1 = OfflineStoreConfig(\n", + " s3_storage_config=S3StorageConfig(\n", + " s3_uri=f\"s3://{S3_BUCKET}/feature-store-demo/\"\n", + " )\n", + ")\n", + "\n", + "# Configure Lake Formation - enabled at creation\n", + "lake_formation_config = LakeFormationConfig()\n", + "lake_formation_config.enabled = True\n", + "lake_formation_config.use_service_linked_role = True\n", + "lake_formation_config.show_s3_policy = True\n", + "\n", + "print(\"Store Config:\")\n", + "print(f\" Online Store: enabled\")\n", + "print(f\" Offline Store S3: s3://{S3_BUCKET}/feature-store-demo/\")\n", + "print(\"\\nLake Formation Config:\")\n", + "print(f\" enabled: {lake_formation_config.enabled}\")\n", + "print(f\" use_service_linked_role: {lake_formation_config.use_service_linked_role}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create Feature Group with Lake Formation enabled\n", + "print(\"Creating Feature Group with Lake Formation enabled...\")\n", + "print(\"This will:\")\n", + "print(\" 1. Create the Feature Group with online + offline stores\")\n", + "print(\" 2. Wait for 'Created' status\")\n", + "print(\" 3. Register S3 with Lake Formation\")\n", + "print(\" 4. Grant permissions to execution role\")\n", + "print(\" 5. Revoke IAMAllowedPrincipal permissions\")\n", + "print()\n", + "\n", + "fg_workflow1 = FeatureGroup.create(\n", + " feature_group_name=FG_NAME_WORKFLOW1,\n", + " record_identifier_feature_name=\"customer_id\",\n", + " event_time_feature_name=\"event_time\",\n", + " feature_definitions=feature_definitions,\n", + " online_store_config=online_store_config,\n", + " offline_store_config=offline_store_config_1,\n", + " role_arn=ROLE_ARN,\n", + " description=\"Workflow 1: Lake Formation enabled at creation\",\n", + " lake_formation_config=lake_formation_config, # new field\n", + " region=REGION,\n", + ")\n", + "\n", + "print(f\"\\nFeature Group created: {fg_workflow1.feature_group_name}\")\n", + "print(f\"Status: {fg_workflow1.feature_group_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Verify Feature Group status\n", + "fg_workflow1.refresh()\n", + "print(f\"Feature Group: {fg_workflow1.feature_group_name}\")\n", + "print(f\"Status: {fg_workflow1.feature_group_status}\")\n", + "print(f\"ARN: {fg_workflow1.feature_group_arn}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Ingest sample records to verify everything works\n", + "ingest_sample_records(fg_workflow1, num_records=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Retrieve a sample record from the online store\n", + "print(\"Retrieving record for customer 'cust_1' from online store...\")\n", + "response = fg_workflow1.get_record(record_identifier_value_as_string=\"cust_1\")\n", + "\n", + "print(f\"\\nRecord retrieved successfully!\")\n", + "print(f\"Features:\")\n", + "for feature in response.record:\n", + " print(f\" {feature.feature_name}: {feature.value_as_string}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# Example 2: Create Feature Group, Then Enable Lake Formation\n", + "\n", + "This example creates a Feature Group first without Lake Formation, then enables it separately using `enable_lake_formation()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate unique name for example 2\n", + "timestamp = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n", + "FG_NAME_WORKFLOW2 = f\"lf-demo-workflow2-{timestamp}\"\n", + "\n", + "print(f\"Example 2 Feature Group: {FG_NAME_WORKFLOW2}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Configure online and offline stores\n", + "online_store_config_2 = OnlineStoreConfig(enable_online_store=True)\n", + "\n", + "offline_store_config_2 = OfflineStoreConfig(\n", + " s3_storage_config=S3StorageConfig(\n", + " s3_uri=f\"s3://{S3_BUCKET}/feature-store-demo/\"\n", + " ),\n", + " table_format=\"Iceberg\"\n", + ")\n", + "\n", + "# Step 1: Create Feature Group WITHOUT Lake Formation\n", + "print(\"Step 1: Creating Feature Group without Lake Formation...\")\n", + "\n", + "fg_workflow2 = FeatureGroup.create(\n", + " feature_group_name=FG_NAME_WORKFLOW2,\n", + " record_identifier_feature_name=\"customer_id\",\n", + " event_time_feature_name=\"event_time\",\n", + " feature_definitions=feature_definitions,\n", + " online_store_config=online_store_config_2,\n", + " offline_store_config=offline_store_config_2,\n", + " role_arn=ROLE_ARN,\n", + " description=\"Workflow 2: Lake Formation enabled after creation\",\n", + " region=REGION,\n", + ")\n", + "\n", + "print(f\"Feature Group created: {fg_workflow2.feature_group_name}\")\n", + "print(f\"Status: {fg_workflow2.feature_group_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 2: Wait for Feature Group to be ready\n", + "print(\"Step 2: Waiting for Feature Group to reach 'Created' status...\")\n", + "fg_workflow2.wait_for_status(target_status=\"Created\", poll=10, timeout=300)\n", + "print(f\"Status: {fg_workflow2.feature_group_status}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 3: Enable Lake Formation governance\n", + "print(\"Step 3: Enabling Lake Formation governance...\")\n", + "print(\"This will:\")\n", + "print(\" 1. Register S3 with Lake Formation\")\n", + "print(\" 2. Grant permissions to execution role\")\n", + "print(\" 3. Revoke IAMAllowedPrincipal permissions\")\n", + "print()\n", + "\n", + "result = fg_workflow2.enable_lake_formation( # new method\n", + " use_service_linked_role=True\n", + ")\n", + "\n", + "print(f\"\\nLake Formation setup results:\")\n", + "print(f\" s3_registration: {result['s3_registration']}\")\n", + "print(f\" permissions_granted: {result['permissions_granted']}\")\n", + "print(f\" iam_principal_revoked: {result['iam_principal_revoked']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 4: Ingest sample records to verify everything works\n", + "print(\"Step 4: Ingesting records to verify Lake Formation setup...\")\n", + "ingest_sample_records(fg_workflow2, num_records=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Step 5: Retrieve a sample record from the online store\n", + "print(\"Step 5: Retrieving record for customer 'cust_1' from online store...\")\n", + "response = fg_workflow2.get_record(record_identifier_value_as_string=\"cust_1\")\n", + "\n", + "print(f\"\\nRecord retrieved successfully!\")\n", + "print(f\"Features:\")\n", + "for feature in response.record:\n", + " print(f\" {feature.feature_name}: {feature.value_as_string}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Verify Feature Group status\n", + "fg_workflow2.refresh()\n", + "print(f\"Feature Group: {fg_workflow2.feature_group_name}\")\n", + "print(f\"Status: {fg_workflow2.feature_group_status}\")\n", + "print(f\"ARN: {fg_workflow2.feature_group_arn}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# Cleanup\n", + "\n", + "Delete the Feature Groups and associated Glue tables created in this demo." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment to delete the Feature Groups\n", + "cleanup_feature_group(fg_workflow1)\n", + "# cleanup_feature_group(fg_workflow2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# Summary\n", + "\n", + "This notebook demonstrated two workflows:\n", + "\n", + "**Example 1: Lake Formation at Creation**\n", + "- Use `LakeFormationConfig` with `enabled=True` in `FeatureGroup.create()`\n", + "- Lake Formation is automatically configured after Feature Group creation\n", + "- Both online and offline stores enabled\n", + "\n", + "**Example 2: Enable Lake Formation Later**\n", + "- Create Feature Group normally without Lake Formation\n", + "- Call `enable_lake_formation()` method after creation\n", + "- More control over when Lake Formation is enabled\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "# FAQ:\n", + "\n", + "## What is the S3 deny policy for?\n", + "\n", + "When you enable Lake Formation governance, you control access to data through Lake Formation permissions. However, **IAM roles that already have direct S3 access will continue to have access** to the underlying data files, bypassing Lake Formation entirely.\n", + "\n", + "The S3 deny policy closes this access path by explicitly denying S3 access to all principals except:\n", + "- The Lake Formation service-linked role (for data access)\n", + "- The Feature Store offline store role provided during Feature Group creation\n", + "\n", + "## Why don't we apply the S3 deny policy automatically?\n", + "\n", + "We provide the policy as a **recommendation** rather than applying it automatically for several important reasons:\n", + "\n", + "### 1. Protect existing SageMaker workflows from breaking\n", + "\n", + "Many customers already have SageMaker training and processing jobs wired directly to S3 URIs. An automatic S3 deny could cause those jobs to fail the moment governance is enabled on a table.\n", + "\n", + "### 2. Support different personas and trust levels\n", + "\n", + "Different users have different access needs:\n", + "- **Analysts / BI users** - should only see data through governed surfaces (Lake Formation tables, Athena, Redshift, etc.)\n", + "- **ML / Data engineers** - often need raw S3 access for training, feature engineering, and debugging\n", + "\n", + "### 3. Enable gradual migration to stronger governance\n", + "\n", + "Many customers want to phase in Lake Formation governance:\n", + "1. Start by governing table access only\n", + "2. Later tighten S3 access once they've refactored jobs and validated behavior\n", + "\n", + "### 4. Avoid breaking existing bucket policies\n", + "\n", + "Automatically modifying bucket policies could:\n", + "- Conflict with existing policy statements\n", + "- Lock out users or services unexpectedly\n", + "- Cause cascading failures across multiple applications sharing the bucket\n", + "\n", + "Therefore, the S3 policy is provided as a starting point that should be validated by the user. \n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 6af324019b376d456b15de8a85ca91c67d2d6248 Mon Sep 17 00:00:00 2001 From: BassemHalim Date: Thu, 29 Jan 2026 14:45:37 -0800 Subject: [PATCH 5/8] add role policy to notebook --- .../v3-feature-store-lake-formation.ipynb | 97 +++++++++++++++++-- 1 file changed, 89 insertions(+), 8 deletions(-) diff --git a/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb b/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb index 4f6efa03c4..f8c5a8ced0 100644 --- a/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb +++ b/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb @@ -21,15 +21,89 @@ "\n", "## Required IAM Permissions\n", "\n", - "TODO: Add the required IAM policy document here\n", + "This notebook uses two separate IAM roles:\n", + "1. **Execution Role**: The SageMaker execution role running this notebook\n", + "2. **Offline Store Role**: A dedicated role for Feature Store S3 access\n", + "\n", + "### Execution Role Policy\n", + "\n", + "The execution role needs permissions to manage Feature Groups and configure Lake Formation:\n", "\n", "```json\n", "{\n", " \"Version\": \"2012-10-17\",\n", " \"Statement\": [\n", - " // TODO: Add required permissions\n", - " 1. FS permissions to create FG\n", - " 2. LF permissons\n", + " {\n", + " \"Sid\": \"FeatureGroupManagement\",\n", + " \"Effect\": \"Allow\",\n", + " \"Action\": [\n", + " \"sagemaker:*\"\n", + " ],\n", + " \"Resource\": \"arn:aws:sagemaker:*:*:feature-group/*\"\n", + " },\n", + " {\n", + " \"Sid\": \"LakeFormation\",\n", + " \"Effect\": \"Allow\",\n", + " \"Action\": [\n", + " \"lakeformation:RegisterResource\",\n", + " \"lakeformation:DeregisterResource\",\n", + " \"lakeformation:GrantPermissions\",\n", + " \"lakeformation:RevokePermissions\"\n", + " ],\n", + " \"Resource\": \"*\"\n", + " },\n", + " {\n", + " \"Sid\": \"GlueCatalogRead\",\n", + " \"Effect\": \"Allow\",\n", + " \"Action\": [\n", + " \"glue:GetTable\",\n", + " \"glue:GetDatabase\",\n", + " \"glue:DeleteTable\"\n", + " ],\n", + " \"Resource\": [\n", + " \"arn:aws:glue:*:*:catalog\",\n", + " \"arn:aws:glue:*:*:database/sagemaker_featurestore\",\n", + " \"arn:aws:glue:*:*:table/sagemaker_featurestore/*\"\n", + " ]\n", + " },\n", + " {\n", + " \"Sid\": \"PassOfflineStoreRole\",\n", + " \"Effect\": \"Allow\",\n", + " \"Action\": \"iam:PassRole\",\n", + " \"Resource\": \"arn:aws:iam::*:role/SagemakerFeatureStoreOfflineRole\"\n", + " },\n", + " {\n", + " \"Sid\": \"LakeFormationServiceLinkedRole\",\n", + " \"Effect\": \"Allow\",\n", + " \"Action\": [\n", + " \"iam:GetRole\",\n", + " \"iam:PutRolePolicy\",\n", + " \"iam:GetRolePolicy\"\n", + " ],\n", + " \"Resource\": \"arn:aws:iam::*:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess\"\n", + " },\n", + " {\n", + " \"Sid\": \"S3SagemakerDefaultBucket\",\n", + " \"Effect\": \"Allow\",\n", + " \"Action\": [\n", + " \"s3:CreateBucket\",\n", + " \"s3:GetBucketAcl\",\n", + " \"s3:ListBucket\"\n", + " ],\n", + " \"Resource\": [\n", + " \"arn:aws:s3:::sagemaker-*\"\n", + " ]\n", + " },\n", + " {\n", + " \"Sid\": \"CreateGlueTable\",\n", + " \"Effect\": \"Allow\",\n", + " \"Action\": [\n", + " \"glue:CreateTable\"\n", + " ],\n", + " \"Resource\": [\n", + " \"*\"\n", + " ]\n", + " }\n", " ]\n", "}\n", "```\n", @@ -102,11 +176,18 @@ "# Use SageMaker session to get default bucket and execution role\n", "sagemaker_session = SageMakerSession()\n", "S3_BUCKET = sagemaker_session.default_bucket()\n", - "ROLE_ARN = get_execution_role(sagemaker_session)\n", "REGION = sagemaker_session.boto_session.region_name\n", "\n", + "# Execution role (for running this notebook)\n", + "EXECUTION_ROLE_ARN = get_execution_role(sagemaker_session)\n", + "\n", + "# Offline store role (dedicated role for Feature Store S3 access)\n", + "# Replace with your dedicated offline store role ARN\n", + "OFFLINE_STORE_ROLE_ARN = \"arn:aws:iam:::role/\"\n", + "\n", "print(f\"S3 Bucket: {S3_BUCKET}\")\n", - "print(f\"Role ARN: {ROLE_ARN}\")\n", + "print(f\"Execution Role ARN: {EXECUTION_ROLE_ARN}\")\n", + "print(f\"Offline Store Role ARN: {OFFLINE_STORE_ROLE_ARN}\")\n", "print(f\"Region: {REGION}\")" ] }, @@ -300,7 +381,7 @@ " feature_definitions=feature_definitions,\n", " online_store_config=online_store_config,\n", " offline_store_config=offline_store_config_1,\n", - " role_arn=ROLE_ARN,\n", + " role_arn=OFFLINE_STORE_ROLE_ARN,\n", " description=\"Workflow 1: Lake Formation enabled at creation\",\n", " lake_formation_config=lake_formation_config, # new field\n", " region=REGION,\n", @@ -398,7 +479,7 @@ " feature_definitions=feature_definitions,\n", " online_store_config=online_store_config_2,\n", " offline_store_config=offline_store_config_2,\n", - " role_arn=ROLE_ARN,\n", + " role_arn=OFFLINE_STORE_ROLE_ARN,\n", " description=\"Workflow 2: Lake Formation enabled after creation\",\n", " region=REGION,\n", ")\n", From 91711eb97dba9928f568f8e20c7fefb005067f09 Mon Sep 17 00:00:00 2001 From: BassemHalim Date: Mon, 9 Feb 2026 12:43:08 -0800 Subject: [PATCH 6/8] feat(feature_store): Add Feature Processor implementation - Most code was ported from V2 and the imports were updated - Some usage of session.describe_feature_group was converted to FeatureGroup.get() from sagemaker.core which calls describe_feature_Group under the hood - The feature_scheduler to_pipeline was refactored to replace Estimator from V2 to use ModelTrainer from V3. (Estimator was removed from V3) --- .../feature_processor/__init__.py | 45 + .../feature_processor/_config_uploader.py | 209 ++++ .../feature_processor/_constants.py | 54 + .../feature_processor/_data_source.py | 154 +++ .../feature_store/feature_processor/_enums.py | 33 + .../feature_store/feature_processor/_env.py | 78 ++ .../_event_bridge_rule_helper.py | 305 +++++ .../_event_bridge_scheduler_helper.py | 118 ++ .../feature_processor/_exceptions.py | 18 + .../feature_processor/_factory.py | 167 +++ .../_feature_processor_config.py | 72 ++ .../_feature_processor_pipeline_events.py | 29 + .../feature_processor/_input_loader.py | 366 ++++++ .../feature_processor/_input_offset_parser.py | 129 ++ .../feature_processor/_params_loader.py | 83 ++ .../feature_processor/_spark_factory.py | 202 +++ .../feature_processor/_udf_arg_provider.py | 239 ++++ .../feature_processor/_udf_output_receiver.py | 98 ++ .../feature_processor/_udf_wrapper.py | 88 ++ .../feature_processor/_validation.py | 210 ++++ .../feature_processor/feature_processor.py | 129 ++ .../feature_processor/feature_scheduler.py | 1105 +++++++++++++++++ .../feature_processor/lineage/__init__.py | 0 .../lineage/_feature_group_contexts.py | 31 + .../_feature_group_lineage_entity_handler.py | 184 +++ .../lineage/_feature_processor_lineage.py | 759 +++++++++++ .../_feature_processor_lineage_name_helper.py | 101 ++ .../lineage/_lineage_association_handler.py | 300 +++++ .../_pipeline_lineage_entity_handler.py | 105 ++ .../lineage/_pipeline_schedule.py | 44 + .../lineage/_pipeline_trigger.py | 36 + ...pipeline_version_lineage_entity_handler.py | 92 ++ .../lineage/_s3_lineage_entity_handler.py | 316 +++++ .../lineage/_transformation_code.py | 31 + .../feature_processor/lineage/constants.py | 43 + 35 files changed, 5973 insertions(+) create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py create mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py new file mode 100644 index 0000000000..1051096d0e --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/__init__.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Exported classes for the sagemaker.mlops.feature_store.feature_processor module.""" +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( # noqa: F401 + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, + PySparkDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._exceptions import ( # noqa: F401 + IngestionError, +) +from sagemaker.mlops.feature_store.feature_processor.feature_processor import ( # noqa: F401 + feature_processor, +) +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( # noqa: F401 + to_pipeline, + schedule, + describe, + put_trigger, + delete_trigger, + enable_trigger, + disable_trigger, + delete_schedule, + list_pipelines, + execute, + TransformationCode, + FeatureProcessorPipelineEvents, +) +from sagemaker.mlops.feature_store.feature_processor._enums import ( # noqa: F401 + FeatureProcessorPipelineExecutionStatus, +) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py new file mode 100644 index 0000000000..d181218fb5 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_config_uploader.py @@ -0,0 +1,209 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for preparing and uploading configs for a scheduled feature processor.""" +from __future__ import absolute_import +from typing import Callable, Dict, Optional, Tuple, List, Union + +import attr + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import ( + SPARK_JAR_FILES_PATH, + SPARK_PY_FILES_PATH, + SPARK_FILES_PATH, + S3_DATA_DISTRIBUTION_TYPE, +) +from sagemaker.core.inputs import TrainingInput +from sagemaker.core.shapes import Channel, DataSource, S3DataSource +from sagemaker.core.remote_function.core.stored_function import StoredFunction +from sagemaker.core.remote_function.job import ( + _prepare_and_upload_workspace, + _prepare_and_upload_runtime_scripts, + _JobSettings, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, + _prepare_and_upload_spark_dependent_files, +) +from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, +) +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.core.s3 import s3_path_join + + +@attr.s +class ConfigUploader: + """Prepares and uploads customer provided configs to S3""" + + remote_decorator_config: _JobSettings = attr.ib() + runtime_env_manager: RuntimeEnvironmentManager = attr.ib() + + def prepare_step_input_channel_for_spark_mode( + self, func: Callable, s3_base_uri: str, sagemaker_session: Session + ) -> Tuple[List[Channel], Dict]: + """Prepares input channels for SageMaker Pipeline Step. + + Returns: + Tuple of (List[Channel], spark_dependency_paths dict) + """ + self._prepare_and_upload_callable(func, s3_base_uri, sagemaker_session) + bootstrap_scripts_s3uri = self._prepare_and_upload_runtime_scripts( + self.remote_decorator_config.spark_config, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + ) + dependencies_list_path = self.runtime_env_manager.snapshot( + self.remote_decorator_config.dependencies + ) + user_workspace_s3uri = self._prepare_and_upload_workspace( + dependencies_list_path, + self.remote_decorator_config.include_local_workdir, + self.remote_decorator_config.pre_execution_commands, + self.remote_decorator_config.pre_execution_script, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + self.remote_decorator_config.custom_file_filter, + ) + + ( + submit_jars_s3_paths, + submit_py_files_s3_paths, + submit_files_s3_path, + config_file_s3_uri, + ) = self._prepare_and_upload_spark_dependent_files( + self.remote_decorator_config.spark_config, + s3_base_uri, + self.remote_decorator_config.s3_kms_key, + sagemaker_session, + ) + + channels = [ + Channel( + channel_name=RUNTIME_SCRIPTS_CHANNEL_NAME, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=bootstrap_scripts_s3uri, + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ] + + if user_workspace_s3uri: + channels.append( + Channel( + channel_name=REMOTE_FUNCTION_WORKSPACE, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE), + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ) + + if config_file_s3_uri: + channels.append( + Channel( + channel_name=SPARK_CONF_CHANNEL_NAME, + data_source=DataSource( + s3_data_source=S3DataSource( + s3_uri=config_file_s3_uri, + s3_data_type="S3Prefix", + s3_data_distribution_type=S3_DATA_DISTRIBUTION_TYPE, + ) + ), + input_mode="File", + ) + ) + + return channels, { + SPARK_JAR_FILES_PATH: submit_jars_s3_paths, + SPARK_PY_FILES_PATH: submit_py_files_s3_paths, + SPARK_FILES_PATH: submit_files_s3_path, + } + + def _prepare_and_upload_callable( + self, func: Callable, s3_base_uri: str, sagemaker_session: Session + ) -> None: + """Prepares and uploads callable to S3""" + stored_function = StoredFunction( + sagemaker_session=sagemaker_session, + s3_base_uri=s3_base_uri, + s3_kms_key=self.remote_decorator_config.s3_kms_key, + ) + stored_function.save(func) + + def _prepare_and_upload_workspace( + self, + local_dependencies_path: str, + include_local_workdir: bool, + pre_execution_commands: List[str], + pre_execution_script_local_path: str, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None, + ) -> str: + """Upload the training step dependencies to S3 if present""" + return _prepare_and_upload_workspace( + local_dependencies_path=local_dependencies_path, + include_local_workdir=include_local_workdir, + pre_execution_commands=pre_execution_commands, + pre_execution_script_local_path=pre_execution_script_local_path, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=custom_file_filter, + ) + + def _prepare_and_upload_runtime_scripts( + self, + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + ) -> str: + """Copy runtime scripts to a folder and upload to S3""" + return _prepare_and_upload_runtime_scripts( + spark_config=spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + def _prepare_and_upload_spark_dependent_files( + self, + spark_config: SparkConfig, + s3_base_uri: str, + s3_kms_key: str, + sagemaker_session: Session, + ) -> Tuple: + """Upload the spark dependencies to S3 if present""" + if not spark_config: + return None, None, None, None + + return _prepare_and_upload_spark_dependent_files( + spark_config=spark_config, + s3_base_uri=s3_base_uri, + s3_kms_key=s3_kms_key, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py new file mode 100644 index 0000000000..e010446904 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_constants.py @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing constants for feature_processor and feature_scheduler module.""" +from __future__ import absolute_import + +from sagemaker.core.workflow.parameters import Parameter, ParameterTypeEnum + +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" +DEFAULT_SCHEDULE_STATE = "ENABLED" +DEFAULT_TRIGGER_STATE = "ENABLED" +UNDERSCORE = "_" +RESOURCE_NOT_FOUND_EXCEPTION = "ResourceNotFoundException" +RESOURCE_NOT_FOUND = "ResourceNotFound" +EXECUTION_TIME_PIPELINE_PARAMETER = "scheduled_time" +VALIDATION_EXCEPTION = "ValidationException" +EVENT_BRIDGE_INVOCATION_TIME = "" +SCHEDULED_TIME_PIPELINE_PARAMETER = Parameter( + name=EXECUTION_TIME_PIPELINE_PARAMETER, parameter_type=ParameterTypeEnum.STRING +) +EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT = "%Y-%m-%dT%H:%M:%SZ" # 2023-01-01T07:00:00Z +NO_FLEXIBLE_TIME_WINDOW = dict(Mode="OFF") +PIPELINE_NAME_MAXIMUM_LENGTH = 80 +PIPELINE_CONTEXT_TYPE = "FeatureEngineeringPipeline" +SPARK_JAR_FILES_PATH = "submit_jars_s3_paths" +SPARK_PY_FILES_PATH = "submit_py_files_s3_paths" +SPARK_FILES_PATH = "submit_files_s3_path" +FEATURE_PROCESSOR_TAG_KEY = "sm-fs-fe:created-from" +FEATURE_PROCESSOR_TAG_VALUE = "fp-to-pipeline" +FEATURE_GROUP_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):feature-group/(.*?)$" +PIPELINE_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):pipeline/(.*?)$" +EVENTBRIDGE_RULE_ARN_REGEX_PATTERN = r"arn:(.*?):events:(.*?):(.*?):rule/(.*?)$" +SAGEMAKER_WHL_FILE_S3_PATH = "s3://ada-private-beta/sagemaker-2.151.1.dev0-py2.py3-none-any.whl" +S3_DATA_DISTRIBUTION_TYPE = "FullyReplicated" +PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name" +PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name" +TO_PIPELINE_RESERVED_TAG_KEYS = [ + FEATURE_PROCESSOR_TAG_KEY, + PIPELINE_CONTEXT_NAME_TAG_KEY, + PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, +] +BASE_EVENT_PATTERN = { + "source": ["aws.sagemaker"], + "detail": {"currentPipelineExecutionStatus": [], "pipelineArn": []}, +} diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py new file mode 100644 index 0000000000..a6c452267c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_data_source.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes to define input data sources.""" +from __future__ import absolute_import + +from typing import Optional, Dict, Union, TypeVar, Generic +from abc import ABC, abstractmethod +from pyspark.sql import DataFrame, SparkSession + + +import attr + +T = TypeVar("T") + + +@attr.s +class BaseDataSource(Generic[T], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the specified type. + """ + + @abstractmethod + def read_data(self, *args, **kwargs) -> T: + """Read data from data source and return the specified type. + + Args: + args: Arguments for reading the data. + kwargs: Keyword argument for reading the data. + Returns: + T: The specified abstraction of data source. + """ + + @property + @abstractmethod + def data_source_unique_id(self) -> str: + """The identifier for the customized feature processor data source. + + Returns: + str: The data source unique id. + """ + + @property + @abstractmethod + def data_source_name(self) -> str: + """The name for the customized feature processor data source. + + Returns: + str: The data source name. + """ + + +@attr.s +class PySparkDataSource(BaseDataSource[DataFrame], ABC): + """Abstract base class for feature processor data sources. + + Provides a skeleton for customization requiring the overriding of the method to read data from + data source and return the Spark DataFrame. + """ + + @abstractmethod + def read_data( + self, spark: SparkSession, params: Optional[Dict[str, Union[str, Dict]]] = None + ) -> DataFrame: + """Read data from data source and convert the data to Spark DataFrame. + + Args: + spark (SparkSession): The Spark session to read the data. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. + Returns: + DataFrame: The Spark DataFrame as an abstraction on the data source. + """ + + +@attr.s +class FeatureGroupDataSource: + """A Feature Group data source definition for a FeatureProcessor. + + Attributes: + name (str): The name or ARN of the Feature Group. + input_start_offset (Optional[str], optional): A duration specified as a string in the + format ' ' where 'no' is a number and 'unit' is a unit of time in ['hours', + 'days', 'weeks', 'months', 'years'] (plural and singular forms). Inputs contain data + with event times no earlier than input_start_offset in the past. Offsets are relative + to the function execution time. If the function is executed by a Schedule, then the + offset is relative to the scheduled start time. Defaults to None. + input_end_offset (Optional[str], optional): The 'end' (as opposed to start) counterpart for + the 'input_start_offset'. Inputs will contain records with event times no later than + 'input_end_offset' in the past. Defaults to None. + """ + + name: str = attr.ib() + input_start_offset: Optional[str] = attr.ib(default=None) + input_end_offset: Optional[str] = attr.ib(default=None) + + +@attr.s +class CSVDataSource: + """An CSV data source definition for a FeatureProcessor. + + Attributes: + s3_uri (str): S3 URI of the data source. + csv_header (bool): Whether to read the first line of the CSV file as column names. This + option is only valid when file_format is set to csv. By default the value of this + option is true, and all column types are assumed to be a string. + infer_schema (bool): Whether to infer the schema of the CSV data source. This option is only + valid when file_format is set to csv. If set to true, two passes of the data is required + to load and infer the schema. + """ + + s3_uri: str = attr.ib() + csv_header: bool = attr.ib(default=True) + csv_infer_schema: bool = attr.ib(default=False) + + +@attr.s +class ParquetDataSource: + """An parquet data source definition for a FeatureProcessor. + + Attributes: + s3_uri (str): S3 URI of the data source. + """ + + s3_uri: str = attr.ib() + + +@attr.s +class IcebergTableDataSource: + """An iceberg table data source definition for FeatureProcessor + + Attributes: + warehouse_s3_uri (str): S3 URI of data warehouse. The value is usually + the URI where data is stored. + catalog (str): Name of the catalog. + database (str): Name of the database. + table (str): Name of the table. + """ + + warehouse_s3_uri: str = attr.ib() + catalog: str = attr.ib() + database: str = attr.ib() + table: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py new file mode 100644 index 0000000000..b63ed3a65a --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_enums.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing Enums for the feature_processor module.""" +from __future__ import absolute_import + +from enum import Enum + + +class FeatureProcessorMode(Enum): + """Enum of feature_processor modes.""" + + PYSPARK = "pyspark" # Execute a pyspark job. + PYTHON = "python" # Execute a regular python script. + + +class FeatureProcessorPipelineExecutionStatus(Enum): + """Enum of feature_processor pipeline execution status.""" + + EXECUTING = "Executing" + STOPPING = "Stopping" + STOPPED = "Stopped" + FAILED = "Failed" + SUCCEEDED = "Succeeded" diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py new file mode 100644 index 0000000000..d4ccfb1197 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_env.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class that determines the current execution environment.""" +from __future__ import absolute_import + + +from typing import Dict, Optional +from datetime import datetime, timezone +import json +import logging +import os +import attr +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) + + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EnvironmentHelper: + """Helper class to retrieve info from environment. + + Attributes: + current_time (datetime): The current datetime. + """ + + current_time = attr.ib(default=datetime.now(timezone.utc)) + + def is_training_job(self) -> bool: + """Determine if the current execution environment is inside a SageMaker Training Job""" + return self.load_training_resource_config() is not None + + def get_instance_count(self) -> int: + """Determine the number of instances for the current execution environment.""" + resource_config = self.load_training_resource_config() + return len(resource_config["hosts"]) if resource_config else 1 + + def load_training_resource_config(self) -> Optional[Dict]: + """Load the contents of resourceconfig.json contents. + + Returns: + Optional[Dict]: None if not found. + """ + SM_TRAINING_CONFIG_FILE_PATH = "/opt/ml/input/config/resourceconfig.json" + try: + with open(SM_TRAINING_CONFIG_FILE_PATH, "r") as cfgfile: + resource_config = json.load(cfgfile) + logger.debug("Contents of %s: %s", SM_TRAINING_CONFIG_FILE_PATH, resource_config) + return resource_config + except FileNotFoundError: + return None + + def get_job_scheduled_time(self) -> str: + """Get the job scheduled time. + + Returns: + str: Timestamp when the job is scheduled. + """ + + scheduled_time = self.current_time.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + if self.is_training_job(): + envs = dict(os.environ) + return envs.get(EXECUTION_TIME_PIPELINE_PARAMETER, scheduled_time) + + return scheduled_time diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py new file mode 100644 index 0000000000..250e7d456f --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_rule_helper.py @@ -0,0 +1,305 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for EventBridge Schedule management for a feature processor.""" +from __future__ import absolute_import + +import json +import logging +import re +from typing import Dict, List, Tuple, Optional, Any +import attr +from botocore.exceptions import ClientError +from botocore.paginate import PageIterator +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + RESOURCE_NOT_FOUND_EXCEPTION, + PIPELINE_ARN_REGEX_PATTERN, + BASE_EVENT_PATTERN, +) +from sagemaker.mlops.feature_store.feature_processor._enums import ( + FeatureProcessorPipelineExecutionStatus, +) +from sagemaker.core.common_utils import TagsDict + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EventBridgeRuleHelper: + """Contains helper methods for managing EventBridge rules for a feature processor.""" + + sagemaker_session: Session = attr.ib() + event_bridge_rule_client = attr.ib() + + def put_rule( + self, + source_pipeline_events: List[FeatureProcessorPipelineEvents], + target_pipeline: str, + event_pattern: str, + state: str, + ) -> str: + """Creates an EventBridge Rule for a given target pipeline. + + Args: + source_pipeline_events: The list of pipeline events that trigger the EventBridge Rule. + target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule. + event_pattern: The EventBridge EventPattern that triggers the EventBridge Rule. + If specified, will override source_pipeline_events. + state: Indicates whether the rule is enabled or disabled. + + Returns: + The Amazon Resource Name (ARN) of the rule. + """ + self._validate_feature_processor_pipeline_events(source_pipeline_events) + rule_name = target_pipeline + _event_patterns = ( + event_pattern + or self._generate_event_pattern_from_feature_processor_pipeline_events( + source_pipeline_events + ) + ) + rule_arn = self.event_bridge_rule_client.put_rule( + Name=rule_name, EventPattern=_event_patterns, State=state + )["RuleArn"] + return rule_arn + + def put_target( + self, + rule_name: str, + target_pipeline: str, + target_pipeline_parameters: Dict[str, str], + role_arn: str, + ) -> None: + """Attach target pipeline to an event based trigger. + + Args: + rule_name: The name of the EventBridge Rule. + target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule. + target_pipeline_parameters: The list of parameters to start execution of a pipeline. + role_arn: The Amazon Resource Name (ARN) of the IAM role associated with the rule. + """ + target_pipeline_arn_and_name = self._generate_pipeline_arn_and_name(target_pipeline) + target_pipeline_name = target_pipeline_arn_and_name["pipeline_name"] + target_pipeline_arn = target_pipeline_arn_and_name["pipeline_arn"] + target_request_dict = { + "Id": target_pipeline_name, + "Arn": target_pipeline_arn, + "RoleArn": role_arn, + } + if target_pipeline_parameters: + target_request_dict["SageMakerPipelineParameters"] = { + "PipelineParameterList": target_pipeline_parameters + } + put_targets_response = self.event_bridge_rule_client.put_targets( + Rule=rule_name, + Targets=[target_request_dict], + ) + if put_targets_response["FailedEntryCount"] != 0: + error_msg = put_targets_response["FailedEntries"][0]["ErrorMessage"] + raise Exception(f"Failed to add target pipeline to rule. Failure reason: {error_msg}") + + def delete_rule(self, rule_name: str) -> None: + """Deletes an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.delete_rule(Name=rule_name) + + def remove_targets(self, rule_name: str, ids: List[str]) -> None: + """Deletes an EventBridge Targets of a given rule if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + ids: The ids of the EventBridge Target. + """ + self.event_bridge_rule_client.remove_targets(Rule=rule_name, Ids=ids) + + def list_targets_by_rule(self, rule_name: str) -> PageIterator: + """List EventBridge Targets of a given rule. + + Args: + rule_name: The name of the EventBridge Rule. + + Returns: + The page iterator of list_targets_by_rule call. + """ + return self.event_bridge_rule_client.get_paginator("list_targets_by_rule").paginate( + Rule=rule_name + ) + + def describe_rule(self, rule_name: str) -> Optional[Dict[str, Any]]: + """Describe the EventBridge Rule ARN corresponding to a sagemaker pipeline + + Args: + rule_name: The name of the EventBridge Rule. + Returns: + Optional[Dict[str, str]] : Describe EventBridge Rule response if exists. + """ + try: + event_bridge_rule_response = self.event_bridge_rule_client.describe_rule(Name=rule_name) + return event_bridge_rule_response + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("No EventBridge Rule found for pipeline %s.", rule_name) + return None + raise e + + def enable_rule(self, rule_name: str) -> None: + """Enables an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.enable_rule(Name=rule_name) + logger.info("Enabled EventBridge Rule for pipeline %s.", rule_name) + + def disable_rule(self, rule_name: str) -> None: + """Disables an EventBridge Rule of a given pipeline if there is one. + + Args: + rule_name: The name of the EventBridge Rule. + """ + self.event_bridge_rule_client.disable_rule(Name=rule_name) + logger.info("Disabled EventBridge Rule for pipeline %s.", rule_name) + + def add_tags(self, rule_arn: str, tags: List[TagsDict]) -> None: + """Adds tags to the EventBridge Rule. + + Args: + rule_arn: The ARN of the EventBridge Rule. + tags: List of tags to be added. + """ + self.event_bridge_rule_client.tag_resource(ResourceARN=rule_arn, Tags=tags) + + def _generate_event_pattern_from_feature_processor_pipeline_events( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> str: + """Generates the event pattern json string from the pipeline events. + + Args: + pipeline_events: List of pipeline events. + Returns: + str: The event pattern json string. + + Raises: + ValueError: If pipeline events contain duplicate pipeline names. + """ + + result_event_pattern = { + "detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], + } + filters = [] + desired_status_to_pipeline_names_map = ( + self._aggregate_pipeline_events_with_same_desired_status(pipeline_events) + ) + for desired_status in desired_status_to_pipeline_names_map: + pipeline_arns = [ + self._generate_pipeline_arn_and_name(pipeline_name)["pipeline_arn"] + for pipeline_name in desired_status_to_pipeline_names_map[desired_status] + ] + curr_filter = BASE_EVENT_PATTERN.copy() + curr_filter["detail"]["pipelineArn"] = pipeline_arns + curr_filter["detail"]["currentPipelineExecutionStatus"] = [ + status_enum.value for status_enum in desired_status + ] + filters.append(curr_filter) + if len(filters) > 1: + result_event_pattern["$or"] = filters + else: + result_event_pattern.update(filters[0]) + return json.dumps(result_event_pattern) + + def _validate_feature_processor_pipeline_events( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> None: + """Validates the pipeline events. + + Args: + pipeline_events: List of pipeline events. + Raises: + ValueError: If pipeline events contain duplicate pipeline names. + """ + + unique_pipelines = {event.pipeline_name for event in pipeline_events} + potential_infinite_loop = [] + if len(unique_pipelines) != len(pipeline_events): + raise ValueError("Pipeline names in pipeline_events must be unique.") + + for event in pipeline_events: + if FeatureProcessorPipelineExecutionStatus.EXECUTING in event.pipeline_execution_status: + potential_infinite_loop.append(event.pipeline_name) + if potential_infinite_loop: + logger.warning( + "Potential infinite loop detected for pipelines %s. " + "Setting pipeline_execution_status to EXECUTING might cause infinite loop. " + "Please consider a terminal status instead.", + potential_infinite_loop, + ) + + def _aggregate_pipeline_events_with_same_desired_status( + self, pipeline_events: List[FeatureProcessorPipelineEvents] + ) -> Dict[Tuple, List[str]]: + """Aggregate pipeline events with same desired status. + + e.g. + { + (FeatureProcessorPipelineExecutionStatus.FAILED, + FeatureProcessorPipelineExecutionStatus.STOPPED): + ["pipeline_name_1", "pipeline_name_2"], + (FeatureProcessorPipelineExecutionStatus.STOPPED, + FeatureProcessorPipelineExecutionStatus.STOPPED): + ["pipeline_name_3"], + } + Args: + pipeline_events: List of pipeline events. + Returns: + Dict[Tuple, List[str]]: A dictionary of desired status keys and corresponding pipeline + names. + """ + events_by_desired_status = {} + + for event in pipeline_events: + sorted_execution_status = sorted(event.pipeline_execution_status, key=lambda x: x.value) + desired_status_keys = tuple(sorted_execution_status) + + if desired_status_keys not in events_by_desired_status: + events_by_desired_status[desired_status_keys] = [] + events_by_desired_status[desired_status_keys].append(event.pipeline_name) + + return events_by_desired_status + + def _generate_pipeline_arn_and_name(self, pipeline_uri: str) -> Dict[str, str]: + """Generate pipeline arn and pipeline name from pipeline uri. + + Args: + pipeline_uri: The name or arn of the pipeline. + Returns: + Dict[str, str]: The arn and name of the pipeline. + """ + match = re.match(PIPELINE_ARN_REGEX_PATTERN, pipeline_uri) + pipeline_arn = "" + pipeline_name = "" + if not match: + pipeline_name = pipeline_uri + describe_pipeline_response = self.sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_arn = describe_pipeline_response["PipelineArn"] + else: + pipeline_arn = pipeline_uri + pipeline_name = match.group(4) + return dict(pipeline_arn=pipeline_arn, pipeline_name=pipeline_name) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py new file mode 100644 index 0000000000..f454a217e2 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_event_bridge_scheduler_helper.py @@ -0,0 +1,118 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for EventBridge Schedule management for a feature processor.""" +from __future__ import absolute_import +import logging +from datetime import datetime +from typing import Dict, Optional, Any +import attr +from botocore.exceptions import ClientError +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EVENT_BRIDGE_INVOCATION_TIME, + NO_FLEXIBLE_TIME_WINDOW, + RESOURCE_NOT_FOUND_EXCEPTION, +) + +logger = logging.getLogger("sagemaker") + + +@attr.s +class EventBridgeSchedulerHelper: + """Contains helper methods for scheduling events to EventBridge""" + + sagemaker_session: Session = attr.ib() + event_bridge_scheduler_client = attr.ib() + + def upsert_schedule( + self, + schedule_name: str, + pipeline_arn: str, + schedule_expression: str, + state: str, + start_date: datetime, + role: str, + ) -> Dict: + """Creates or updates a Schedule for the given pipeline_arn and schedule_expression. + + Args: + schedule_name: The name of the schedule. + pipeline_arn: The ARN of the sagemaker pipeline that needs to scheduled. + schedule_expression: The schedule expression. + state: Specifies whether the schedule is enabled or disabled. Can only + be ENABLED or DISABLED. + start_date: The date, in UTC, after which the schedule can begin invoking its target. + role: The RoleArn used to execute the scheduled events. + + Returns: + schedule_arn: The arn of the schedule. + """ + pipeline_parameter = dict( + PipelineParameterList=[ + dict( + Name=EXECUTION_TIME_PIPELINE_PARAMETER, + Value=EVENT_BRIDGE_INVOCATION_TIME, + ) + ] + ) + create_or_update_schedule_request_dict = dict( + Name=schedule_name, + ScheduleExpression=schedule_expression, + FlexibleTimeWindow=NO_FLEXIBLE_TIME_WINDOW, + Target=dict( + Arn=pipeline_arn, + SageMakerPipelineParameters=pipeline_parameter, + RoleArn=role, + ), + State=state, + StartDate=start_date, + ) + try: + return self.event_bridge_scheduler_client.update_schedule( + **create_or_update_schedule_request_dict + ) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + return self.event_bridge_scheduler_client.create_schedule( + **create_or_update_schedule_request_dict + ) + raise e + + def delete_schedule(self, schedule_name: str) -> None: + """Deletes an EventBridge Schedule of a given pipeline if there is one. + + Args: + schedule_name: The name of the EventBridge Schedule. + """ + logger.info("Deleting EventBridge Schedule for pipeline %s.", schedule_name) + self.event_bridge_scheduler_client.delete_schedule(Name=schedule_name) + + def describe_schedule(self, schedule_name) -> Optional[Dict[str, Any]]: + """Describe the EventBridge Schedule ARN corresponding to a sagemaker pipeline + + Args: + schedule_name: The name of the EventBridge Schedule. + Returns: + Optional[Dict[str, str]] : Describe EventBridge Schedule response if exists. + """ + try: + event_bridge_scheduler_response = self.event_bridge_scheduler_client.get_schedule( + Name=schedule_name + ) + return event_bridge_scheduler_response + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("No EventBridge Schedule found for pipeline %s.", schedule_name) + return None + raise e diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py new file mode 100644 index 0000000000..0b21d10ab9 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_exceptions.py @@ -0,0 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module stores exceptions related to the feature_store.feature_processor module.""" +from __future__ import absolute_import + + +class IngestionError(Exception): + """Exception raised to indicate that ingestion did not complete successfully.""" diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py new file mode 100644 index 0000000000..f205c32665 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_factory.py @@ -0,0 +1,167 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains static factory classes to instantiate complex objects for the FeatureProcessor.""" +from __future__ import absolute_import + +from typing import Dict +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._params_loader import ( + ParamsLoader, + SystemParamsLoader, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, + SparkSessionFactory, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import SparkArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + SparkOutputReceiver, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ( + FeatureProcessorArgValidator, + InputValidator, + SparkUDFSignatureValidator, + InputOffsetValidator, + BaseDataSourceValidator, + ValidatorChain, +) + + +class ValidatorFactory: + """Static factory to handle ValidationChain instantiation.""" + + @staticmethod + def get_validation_chain(fp_config: FeatureProcessorConfig) -> ValidatorChain: + """Instantiate a ValidationChain""" + base_validators = [ + InputValidator(), + FeatureProcessorArgValidator(), + InputOffsetValidator(), + BaseDataSourceValidator(), + ] + + mode = fp_config.mode + if FeatureProcessorMode.PYSPARK == mode: + base_validators.append(SparkUDFSignatureValidator()) + return ValidatorChain(validators=base_validators) + + raise ValueError(f"FeatureProcessorMode {mode} is not supported.") + + +class UDFWrapperFactory: + """Static factory to handle UDFWrapper instantiation at runtime.""" + + @staticmethod + def get_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper: + """Instantiate a UDFWrapper based on the FeatureProcessingMode. + + Args: + fp_config (FeatureProcessorConfig): the configuration values for the + feature_processor decorator. + + Raises: + ValueError: if an unsupported FeatureProcessorMode is provided in fp_config. + + Returns: + UDFWrapper: An instance of UDFWrapper to decorate the UDF. + """ + mode = fp_config.mode + + if FeatureProcessorMode.PYSPARK == mode: + return UDFWrapperFactory._get_spark_udf_wrapper(fp_config) + + raise ValueError(f"FeatureProcessorMode {mode} is not supported.") + + @staticmethod + def _get_spark_udf_wrapper(fp_config: FeatureProcessorConfig) -> UDFWrapper[DataFrame]: + """Instantiate a new UDFWrapper for PySpark functions. + + Args: + fp_config (FeatureProcessorConfig): the configuration values for the feature_processor + decorator. + """ + spark_session_factory = UDFWrapperFactory._get_spark_session_factory(fp_config.spark_config) + feature_store_manager_factory = UDFWrapperFactory._get_feature_store_manager_factory() + + output_manager = UDFWrapperFactory._get_spark_output_receiver(feature_store_manager_factory) + arg_provider = UDFWrapperFactory._get_spark_arg_provider(spark_session_factory) + + return UDFWrapper[DataFrame](arg_provider, output_manager) + + @staticmethod + def _get_spark_arg_provider( + spark_session_factory: SparkSessionFactory, + ) -> SparkArgProvider: + """Instantiate a new SparkArgProvider for PySpark functions. + + Args: + spark_session_factory (SparkSessionFactory): A factory to provide a reference to the + SparkSession initialized for the feature_processor wrapped function. The factory + lazily loads the SparkSession, i.e. defers to function execution time. + + Returns: + SparkArgProvider: An instance that generates arguments to provide to the + feature_processor wrapped function. + """ + environment_helper = EnvironmentHelper() + + system_parameters_arg_provider = SystemParamsLoader(environment_helper) + params_arg_provider = ParamsLoader(system_parameters_arg_provider) + input_loader = SparkDataFrameInputLoader(spark_session_factory, environment_helper) + + return SparkArgProvider(params_arg_provider, input_loader, spark_session_factory) + + @staticmethod + def _get_spark_output_receiver( + feature_store_manager_factory: FeatureStoreManagerFactory, + ) -> SparkOutputReceiver: + """Instantiate a new SparkOutputManager for PySpark functions. + + Args: + feature_store_manager_factory (FeatureStoreManagerFactory): A factory to provide + that provides a FeatureStoreManager that handles data ingestion to a Feature Group. + The factory lazily loads the FeatureStoreManager. + + Returns: + SparkOutputReceiver: An instance that handles outputs of the wrapped function. + """ + return SparkOutputReceiver(feature_store_manager_factory) + + @staticmethod + def _get_spark_session_factory(spark_config: Dict[str, str]) -> SparkSessionFactory: + """Instantiate a new SparkSessionFactory + + Args: + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. + + Returns: + SparkSessionFactory: A Spark session factory instance. + """ + environment_helper = EnvironmentHelper() + return SparkSessionFactory(environment_helper, spark_config) + + @staticmethod + def _get_feature_store_manager_factory() -> FeatureStoreManagerFactory: + """Instantiate a new FeatureStoreManagerFactory""" + return FeatureStoreManagerFactory() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py new file mode 100644 index 0000000000..f5d4dd91f3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_config.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains data classes for the FeatureProcessor.""" +from __future__ import absolute_import + +from typing import Dict, List, Optional, Sequence, Union + +import attr + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode + + +@attr.s(frozen=True) +class FeatureProcessorConfig: + """Immutable data class containing the arguments for a FeatureProcessor. + + This class is used throughout sagemaker.mlops.feature_store.feature_processor module. Documentation + for each field can be be found in the feature_processor decorator. + + Defaults are defined as literals in the feature_processor decorator's parameters for usability + (i.e. literals in docs). Defaults, or any business logic, should not be added to this class. + It only serves as an immutable data class. + """ + + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib() + output: str = attr.ib() + mode: FeatureProcessorMode = attr.ib() + target_stores: Optional[List[str]] = attr.ib() + parameters: Optional[Dict[str, Union[str, Dict]]] = attr.ib() + enable_ingestion: bool = attr.ib() + spark_config: Dict[str, str] = attr.ib() + + @staticmethod + def create( + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], + output: str, + mode: FeatureProcessorMode, + target_stores: Optional[List[str]], + parameters: Optional[Dict[str, Union[str, Dict]]], + enable_ingestion: bool, + spark_config: Dict[str, str], + ) -> "FeatureProcessorConfig": + """Static initializer.""" + return FeatureProcessorConfig( + inputs=inputs, + output=output, + mode=mode, + target_stores=target_stores, + parameters=parameters, + enable_ingestion=enable_ingestion, + spark_config=spark_config, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py new file mode 100644 index 0000000000..4ce9fb1b76 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_feature_processor_pipeline_events.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains data classes for the Feature Processor Pipeline Events.""" +from __future__ import absolute_import + +from typing import List +import attr +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorPipelineExecutionStatus + + +@attr.s(frozen=True) +class FeatureProcessorPipelineEvents: + """Immutable data class containing the execution events for a FeatureProcessor pipeline. + + This class is used for creating event based triggers for feature processor pipelines. + """ + + pipeline_name: str = attr.ib() + pipeline_execution_status: List[FeatureProcessorPipelineExecutionStatus] = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py new file mode 100644 index 0000000000..3e82262858 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py @@ -0,0 +1,366 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes that loads user specified input sources (e.g. Feature Groups, S3 URIs, etc).""" +from __future__ import absolute_import + +import logging +import re +from abc import ABC, abstractmethod +from typing import Generic, Optional, TypeVar, Union + +import attr +from pyspark.sql import DataFrame + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import FEATURE_GROUP_ARN_REGEX_PATTERN +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + IcebergTableDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +T = TypeVar("T") + +logger = logging.getLogger("sagemaker") + + +class InputLoader(Generic[T], ABC): + """Loads the contents of a Feature Group's offline store or contents at an S3 URI.""" + + @abstractmethod + def load_from_feature_group(self, feature_group_data_source: FeatureGroupDataSource) -> T: + """Load the data from a Feature Group's offline store. + + Args: + feature_group_data_source (FeatureGroupDataSource): the feature group source. + + Returns: + T: The contents of the offline store as an instance of type T. + """ + + @abstractmethod + def load_from_s3(self, s3_data_source: Union[CSVDataSource, ParquetDataSource]) -> T: + """Load the contents from an S3 based data source. + + Args: + s3_data_source (Union[CSVDataSource, ParquetDataSource]): a data source that is based + in S3. + + Returns: + T: The contents stored at the data source as an instance of type T. + """ + + +@attr.s +class SparkDataFrameInputLoader(InputLoader[DataFrame]): + """InputLoader that reads data in as a Spark DataFrame.""" + + spark_session_factory: SparkSessionFactory = attr.ib() + environment_helper: EnvironmentHelper = attr.ib() + sagemaker_session: Optional[Session] = attr.ib(default=None) + + _supported_table_format = ["Iceberg", "Glue", None] + + def load_from_feature_group( + self, feature_group_data_source: FeatureGroupDataSource + ) -> DataFrame: + """Load the contents of a Feature Group's offline store as a DataFrame. + + Args: + feature_group_data_source (FeatureGroupDataSource): the Feature Group source. + + Raises: + ValueError: If the Feature Group does not have an Offline Store. + ValueError: If the Feature Group's Table Type is not supported by the feature_processor. + + Returns: + DataFrame: A Spark DataFrame containing the contents of the Feature Group's + offline store. + """ + sagemaker_session: Session = self.sagemaker_session or Session() + + feature_group_name = feature_group_data_source.name + feature_group = sagemaker_session.sagemaker_client.describe_feature_group( + FeatureGroupName=self._parse_name_from_arn(feature_group_name) + ) + logger.debug( + "Called describe_feature_group with %s and received: %s", + feature_group_name, + feature_group, + ) + + if "OfflineStoreConfig" not in feature_group: + raise ValueError( + f"Input Feature Groups must have an enabled Offline Store." + f" Feature Group: {feature_group_name} does not have an Offline Store enabled." + ) + + offline_store_uri = feature_group["OfflineStoreConfig"]["S3StorageConfig"][ + "ResolvedOutputS3Uri" + ] + + table_format = feature_group["OfflineStoreConfig"].get("TableFormat", None) + + if table_format not in self._supported_table_format: + raise ValueError( + f"Feature group with table format {table_format} is not supported. " + f"The table format should be one of {self._supported_table_format}." + ) + + start_offset = feature_group_data_source.input_start_offset + end_offset = feature_group_data_source.input_end_offset + + if table_format == "Iceberg": + data_catalog_config = feature_group["OfflineStoreConfig"]["DataCatalogConfig"] + return self.load_from_iceberg_table( + IcebergTableDataSource( + offline_store_uri, + data_catalog_config["Catalog"], + data_catalog_config["Database"], + data_catalog_config["TableName"], + ), + feature_group["EventTimeFeatureName"], + start_offset, + end_offset, + ) + + return self.load_from_date_partitioned_s3( + ParquetDataSource(offline_store_uri), start_offset, end_offset + ) + + def load_from_date_partitioned_s3( + self, + s3_data_source: ParquetDataSource, + input_start_offset: str, + input_end_offset: str, + ) -> DataFrame: + """Load the contents from a Feature Group's partitioned offline S3 as a DataFrame. + + Args: + s3_data_source (ParquetDataSource): + A data source that is based in S3. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the data loaded from S3. + """ + + spark_session = self.spark_session_factory.spark_session + s3a_uri = s3_data_source.s3_uri.replace("s3://", "s3a://") + filter_condition = self._get_s3_partitions_offset_filter_condition( + input_start_offset, input_end_offset + ) + + logger.info( + "Loading data from %s with filtering condition %s.", + s3a_uri, + filter_condition, + ) + input_df = spark_session.read.parquet(s3a_uri) + if filter_condition: + input_df = input_df.filter(filter_condition) + + return input_df + + def load_from_s3(self, s3_data_source: Union[CSVDataSource, ParquetDataSource]) -> DataFrame: + """Load the contents from an S3 based data source as a DataFrame. + + Args: + s3_data_source (Union[CSVDataSource, ParquetDataSource]): + A data source that is based in S3. + + Raises: + ValueError: If an invalid DataSource is provided. + + Returns: + DataFrame: Contents of the data loaded from S3. + """ + spark_session = self.spark_session_factory.spark_session + s3a_uri = s3_data_source.s3_uri.replace("s3://", "s3a://") + + if isinstance(s3_data_source, CSVDataSource): + # TODO: Accept `schema` parameter. (Inferring schema requires a pass through every row) + logger.info("Loading data from %s.", s3a_uri) + return spark_session.read.csv( + s3a_uri, + header=s3_data_source.csv_header, + inferSchema=s3_data_source.csv_infer_schema, + ) + + if isinstance(s3_data_source, ParquetDataSource): + logger.info("Loading data from %s.", s3a_uri) + return spark_session.read.parquet(s3a_uri) + + raise ValueError("An invalid data source was provided.") + + def load_from_iceberg_table( + self, + iceberg_table_data_source: IcebergTableDataSource, + event_time_feature_name: str, + input_start_offset: str, + input_end_offset: str, + ) -> DataFrame: + """Load the contents from an Iceberg table as a DataFrame. + + Args: + iceberg_table_data_source (IcebergTableDataSource): An Iceberg Table source. + event_time_feature_name (str): Event time feature's name of feature group. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the Iceberg Table as a Spark DataFrame. + """ + catalog = iceberg_table_data_source.catalog.lower() + database = iceberg_table_data_source.database.lower() + table = iceberg_table_data_source.table.lower() + iceberg_table = f"{catalog}.{database}.{table}" + + spark_session = self.spark_session_factory.get_spark_session_with_iceberg_config( + iceberg_table_data_source.warehouse_s3_uri, catalog + ) + + filter_condition = self._get_iceberg_offset_filter_condition( + event_time_feature_name, + input_start_offset, + input_end_offset, + ) + + iceberg_df = spark_session.table(iceberg_table) + + if filter_condition: + logger.info( + "The filter condition for iceberg feature group is %s.", + filter_condition, + ) + iceberg_df = iceberg_df.filter(filter_condition) + + return iceberg_df + + def _get_iceberg_offset_filter_condition( + self, + event_time_feature_name: str, + input_start_offset: str, + input_end_offset: str, + ): + """Load the contents from an Iceberg table as a DataFrame. + + Args: + iceberg_table_data_source (IcebergTableDataSource): An Iceberg Table source. + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + DataFrame: Contents of the Iceberg Table as a Spark DataFrame. + """ + if input_start_offset is None and input_end_offset is None: + return None + + offset_parser = InputOffsetParser(self.environment_helper.get_job_scheduled_time()) + start_offset_time = offset_parser.get_iso_format_offset_date(input_start_offset) + end_offset_time = offset_parser.get_iso_format_offset_date(input_end_offset) + + start_condition = ( + f"{event_time_feature_name} >= '{start_offset_time}'" if input_start_offset else None + ) + end_condition = ( + f"{event_time_feature_name} < '{end_offset_time}'" if input_end_offset else None + ) + + conditions = filter(None, [start_condition, end_condition]) + return " AND ".join(conditions) + + def _get_s3_partitions_offset_filter_condition( + self, input_start_offset: str, input_end_offset: str + ) -> str: + """Get s3 partitions filter condition based on input offsets. + + Args: + input_start_offset (str): Start offset that is used to calculate the input start date. + input_end_offset (str): End offset that is used to calculate the input end date. + + Returns: + str: A SQL string that defines the condition of time range filter. + """ + if input_start_offset is None and input_end_offset is None: + return None + + offset_parser = InputOffsetParser(self.environment_helper.get_job_scheduled_time()) + ( + start_year, + start_month, + start_day, + start_hour, + ) = offset_parser.get_offset_date_year_month_day_hour(input_start_offset) + ( + end_year, + end_month, + end_day, + end_hour, + ) = offset_parser.get_offset_date_year_month_day_hour(input_end_offset) + + # Include all records that event time is between start_year and end_year + start_year_include_condition = f"year >= '{start_year}'" if input_start_offset else None + end_year_include_condition = f"year <= '{end_year}'" if input_end_offset else None + year_include_condition = " AND ".join( + filter(None, [start_year_include_condition, end_year_include_condition]) + ) + + # Exclude all records that the event time is earlier than the start or later than the end + start_offset_exclude_condition = ( + f"(year = '{start_year}' AND month < '{start_month}') " + f"OR (year = '{start_year}' AND month = '{start_month}' AND day < '{start_day}') " + f"OR (year = '{start_year}' AND month = '{start_month}' AND day = '{start_day}' " + f"AND hour < '{start_hour}')" + if input_start_offset + else None + ) + end_offset_exclude_condition = ( + f"(year = '{end_year}' AND month > '{end_month}') " + f"OR (year = '{end_year}' AND month = '{end_month}' AND day > '{end_day}') " + f"OR (year = '{end_year}' AND month = '{end_month}' AND day = '{end_day}' " + f"AND hour >= '{end_hour}')" + if input_end_offset + else None + ) + offset_exclude_condition = " OR ".join( + filter(None, [start_offset_exclude_condition, end_offset_exclude_condition]) + ) + + filter_condition = f"({year_include_condition}) AND NOT ({offset_exclude_condition})" + + logger.info("The filter condition for hive feature group is %s.", filter_condition) + + return filter_condition + + def _parse_name_from_arn(self, fg_uri: str) -> str: + """Parse a Feature Group's name from an arn. + + Args: + fg_uri (str): a string identifier of the Feature Group. + + Returns: + str: the name of the feature group. + """ + match = re.match(FEATURE_GROUP_ARN_REGEX_PATTERN, fg_uri) + if match: + feature_group_name = match.group(4) + return feature_group_name + return fg_uri diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py new file mode 100644 index 0000000000..89d816af49 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_offset_parser.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class that parse the input data start and end offset""" +from __future__ import absolute_import + +import re +from typing import Optional, Tuple, Union +from datetime import datetime, timezone +from dateutil.relativedelta import relativedelta +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) + +UNIT_RE = r"(\d+?)\s+([a-z]+?)s?" +VALID_UNITS = ["hour", "day", "week", "month", "year"] + + +class InputOffsetParser: + """Contains methods to parse the input offset to different formats. + + Args: + now (datetime): + The point of time that the parser should calculate offset against. + """ + + def __init__(self, now: Union[datetime, str] = None) -> None: + if now is None: + self.now = datetime.now(timezone.utc) + elif isinstance(now, datetime): + self.now = now + else: + self.now = datetime.strptime(now, EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + + def get_iso_format_offset_date(self, offset: Optional[str]) -> str: + """Get the iso format of target date based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + str: ISO-8061 formatted string of the offset date. + """ + if offset is None: + return None + + offset_datetime = self.get_offset_datetime(offset) + return offset_datetime.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT) + + def get_offset_datetime(self, offset: Optional[str]) -> datetime: + """Get the datetime format of target date based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + datetime: datetime instance of the offset date. + """ + if offset is None: + return None + + offset_td = InputOffsetParser.parse_offset_to_timedelta(offset) + + return self.now + offset_td + + def get_offset_date_year_month_day_hour( + self, offset: Optional[str] + ) -> Tuple[str, str, str, str]: + """Get the year, month, day and hour based on offset diff. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Returns: + Tuple[str, str, str, str]: A tuple that consists of extracted year, month, day, hour from offset date. + """ + if offset is None: + return (None, None, None, None) + + offset_dt = self.get_offset_datetime(offset) + return ( + offset_dt.strftime("%Y"), + offset_dt.strftime("%m"), + offset_dt.strftime("%d"), + offset_dt.strftime("%H"), + ) + + @staticmethod + def parse_offset_to_timedelta(offset: Optional[str]) -> relativedelta: + """Parse the offset to time delta. + + Args: + offset (Optional[str]): Offset that is used for target date calcluation. + + Raises: + ValueError: If an offset is provided in a unrecognizable format. + ValueError: If an invalid offset unit is provided. + + Returns: + reletivedelta: Time delta representation of the time offset. + """ + if offset is None: + return None + + unit_match = re.fullmatch(UNIT_RE, offset) + + if not unit_match: + raise ValueError( + f"[{offset}] is not in a valid offset format. " + "Please pass a valid offset e.g '1 day'." + ) + + multiple, unit = unit_match.groups() + + if unit not in VALID_UNITS: + raise ValueError(f"[{unit}] is not a valid offset unit. Supported units: {VALID_UNITS}") + + shift_args = {f"{unit}s": -int(multiple)} + + return relativedelta(**shift_args) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py new file mode 100644 index 0000000000..f5be546e86 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_params_loader.py @@ -0,0 +1,83 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for loading the 'params' argument for the UDF.""" +from __future__ import absolute_import + +from typing import Dict, Union + +import attr + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +@attr.s +class SystemParamsLoader: + """Provides the fields for the params['system'] namespace. + + These are the parameters that the feature_processor automatically loads from various SageMaker + resources. + """ + + _SYSTEM_PARAMS_KEY = "system" + + environment_helper: EnvironmentHelper = attr.ib() + + def get_system_args(self) -> Dict[str, Union[str, Dict]]: + """Generates the system generated parameters for the feature_processor wrapped function. + + Args: + fp_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + + Returns: + Dict[str, Union[str, Dict]]: The system parameters. + """ + + return { + self._SYSTEM_PARAMS_KEY: { + "scheduled_time": self.environment_helper.get_job_scheduled_time(), + } + } + + +@attr.s +class ParamsLoader: + """Provides 'params' argument for the FeatureProcessor.""" + + _PARAMS_KEY = "params" + + system_parameters_arg_provider: SystemParamsLoader = attr.ib() + + def get_parameter_args( + self, + fp_config: FeatureProcessorConfig, + ) -> Dict[str, Union[str, Dict]]: + """Loads the 'params' argument for the FeatureProcessor. + + Args: + fp_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + + Returns: + Dict[str, Union[str, Dict]]: A dictionary that contains both user provided + parameters (feature_processor argument) and system parameters. + """ + return { + self._PARAMS_KEY: { + **(fp_config.parameters or {}), + **self.system_parameters_arg_provider.get_system_args(), + } + } diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py new file mode 100644 index 0000000000..d304185e85 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_spark_factory.py @@ -0,0 +1,202 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains factory classes for instantiating Spark objects.""" +from __future__ import absolute_import + +from functools import lru_cache +from typing import List, Tuple, Dict + +import feature_store_pyspark +import feature_store_pyspark.FeatureStoreManager as fsm +from pyspark.conf import SparkConf +from pyspark.context import SparkContext +from pyspark.sql import SparkSession + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +SPARK_APP_NAME = "FeatureProcessor" + + +class SparkSessionFactory: + """Lazy loading, memoizing, instantiation of SparkSessions. + + Useful when you want to defer SparkSession instantiation and provide access to the same + instance throughout the application. + """ + + def __init__( + self, environment_helper: EnvironmentHelper, spark_config: Dict[str, str] = None + ) -> None: + """Initialize the SparkSessionFactory. + + Args: + environment_helper (EnvironmentHelper): A helper class to determine the current + execution. + spark_config (Dict[str, str]): The Spark configuration that will be passed to the + initialization of Spark session. + """ + self.environment_helper = environment_helper + self.spark_config = spark_config + + @property + @lru_cache() + def spark_session(self) -> SparkSession: + """Instantiate a new SparkSession or return the existing one.""" + is_training_job = self.environment_helper.is_training_job() + instance_count = self.environment_helper.get_instance_count() + + spark_configs = self._get_spark_configs(is_training_job) + spark_conf = SparkConf().setAll(spark_configs).setAppName(SPARK_APP_NAME) + + if instance_count == 1: + spark_conf.setMaster("local[*]") + + sc = SparkContext.getOrCreate(conf=spark_conf) + + jsc = sc._jsc # Java Spark Context (JVM SparkContext) + for cfg in self._get_jsc_hadoop_configs(): + jsc.hadoopConfiguration().set(cfg[0], cfg[1]) + + return SparkSession(sparkContext=sc) + + def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]: + """Generate Spark Configurations optimized for feature_processing functionality. + + Args: + is_training_job (bool): a boolean indicating whether the current execution environment + is a training job or not. + + Returns: + List[Tuple[str, str]]: Spark configurations. + """ + spark_configs = [ + ( + "spark.hadoop.fs.s3a.aws.credentials.provider", + ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ), + ), + # spark-3.3.1#recommended-settings-for-writing-to-object-stores - https://tinyurl.com/54rkhef6 + ("spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version", "2"), + ( + "spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored", + "true", + ), + ("spark.hadoop.parquet.enable.summary-metadata", "false"), + # spark-3.3.1#parquet-io-settings https://tinyurl.com/59a7uhwu + ("spark.sql.parquet.mergeSchema", "false"), + ("spark.sql.parquet.filterPushdown", "true"), + ("spark.sql.hive.metastorePartitionPruning", "true"), + # hadoop-aws#performance - https://tinyurl.com/mutxj96f + ("spark.hadoop.fs.s3a.threads.max", "500"), + ("spark.hadoop.fs.s3a.connection.maximum", "500"), + ("spark.hadoop.fs.s3a.experimental.input.fadvise", "normal"), + ("spark.hadoop.fs.s3a.block.size", "128M"), + ("spark.hadoop.fs.s3a.fast.upload.buffer", "disk"), + ("spark.hadoop.fs.trash.interval", "0"), + ("spark.port.maxRetries", "50"), + ] + + if self.spark_config: + spark_configs.extend(self.spark_config.items()) + + if not is_training_job: + fp_spark_jars = feature_store_pyspark.classpath_jars() + fp_spark_packages = [ + "org.apache.hadoop:hadoop-aws:3.3.1", + "org.apache.hadoop:hadoop-common:3.3.1", + ] + + if self.spark_config and "spark.jars" in self.spark_config: + fp_spark_jars.append(self.spark_config.get("spark.jars")) + + if self.spark_config and "spark.jars.packages" in self.spark_config: + fp_spark_packages.append(self.spark_config.get("spark.jars.packages")) + + spark_configs.extend( + ( + ("spark.jars", ",".join(fp_spark_jars)), + ( + "spark.jars.packages", + ",".join(fp_spark_packages), + ), + ) + ) + + return spark_configs + + def _get_jsc_hadoop_configs(self) -> List[Tuple[str, str]]: + """JVM SparkContext Hadoop configurations.""" + # Skip generation of _SUCCESS files to speed up writes. + return [("mapreduce.fileoutputcommitter.marksuccessfuljobs", "false")] + + def _get_iceberg_configs(self, warehouse_s3_uri: str, catalog: str) -> List[Tuple[str, str]]: + """Spark configurations for reading and writing data from Iceberg Table sources. + + Args: + warehouse_s3_uri (str): The S3 URI of the warehouse. + catalog (str): The catalog. + + Returns: + List[Tuple[str, str]]: the Spark configurations. + """ + catalog = catalog.lower() + return [ + (f"spark.sql.catalog.{catalog}", "smfs.shaded.org.apache.iceberg.spark.SparkCatalog"), + (f"spark.sql.catalog.{catalog}.warehouse", warehouse_s3_uri), + ( + f"spark.sql.catalog.{catalog}.catalog-impl", + "smfs.shaded.org.apache.iceberg.aws.glue.GlueCatalog", + ), + ( + f"spark.sql.catalog.{catalog}.io-impl", + "smfs.shaded.org.apache.iceberg.aws.s3.S3FileIO", + ), + (f"spark.sql.catalog.{catalog}.glue.skip-name-validation", "true"), + ] + + def get_spark_session_with_iceberg_config(self, warehouse_s3_uri, catalog) -> SparkSession: + """Get an instance of the SparkSession with Iceberg settings configured. + + Args: + warehouse_s3_uri (str): The S3 URI of the warehouse. + catalog (str): The catalog. + + Returns: + SparkSession: A SparkSession ready to support reading and writing data from an Iceberg + Table. + """ + conf = self.spark_session.conf + + for cfg in self._get_iceberg_configs(warehouse_s3_uri, catalog): + conf.set(cfg[0], cfg[1]) + + return self.spark_session + + +class FeatureStoreManagerFactory: + """Lazy loading, memoizing, instantiation of FeatureStoreManagers. + + Useful when you want to defer FeatureStoreManagers instantiation and provide access to the same + instance throughout the application. + """ + + @property + @lru_cache() + def feature_store_manager(self) -> fsm.FeatureStoreManager: + """Instansiate a new FeatureStoreManager.""" + return fsm.FeatureStoreManager() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py new file mode 100644 index 0000000000..bd21e804eb --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_arg_provider.py @@ -0,0 +1,239 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for loading arguments for the parameters defined in the UDF.""" +from __future__ import absolute_import + +from abc import ABC, abstractmethod +from inspect import signature +from typing import Any, Callable, Dict, Generic, List, OrderedDict, TypeVar, Union, Optional + +import attr +from pyspark.sql import DataFrame, SparkSession + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, + PySparkDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._params_loader import ParamsLoader +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory + +T = TypeVar("T") + + +@attr.s +class UDFArgProvider(Generic[T], ABC): + """Base class for arguments providers for the UDF. + + Args: + Generic (T): The type of the auto-loaded data values. + """ + + @abstractmethod + def provide_input_args( + self, udf: Callable[..., T], fp_config: FeatureProcessorConfig + ) -> OrderedDict[str, T]: + """Provides a dict of (input name, auto-loaded data) using the feature_processor parameters. + + The input name is the udfs parameter name, and the data source is the one defined at the + same index (as the input name) in fp_config.inputs. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + OrderedDict[str, T]: The loaded data sources, in the same order as fp_config.inputs. + """ + + @abstractmethod + def provide_params_arg( + self, udf: Callable[..., T], fp_config: FeatureProcessorConfig + ) -> Dict[str, Dict]: + """Provides the 'params' argument that is provided to the UDF. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Dict[str, Dict]: A combination of user defined parameters (in fp_config) and system + provided parameters. + """ + + @abstractmethod + def provide_additional_kwargs(self, udf: Callable[..., T]) -> Dict[str, Any]: + """Provides any additional arguments to be provided to the UDF, dependent on the mode. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + + Returns: + Dict[str, Any]: additional kwargs for the user function. + """ + + +@attr.s +class SparkArgProvider(UDFArgProvider[DataFrame]): + """Provides arguments to Spark UDFs.""" + + PARAMS_ARG_NAME = "params" + SPARK_SESSION_ARG_NAME = "spark" + + params_loader: ParamsLoader = attr.ib() + input_loader: SparkDataFrameInputLoader = attr.ib() + spark_session_factory: SparkSessionFactory = attr.ib() + + def provide_input_args( + self, udf: Callable[..., DataFrame], fp_config: FeatureProcessorConfig + ) -> OrderedDict[str, DataFrame]: + """Provide a DataFrame for each requested input. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If the signature of the UDF does not match the fp_config.inputs. + ValueError: If there are no inputs provided to the user defined function. + + Returns: + OrderedDict[str, DataFrame]: The loaded data sources, in the same order as + fp_config.inputs. + """ + udf_parameter_names = list(signature(udf).parameters.keys()) + udf_input_names = self._get_input_parameters(udf_parameter_names) + udf_params = self.params_loader.get_parameter_args(fp_config).get( + self.PARAMS_ARG_NAME, None + ) + + if len(udf_input_names) == 0: + raise ValueError("Expected at least one input to the user defined function.") + + if len(udf_input_names) != len(fp_config.inputs): + raise ValueError( + f"The signature of the user defined function does not match the list of inputs" + f" requested. Expected {len(fp_config.inputs)} parameter(s)." + ) + + return OrderedDict( + (input_name, self._load_data_frame(data_source=input_uri, params=udf_params)) + for (input_name, input_uri) in zip(udf_input_names, fp_config.inputs) + ) + + def provide_params_arg( + self, udf: Callable[..., DataFrame], fp_config: FeatureProcessorConfig + ) -> Dict[str, Union[str, Dict]]: + """Provide params for the UDF. If the udf has a parameter named 'params'. + + Args: + udf (Callable[..., T]): the feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + return ( + self.params_loader.get_parameter_args(fp_config) + if self._has_param(udf, self.PARAMS_ARG_NAME) + else {} + ) + + def provide_additional_kwargs(self, udf: Callable[..., DataFrame]) -> Dict[str, SparkSession]: + """Provide the Spark session. If the udf has a parameter named 'spark'. + + Args: + udf (Callable[..., T]): the feature_processor wrapped user function. + """ + return ( + {self.SPARK_SESSION_ARG_NAME: self.spark_session_factory.spark_session} + if self._has_param(udf, self.SPARK_SESSION_ARG_NAME) + else {} + ) + + def _get_input_parameters(self, udf_parameter_names: List[str]) -> List[str]: + """Parses the parameter names from the UDF that correspond to the input data sources. + + This function assumes that the udf signature's `params` and `spark` parameters are at the + end, in any order, if provided. + + Args: + udf_parameter_names (List[str]): The full list of parameters names in the UDF. + + Returns: + List[str]: A subset of parameter names corresponding to the input data sources. + """ + inputs_end_index = len(udf_parameter_names) - 1 + + # Reduce range based on the position of optional kwargs of the UDF. + if self.PARAMS_ARG_NAME in udf_parameter_names: + inputs_end_index = udf_parameter_names.index(self.PARAMS_ARG_NAME) - 1 + + if self.SPARK_SESSION_ARG_NAME in udf_parameter_names: + inputs_end_index = min( + inputs_end_index, + udf_parameter_names.index(self.SPARK_SESSION_ARG_NAME) - 1, + ) + + return udf_parameter_names[: inputs_end_index + 1] + + def _load_data_frame( + self, + data_source: Union[ + FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource + ], + params: Optional[Dict[str, Union[str, Dict]]] = None, + ) -> DataFrame: + """Given a data source definition, load the data as a Spark DataFrame. + + Args: + data_source (Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]): A user specified data source from the feature_processor + decorator's parameters. + params (Optional[Dict[str, Union[str, Dict]]]): Parameters provided to the + feature_processor decorator. + + Returns: + DataFrame: The contents of the data source as a Spark DataFrame. + """ + if isinstance(data_source, (CSVDataSource, ParquetDataSource)): + return self.input_loader.load_from_s3(data_source) + + if isinstance(data_source, FeatureGroupDataSource): + return self.input_loader.load_from_feature_group(data_source) + + if isinstance(data_source, PySparkDataSource): + spark_session = self.spark_session_factory.spark_session + return data_source.read_data(spark=spark_session, params=params) + + if isinstance(data_source, BaseDataSource): + return data_source.read_data(params=params) + + raise ValueError(f"Unknown data source type: {type(data_source)}") + + def _has_param(self, udf: Callable, name: str) -> bool: + """Determine if a function has a parameter with a given name. + + Args: + udf (Callable): the user defined function. + name (str): the name of the parameter. + + Returns: + bool: True if the udf contains a parameter with the name. + """ + return name in list(signature(udf).parameters.keys()) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py new file mode 100644 index 0000000000..a037e837c2 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_output_receiver.py @@ -0,0 +1,98 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes for handling UDF outputs""" +from __future__ import absolute_import + +import logging +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +import attr +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor import IngestionError +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, +) + +T = TypeVar("T") + +logger = logging.getLogger("sagemaker") + + +class UDFOutputReceiver(Generic[T], ABC): + """Base class for handling outputs of the UDF.""" + + @abstractmethod + def ingest_udf_output(self, output: T, fp_config: FeatureProcessorConfig) -> None: + """Ingests data to the output feature group. + + Args: + output (T): The output of the feature_processor wrapped function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + + +@attr.s +class SparkOutputReceiver(UDFOutputReceiver[DataFrame]): + """Handles the Spark DataFrame the output from the UDF""" + + feature_store_manager_factory: FeatureStoreManagerFactory = attr.ib() + + def ingest_udf_output(self, output: DataFrame, fp_config: FeatureProcessorConfig) -> None: + """Ingests UDF to the output Feature Group. + + Args: + output (T): The output of the feature_processor wrapped function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + Py4JError: If there is a problem with Py4J, including client code errors. + IngestionError: If any rows are not ingested successfully then a sample of the records, + with failure reasons, is logged. + """ + if fp_config.enable_ingestion is False: + logging.info("Ingestion is disabled. Skipping ingestion.") + return + + logger.info( + "Ingesting transformed data to %s with target_stores: %s", + fp_config.output, + fp_config.target_stores, + ) + + feature_store_manager = self.feature_store_manager_factory.feature_store_manager + try: + feature_store_manager.ingest_data( + input_data_frame=output, + feature_group_arn=fp_config.output, + target_stores=fp_config.target_stores, + ) + except Py4JJavaError as e: + if e.java_exception.getClass().getSimpleName() == "StreamIngestionFailureException": + logger.warning( + "Ingestion did not complete successfully. Failed records and error messages" + " have been printed to the console." + ) + feature_store_manager.get_failed_stream_ingestion_data_frame().show( + n=20, truncate=False + ) + raise IngestionError(e.java_exception) + + raise e + + logger.info("Ingestion to %s complete.", fp_config.output) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py new file mode 100644 index 0000000000..95b07de7c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_udf_wrapper.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module provides a wrapper for user provided functions.""" +from __future__ import absolute_import + +import functools +from typing import Any, Callable, Dict, Generic, Tuple, TypeVar + +import attr + +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import UDFArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + UDFOutputReceiver, +) + +T = TypeVar("T") + + +@attr.s +class UDFWrapper(Generic[T]): + """Class that wraps a user provided function.""" + + udf_arg_provider: UDFArgProvider[T] = attr.ib() + udf_output_receiver: UDFOutputReceiver[T] = attr.ib() + + def wrap(self, udf: Callable[..., T], fp_config: FeatureProcessorConfig) -> Callable[..., None]: + """Wrap the provided UDF with the logic defined by the FeatureProcessorConfig. + + General functionality of the wrapper function includes but is not limited to loading data + sources and ingesting output data to a Feature Group. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Callable[..., None]: the user provided function wrapped with feature_processor logic. + """ + + @functools.wraps(udf) + def wrapper() -> None: + udf_args, udf_kwargs = self._prepare_udf_args( + udf=udf, + fp_config=fp_config, + ) + + output = udf(*udf_args, **udf_kwargs) + + self.udf_output_receiver.ingest_udf_output(output, fp_config) + + return wrapper + + def _prepare_udf_args( + self, + udf: Callable[..., T], + fp_config: FeatureProcessorConfig, + ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + """Generate the arguments for the user defined function, provided by the wrapper function. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Returns: + Tuple[Tuple[Any, ...], Dict[str, Any]]: A tuple positional arguments and keyword + arguments for the UDF. + """ + args = () + kwargs = { + **self.udf_arg_provider.provide_input_args(udf, fp_config), + **self.udf_arg_provider.provide_params_arg(udf, fp_config), + **self.udf_arg_provider.provide_additional_kwargs(udf), + } + + return (args, kwargs) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py new file mode 100644 index 0000000000..307838be0c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_validation.py @@ -0,0 +1,210 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module that contains validators and a validation chain""" +from __future__ import absolute_import + +import inspect +import re +from abc import ABC, abstractmethod +from typing import Any, Callable, List + +import attr + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + FeatureGroupDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) + + +@attr.s +class Validator(ABC): + """Base class for all validators. Errors are raised if validation fails.""" + + @abstractmethod + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validates FeatureProcessorConfig and a UDF.""" + + +@attr.s +class ValidatorChain: + """Executes a series of validators.""" + + validators: List[Validator] = attr.ib() + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validates a value using the list of validators. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If there are any validation errors raised by the validators in this chain. + """ + for validator in self.validators: + validator.validate(udf, fp_config) + + +class FeatureProcessorArgValidator(Validator): + """A validator for arguments provided to FeatureProcessor.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Temporary validator for unsupported feature_processor parameters. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + """ + # TODO: Validate target_stores values. + + +class InputValidator(Validator): + """A validator for the 'input' parameter.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the arguments provided to the decorator's input parameter. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises: + ValueError: If no inputs are provided. + """ + + inputs = fp_config.inputs + if inputs is None or len(inputs) == 0: + raise ValueError("At least one input is required.") + + +class SparkUDFSignatureValidator(Validator): + """A validator for PySpark UDF signatures.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the signature of the UDF based on the configurations provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. No input provided to feature_processor. + 2. Number of provided parameters does not match with that of provided inputs. + 3. Required parameters are not provided in the right order. + """ + parameters = list(inspect.signature(udf).parameters.keys()) + input_parameters = self._get_input_params(udf) + if len(input_parameters) < 1: + raise ValueError("feature_processor expects at least 1 input parameter.") + + # Validate count of input parameters against requested inputs. + num_data_sources = len(fp_config.inputs) + if len(input_parameters) != num_data_sources: + raise ValueError( + f"feature_processor expected a function with ({num_data_sources}) parameter(s)" + f" before any optional 'params' or 'spark' parameters for the ({num_data_sources})" + f" requested data source(s)." + ) + + # Validate position of non-input parameters. + if "params" in parameters and parameters[-1] != "params" and parameters[-2] != "params": + raise ValueError( + "feature_processor expected the 'params' parameter to be the last or second last" + " parameter after input parameters." + ) + + if "spark" in parameters and parameters[-1] != "spark" and parameters[-2] != "spark": + raise ValueError( + "feature_processor expected the 'spark' parameter to be the last or second last" + " parameter after input parameters." + ) + + def _get_input_params(self, udf: Callable[..., Any]) -> List[str]: + """Get the parameters that correspond to the inputs for a UDF. + + Args: + udf (Callable[..., Any]): the user provided function. + """ + parameters = list(inspect.signature(udf).parameters.keys()) + + # Remove non-input parameter names. + if "params" in parameters: + parameters.remove("params") + if "spark" in parameters: + parameters.remove("spark") + + return parameters + + +class InputOffsetValidator(Validator): + """An Validator for input offset.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the start and end input offset provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when input_start_offset is later than + input_end_offset. + """ + + for config_input in fp_config.inputs: + if isinstance(input, FeatureGroupDataSource): + input_start_offset = config_input.input_start_offset + input_end_offset = config_input.input_end_offset + start_td = InputOffsetParser.parse_offset_to_timedelta(input_start_offset) + end_td = InputOffsetParser.parse_offset_to_timedelta(input_end_offset) + if start_td and end_td and start_td > end_td: + raise ValueError("input_start_offset should be always before input_end_offset.") + + +class BaseDataSourceValidator(Validator): + """An Validator for BaseDataSource.""" + + def validate(self, udf: Callable[..., Any], fp_config: FeatureProcessorConfig) -> None: + """Validate the BaseDataSource provided to the decorator. + + Args: + udf (Callable[..., T]): The feature_processor wrapped user function. + fp_config (FeatureProcessorConfig): The configuration for the feature_processor. + + Raises (ValueError): raises ValueError when data_source_unique_id or data_source_name + of the input data source is not valid. + """ + + for config_input in fp_config.inputs: + if isinstance(config_input, BaseDataSource): + source_name = config_input.data_source_name + source_id = config_input.data_source_unique_id + + source_name_pattern = r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,119}$" + source_id_pattern = r"^.{1,2048}$" + + if not re.match(source_name_pattern, source_name): + raise ValueError( + f"data_source_name of input does not match pattern '{source_name_pattern}'." + ) + + if not re.match(source_id_pattern, source_id): + raise ValueError( + f"data_source_unique_id of input does not match " + f"pattern '{source_id_pattern}'." + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py new file mode 100644 index 0000000000..31593a3f1c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_processor.py @@ -0,0 +1,129 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Processor decorator for feature transformation functions.""" +from __future__ import absolute_import + +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from sagemaker.mlops.feature_store.feature_processor import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +def feature_processor( + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ], + output: str, + target_stores: Optional[List[str]] = None, + parameters: Optional[Dict[str, Union[str, Dict]]] = None, + enable_ingestion: bool = True, + spark_config: Dict[str, str] = None, +) -> Callable: + """Decorator to facilitate feature engineering for Feature Groups. + + If the decorated function is executed without arguments then the decorated function's arguments + are automatically loaded from the input data sources. Outputs are ingested to the output Feature + Group. If arguments are provided to this function, then arguments are not automatically loaded + (for testing). + + Decorated functions must conform to the expected signature. Parameters: one parameter of type + pyspark.sql.DataFrame for each DataSource in 'inputs'; followed by the optional parameters with + names and types in [params: Dict[str, Any], spark: SparkSession]. Outputs: a single return + value of type pyspark.sql.DataFrame. The function can have any name. + + **Example:** + + .. code-block:: python + + @feature_processor( + inputs=[FeatureGroupDataSource("input-fg"), CSVDataSource("s3://bucket/prefix)], + output='arn:aws:sagemaker:us-west-2:123456789012:feature-group/output-fg' + ) + def transform( + input_feature_group: DataFrame, input_csv: DataFrame, params: Dict[str, Any], + spark: SparkSession + ) -> DataFrame: + return ... + + **More concisely:** + + .. code-block:: python + + @feature_processor( + inputs=[FeatureGroupDataSource("input-fg"), CSVDataSource("s3://bucket/prefix)], + output='arn:aws:sagemaker:us-west-2:123456789012:feature-group/output-fg' + ) + def transform(input_feature_group, input_csv): + return ... + + Args: + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource,\ + BaseDataSource]]): A list of data sources. + output (str): A Feature Group ARN to write results of this function to. + target_stores (Optional[list[str]], optional): A list containing at least one of + 'OnlineStore' or 'OfflineStore'. If unspecified, data will be ingested to the enabled + stores of the output feature group. Defaults to None. + parameters (Optional[Dict[str, Union[str, Dict]]], optional): Parameters to be provided to + the decorated function, available as the 'params' argument. Useful for parameterized + functions. The params argument also contains the set of system provided parameters + under the key 'system'. E.g. 'scheduled_time': a timestamp representing the time that + the execution was scheduled to execute at, if triggered by a Scheduler, otherwise, the + current time. + enable_ingestion (bool, optional): A boolean indicating whether the decorated function's + return value is ingested to the 'output' Feature Group. This flag is useful during the + development phase to ensure that data is not used until the function is ready. It also + useful for users that want to manage their own data ingestion. Defaults to True. + spark_config (Dict[str, str]): A dict contains the key-value paris for Spark configurations. + + Raises: + IngestionError: If any rows are not ingested successfully then a sample of the records, + with failure reasons, is logged. + + Returns: + Callable: The decorated function. + """ + + def decorator(udf: Callable[..., Any]) -> Callable: + fp_config = FeatureProcessorConfig.create( + inputs=inputs, + output=output, + mode=FeatureProcessorMode.PYSPARK, + target_stores=target_stores, + parameters=parameters, + enable_ingestion=enable_ingestion, + spark_config=spark_config, + ) + + validator_chain = ValidatorFactory.get_validation_chain(fp_config) + udf_wrapper = UDFWrapperFactory.get_udf_wrapper(fp_config) + + validator_chain.validate(udf=udf, fp_config=fp_config) + wrapped_function = udf_wrapper.wrap(udf=udf, fp_config=fp_config) + + wrapped_function.feature_processor_config = fp_config + + return wrapped_function + + return decorator diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py new file mode 100644 index 0000000000..b7ae647ef7 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py @@ -0,0 +1,1105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Processor schedule APIs.""" +from __future__ import absolute_import +import logging +import json +import re +from datetime import datetime +from typing import Callable, List, Optional, Dict, Sequence, Union, Any, Tuple + +import pytz +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._config_uploader import ConfigUploader +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_group_lineage_context_name, + _get_feature_group_pipeline_lineage_context_name, + _get_feature_group_pipeline_version_lineage_context_name, + _get_feature_processor_pipeline_lineage_context_name, + _get_feature_processor_pipeline_version_lineage_context_name, +) +from sagemaker.core.lineage import context +from sagemaker.core.lineage._utils import get_resource_name_from_arn +from sagemaker.core.resources import FeatureGroup +from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( + RuntimeEnvironmentManager, +) + +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.network import SUBNETS_KEY, SECURITY_GROUP_IDS_KEY + +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, + RESOURCE_NOT_FOUND_EXCEPTION, + SPARK_JAR_FILES_PATH, + SPARK_PY_FILES_PATH, + SPARK_FILES_PATH, + FEATURE_PROCESSOR_TAG_KEY, + FEATURE_PROCESSOR_TAG_VALUE, + PIPELINE_CONTEXT_TYPE, + DEFAULT_SCHEDULE_STATE, + SCHEDULED_TIME_PIPELINE_PARAMETER, + PIPELINE_CONTEXT_NAME_TAG_KEY, + PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + PIPELINE_NAME_MAXIMUM_LENGTH, + RESOURCE_NOT_FOUND, + FEATURE_GROUP_ARN_REGEX_PATTERN, + TO_PIPELINE_RESERVED_TAG_KEYS, + DEFAULT_TRIGGER_STATE, + EVENTBRIDGE_RULE_ARN_REGEX_PATTERN, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + +from sagemaker.core.s3 import s3_path_join + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.workflow.pipeline import Pipeline +from sagemaker.mlops.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) + +from sagemaker.mlops.workflow.steps import TrainingStep + +from sagemaker.train.model_trainer import ModelTrainer +from sagemaker.train.configs import Compute, Networking, StoppingCondition, SourceCode, Tag +from sagemaker.core.shapes import OutputDataConfig +from sagemaker.core.workflow.pipeline_context import PipelineSession + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage import ( + FeatureProcessorLineageHandler, + TransformationCode, +) + +from sagemaker.core.remote_function.job import ( + _JobSettings, + JOBS_CONTAINER_ENTRYPOINT, + SPARK_APP_SCRIPT_PATH, +) + +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, +) + +logger = logging.getLogger("sagemaker") + + +def to_pipeline( + pipeline_name: str, + step: Callable, + role: Optional[str] = None, + transformation_code: Optional[TransformationCode] = None, + max_retries: Optional[int] = None, + tags: Optional[List[Tuple[str, str]]] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates a sagemaker pipeline that takes in a callable as a training step. + + To configure training step used in sagemaker pipeline, input argument step needs to be wrapped + by remote decorator in module sagemaker.remote_function. If not wrapped by remote decorator, + default configurations in sagemaker.remote_function.job._JobSettings will be used to create + training step. + + Args: + pipeline_name (str): The name of the pipeline. + step (Callable): A user provided function wrapped by feature_processor and optionally + wrapped by remote_decorator. + role (Optional[str]): The Amazon Resource Name (ARN) of the role used by the pipeline to + access and create resources. If not specified, it will default to the credentials + provided by the AWS configuration chain. + transformation_code (Optional[str]): The data source for a reference to the transformation + code for Lineage tracking. This code is not used for actual transformation. + max_retries (Optional[int]): The number of times to retry sagemaker pipeline step. + If not specified, sagemaker pipline step will not retry. + tags (List[Tuple[str, str]): A list of tags attached to the pipeline and all corresponding + lineage resources that support tags. If not specified, no custom tags will be attached. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: SageMaker Pipeline ARN. + """ + + _validate_input_for_to_pipeline_api(pipeline_name, step) + if tags: + _validate_tags_for_to_pipeline_api(tags) + + _sagemaker_session = sagemaker_session or Session() + + _validate_lineage_resources_for_to_pipeline_api( + step.feature_processor_config, _sagemaker_session + ) + + remote_decorator_config = _get_remote_decorator_config_from_input( + wrapped_func=step, sagemaker_session=_sagemaker_session + ) + _role = role or get_execution_role(_sagemaker_session) + + runtime_env_manager = RuntimeEnvironmentManager() + client_python_version = runtime_env_manager._current_python_version() + config_uploader = ConfigUploader(remote_decorator_config, runtime_env_manager) + + s3_base_uri = s3_path_join(remote_decorator_config.s3_root_uri, pipeline_name) + + ( + input_data_config, + spark_dependency_paths, + ) = config_uploader.prepare_step_input_channel_for_spark_mode( + func=getattr(step, "wrapped_func", step), + s3_base_uri=s3_base_uri, + sagemaker_session=_sagemaker_session, + ) + + pipeline_session = PipelineSession( + boto_session=_sagemaker_session.boto_session, + default_bucket=_sagemaker_session.default_bucket(), + default_bucket_prefix=_sagemaker_session.default_bucket_prefix, + ) + logger.info("Created PipelineSession for pipeline %s", pipeline_name) + + model_trainer = _prepare_model_trainer_from_remote_decorator_config( + remote_decorator_config=remote_decorator_config, + s3_base_uri=s3_base_uri, + client_python_version=client_python_version, + spark_dependency_paths=spark_dependency_paths, + pipeline_session=pipeline_session, + role=_role, + ) + + step_args = model_trainer.train(input_data_config=input_data_config) + logger.info("Obtained step_args from ModelTrainer.train() for pipeline %s", pipeline_name) + + step_name = "-".join([pipeline_name, "feature-processor"]) + training_step_request_dict = dict( + name=step_name, + step_args=step_args, + ) + logger.info("Created TrainingStep '%s' with step_args", step_name) + + if max_retries: + training_step_request_dict["retry_policies"] = [ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + max_attempts=max_retries, + ), + SageMakerJobStepRetryPolicy( + exception_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT, + ], + max_attempts=max_retries, + ), + ] + + pipeline_request_dict = dict( + name=pipeline_name, + steps=[TrainingStep(**training_step_request_dict)], + sagemaker_session=_sagemaker_session, + parameters=[SCHEDULED_TIME_PIPELINE_PARAMETER], + ) + pipeline_tags = [dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)] + if tags: + pipeline_tags.extend([dict(Key=k, Value=v) for k, v in tags]) + + pipeline = Pipeline(**pipeline_request_dict) + logger.info("Creating/Updating sagemaker pipeline %s", pipeline_name) + pipeline.upsert( + role_arn=_role, + tags=pipeline_tags, + ) + logger.info("Created sagemaker pipeline %s", pipeline_name) + + describe_pipeline_response = pipeline.describe() + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=pipeline_name, + pipeline_arn=pipeline_arn, + pipeline=describe_pipeline_response, + inputs=_get_feature_processor_inputs(wrapped_func=step), + output=_get_feature_processor_outputs(wrapped_func=step), + transformation_code=transformation_code, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_lineage(tags_propagate_to_lineage_resources) + lineage_handler.upsert_tags_for_lineage_resources(tags_propagate_to_lineage_resources) + + pipeline_lineage_names: Dict[str, str] = lineage_handler.get_pipeline_lineage_names() + + if pipeline_lineage_names is None: + raise RuntimeError("Failed to retrieve pipeline lineage. Pipeline Lineage does not exist") + + pipeline.upsert( + role_arn=_role, + tags=[ + { + "Key": PIPELINE_CONTEXT_NAME_TAG_KEY, + "Value": pipeline_lineage_names["pipeline_context_name"], + }, + { + "Key": PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + "Value": pipeline_lineage_names["pipeline_version_context_name"], + }, + ], + ) + return pipeline_arn + + +def schedule( + pipeline_name: str, + schedule_expression: str, + role_arn: Optional[str] = None, + state: Optional[str] = DEFAULT_SCHEDULE_STATE, + start_date: Optional[datetime] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates an EventBridge Schedule that schedules executions of a sagemaker pipeline. + + The pipeline created will also have a pipeline parameter `scheduled-time` indicating when the + pipeline is scheduled to run. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be scheduled. + schedule_expression (str): The expression that defines when the schedule runs. It supports + at expression, rate expression and cron expression. See the + `CreateSchedule API + `_ + for more details. + state (str): Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See the `State request parameter + `_ + for more details. If not specified, it will default to ENABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + role_arn (Optional[str]): The Amazon Resource Name (ARN) of the IAM role that EventBridge + Scheduler will assume for this target when the schedule is invoked. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The EventBridge Schedule ARN. + """ + + _sagemaker_session = sagemaker_session or Session() + _validate_pipeline_lineage_resources(pipeline_name, _sagemaker_session) + _start_date = start_date or datetime.now(tz=pytz.utc) + _role_arn = role_arn or get_execution_role(_sagemaker_session) + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("scheduler"), + ) + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + logger.info("Creating/Updating EventBridge Schedule for pipeline %s.", pipeline_name) + event_bridge_schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=pipeline_name, + pipeline_arn=pipeline_arn, + schedule_expression=schedule_expression, + state=state, + start_date=_start_date, + role=_role_arn, + ) + logger.info("Created/Updated EventBridge Schedule for pipeline %s.", pipeline_name) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=pipeline_name, + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline=describe_pipeline_response, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_schedule_lineage( + pipeline_name=pipeline_name, + schedule_arn=event_bridge_schedule_arn["ScheduleArn"], + schedule_expression=schedule_expression, + state=state, + start_date=_start_date, + tags=tags_propagate_to_lineage_resources, + ) + return event_bridge_schedule_arn["ScheduleArn"] + + +def put_trigger( + source_pipeline_events: List[FeatureProcessorPipelineEvents], + target_pipeline: str, + target_pipeline_parameters: Optional[Dict[str, str]] = None, + state: Optional[str] = DEFAULT_TRIGGER_STATE, + event_pattern: Optional[str] = None, + role_arn: Optional[str] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Creates an event based trigger that triggers executions of a sagemaker pipeline. + + Args: + source_pipeline_events (List[FeatureProcessorPipelineEvents]): The list of + FeatureProcessorPipelineEvents that will trigger the target_pipeline. + target_pipeline (str): The name of the SageMaker Pipeline that will be triggered. + target_pipeline_parameters (Optional[Dict[str, str]]): The list of parameters to start + execution of a pipeline. + state (Optional[str]): Indicates whether the rule is enabled or disabled. + If not specified, it will default to ENABLED. + event_pattern (Optional[str]): The EventBridge EventPattern that triggers the + target_pipeline. If specified, will override source_pipeline_events. For more + information, see + https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-event-patterns.html + in the Amazon EventBridge User Guide. + role_arn (Optional[str]): The Amazon Resource Name (ARN) of the IAM role that EventBridge + Scheduler will assume for this target when the schedule is invoked. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The EventBridge Rule ARN. + """ + _sagemaker_session = sagemaker_session or Session() + _role_arn = role_arn or get_execution_role(_sagemaker_session) + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + logger.info("Creating/Updating EventBridge Rule for pipeline %s.", target_pipeline) + rule_arn = event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline=target_pipeline, + event_pattern=event_pattern, + state=state, + ) + rule_name = _parse_name_from_arn(rule_arn, EVENTBRIDGE_RULE_ARN_REGEX_PATTERN) + logger.info("Created/Updated EventBridge Rule for pipeline %s.", target_pipeline) + + logger.info("Attaching pipeline %s to EventBridge Rule %s as target", target_pipeline, rule_arn) + event_bridge_rule_helper.put_target( + rule_name=rule_name, + target_pipeline=target_pipeline, + target_pipeline_parameters=target_pipeline_parameters, + role_arn=_role_arn, + ) + logger.info("Attached pipeline %s to EventBridge Rule %s as target", target_pipeline, rule_arn) + + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=target_pipeline + ) + describe_rule_response = event_bridge_rule_helper.describe_rule(rule_name=rule_name) + pipeline_arn = describe_pipeline_response["PipelineArn"] + tags_propagate_to_lineage_resources = _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn, _sagemaker_session + ) + + event_bridge_rule_helper.add_tags(rule_arn=rule_arn, tags=tags_propagate_to_lineage_resources) + + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=target_pipeline, + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline=describe_pipeline_response, + sagemaker_session=_sagemaker_session, + ) + lineage_handler.create_trigger_lineage( + pipeline_name=target_pipeline, + trigger_arn=rule_arn, + state=state, + tags=tags_propagate_to_lineage_resources, + event_pattern=describe_rule_response["EventPattern"], + ) + return rule_arn + + +def enable_trigger( + pipeline_name: str, + sagemaker_session: Optional[Session] = None, +) -> None: + """Enable the EventBridge Rule that is associated with the pipeline. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule_helper.enable_rule(rule_name=pipeline_name) + logger.info("Enabled EventBridge Rule for pipeline %s.", pipeline_name) + + +def disable_trigger(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Disable the EventBridge Rule that is associated with the pipeline. + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule_helper.disable_rule(rule_name=pipeline_name) + logger.info("Disabled EventBridge Rule for pipeline %s.", pipeline_name) + + +def execute( + pipeline_name: str, + execution_time: Optional[datetime] = None, + sagemaker_session: Optional[Session] = None, +) -> str: + """Starts an execution of a SageMaker Pipeline created by feature_processor + + Args: + pipeline_name (str): The SageMaker Pipeline name that will be executed. + execution_time (datetime): The date, in UTC, will be used as a sagemaker pipeline parameter + indicating the time which at which the execution is scheduled to execute. If not + specified, it will default to the current timestamp. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + str: The pipeline execution ARN. + """ + _sagemaker_session = sagemaker_session or Session() + _validate_pipeline_lineage_resources(pipeline_name, _sagemaker_session) + _execution_time = execution_time or datetime.now() + start_pipeline_execution_request = dict( + PipelineName=pipeline_name, + PipelineParameters=[ + dict( + Name=EXECUTION_TIME_PIPELINE_PARAMETER, + Value=_execution_time.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ) + ], + ) + logger.info("Starting an execution for pipline %s", pipeline_name) + execution_response = _sagemaker_session.sagemaker_client.start_pipeline_execution( + **start_pipeline_execution_request + ) + execution_arn = execution_response["PipelineExecutionArn"] + logger.info( + "Execution %s for pipeline %s is successfully started.", + execution_arn, + pipeline_name, + ) + return execution_arn + + +def delete_schedule(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Delete EventBridge Schedule corresponding to a SageMaker Pipeline if there is one. + + Args: + pipeline_name (str): The name of the SageMaker Pipeline that needs to be deleted + sagemaker_session: (Optional[Session], optional): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, _sagemaker_session.boto_session.client("scheduler") + ) + try: + event_bridge_scheduler_helper.delete_schedule(pipeline_name) + logger.info("Deleted EventBridge Schedule for pipeline %s.", pipeline_name) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION != e.response["Error"]["Code"]: + raise e + + +def delete_trigger(pipeline_name: str, sagemaker_session: Optional[Session] = None) -> None: + """Delete EventBridge Rule corresponding to a SageMaker Pipeline if there is one. + + Args: + pipeline_name (str): The name of the SageMaker Pipeline that needs to be deleted + sagemaker_session: (Optional[Session], optional): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + _sagemaker_session = sagemaker_session or Session() + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + try: + target_ids = [] + for page in event_bridge_rule_helper.list_targets_by_rule(pipeline_name): + target_ids.extend([target["Id"] for target in page["Targets"]]) + event_bridge_rule_helper.remove_targets(rule_name=pipeline_name, ids=target_ids) + event_bridge_rule_helper.delete_rule(pipeline_name) + logger.info("Deleted EventBridge Rule for pipeline %s.", pipeline_name) + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION != e.response["Error"]["Code"]: + raise e + + +def describe( + pipeline_name: str, sagemaker_session: Optional[Session] = None +) -> Dict[str, Union[int, str]]: + """Describe feature processor and other related resources. + + This API will include details related to the feature processor including SageMaker Pipeline and + EventBridge Schedule. + + Args: + pipeline_name (str): Name of the pipeline. + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + Dict[str, Union[int, str]]: Return information for resources related to feature processor. + """ + + _sagemaker_session = sagemaker_session or Session() + describe_response_dict = {} + + try: + describe_pipeline_response = _sagemaker_session.sagemaker_client.describe_pipeline( + PipelineName=pipeline_name + ) + pipeline_definition = json.loads(describe_pipeline_response["PipelineDefinition"]) + pipeline_step = pipeline_definition["Steps"][0] + describe_response_dict = dict( + pipeline_arn=describe_pipeline_response["PipelineArn"], + pipeline_execution_role_arn=describe_pipeline_response["RoleArn"], + ) + + if "RetryPolicies" in pipeline_step: + describe_response_dict["max_retries"] = pipeline_step["RetryPolicies"][0]["MaxAttempts"] + except ClientError as e: + if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]: + logger.info("Pipeline %s does not exist.", pipeline_name) + + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("scheduler"), + ) + + event_bridge_schedule = event_bridge_scheduler_helper.describe_schedule(pipeline_name) + if event_bridge_schedule: + describe_response_dict.update( + dict( + schedule_arn=event_bridge_schedule["Arn"], + schedule_expression=event_bridge_schedule["ScheduleExpression"], + schedule_state=event_bridge_schedule["State"], + schedule_start_date=event_bridge_schedule["StartDate"].strftime( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT + ), + schedule_role=event_bridge_schedule["Target"]["RoleArn"], + ) + ) + + event_bridge_rule_helper = EventBridgeRuleHelper( + _sagemaker_session, + _sagemaker_session.boto_session.client("events"), + ) + event_based_trigger = event_bridge_rule_helper.describe_rule(pipeline_name) + if event_based_trigger: + describe_response_dict.update( + dict( + trigger=event_based_trigger["Arn"], + event_pattern=event_based_trigger["EventPattern"], + trigger_state=event_based_trigger["State"], + ) + ) + + return describe_response_dict + + +def list_pipelines(sagemaker_session: Optional[Session] = None) -> List[Dict[str, Any]]: + """Lists all SageMaker Pipelines created by Feature Processor SDK. + + Args: + sagemaker_session (Optional[Session]): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + List[Dict[str, Any]]: Return list of SageMaker Pipeline metadata created for + feature_processor. + """ + + _sagemaker_session = sagemaker_session or Session() + next_token = None + list_response = [] + pipeline_names_so_far = set([]) + while True: + list_contexts_request = dict(ContextType=PIPELINE_CONTEXT_TYPE) + if next_token: + list_contexts_request["NextToken"] = next_token + list_contexts_response = _sagemaker_session.sagemaker_client.list_contexts( + **list_contexts_request + ) + for _context in list_contexts_response["ContextSummaries"]: + pipeline_name = get_resource_name_from_arn(_context["Source"]["SourceUri"]) + if pipeline_name not in pipeline_names_so_far: + list_response.append(dict(pipeline_name=pipeline_name)) + pipeline_names_so_far.add(pipeline_name) + next_token = list_contexts_response.get("NextToken") + if not next_token: + break + + return list_response + + +def _validate_input_for_to_pipeline_api(pipeline_name: str, step: Callable) -> None: + """Validate input to to_pipeline API. + + The provided callable is considered valid if it's wrapped by feature_processor decorator + and uses pyspark mode. + + Args: + pipeline_name (str): The name of the pipeline. + step (Callable): A user provided function wrapped by feature_processor and optionally + wrapped by remote_decorator. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. pipeline name is longer than 80 characters. + 2. function is not annotated with either feature_processor or remote decorator. + 3. provides a mode other than pyspark. + """ + if len(pipeline_name) > PIPELINE_NAME_MAXIMUM_LENGTH: + raise ValueError( + "Pipeline name used by feature processor should be less than 80 " + "characters. Please choose another pipeline name." + ) + + if not hasattr(step, "feature_processor_config") or not step.feature_processor_config: + raise ValueError( + "Please wrap step parameter with feature_processor decorator" + " in order to use to_pipeline API." + ) + + if not hasattr(step, "job_settings") or not step.job_settings: + raise ValueError( + "Please wrap step parameter with remote decorator in order to use to_pipeline API." + ) + + if FeatureProcessorMode.PYSPARK != step.feature_processor_config.mode: + raise ValueError( + f"Mode {step.feature_processor_config.mode} is not supported by to_pipeline API." + ) + + +def _validate_tags_for_to_pipeline_api(tags: List[Tuple[str, str]]) -> None: + """Validate tags provided to to_pipeline API. + + Args: + tags (List[Tuple[str, str]]): A list of tags attached to the pipeline. + + Raises (ValueError): raises ValueError when any of the following scenario happen: + 1. reserved tag keys are provided to API. + """ + provided_tag_keys = [tag_key_value_pair[0] for tag_key_value_pair in tags] + for reserved_tag_key in TO_PIPELINE_RESERVED_TAG_KEYS: + if reserved_tag_key in provided_tag_keys: + raise ValueError( + f"{reserved_tag_key} is a reserved tag key for to_pipeline API. Please choose another tag." + ) + + +def _validate_lineage_resources_for_to_pipeline_api( + feature_processor_config: FeatureProcessorConfig, sagemaker_session: Session +) -> None: + """Validate existence of feature group lineage resources for to_pipeline API. + + Args: + feature_processor_config (FeatureProcessorConfig): The configuration values for the + feature_processor decorator. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + """ + inputs = feature_processor_config.inputs + output = feature_processor_config.output + for ds in inputs: + if isinstance(ds, FeatureGroupDataSource): + fg_name = _parse_name_from_arn(ds.name) + _validate_fg_lineage_resources(fg_name, sagemaker_session) + output_fg_name = _parse_name_from_arn(output) + _validate_fg_lineage_resources(output_fg_name, sagemaker_session) + + +def _validate_fg_lineage_resources(feature_group_name: str, sagemaker_session: Session) -> None: + """Validate existence of feature group lineage resources. + + Args: + feature_group_name (str): The name or arn of the feature group. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + + Raises (ValueError): raises ValueError when lineage resources are not created for feature + groups. + """ + + # TODO: Add describe_feature_group to V3 sagemaker_session so we can use + # sagemaker_session.describe_feature_group() directly instead of FeatureGroup.get(). + feature_group = FeatureGroup.get( + feature_group_name=feature_group_name, session=sagemaker_session.boto_session + ) + feature_group_creation_time = feature_group.creation_time.strftime("%s") + feature_group_context = _get_feature_group_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + feature_group_pipeline_context = _get_feature_group_pipeline_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + feature_group_pipeline_version_context = ( + _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + ) + for context_name in [ + feature_group_context, + feature_group_pipeline_context, + feature_group_pipeline_version_context, + ]: + try: + logger.info("Verifying existence of context %s.", context_name) + context.Context.load(context_name=context_name, sagemaker_session=sagemaker_session) + except ClientError as e: + if RESOURCE_NOT_FOUND == e.response["Error"]["Code"]: + raise ValueError( + f"Lineage resource {context_name} has not yet been created for feature group" + f" {feature_group_name} or has already been deleted. Please try again later." + ) + raise e + + +def _validate_pipeline_lineage_resources(pipeline_name: str, sagemaker_session: Session) -> None: + """Validate existence of pipeline lineage resources. + + Args: + pipeline_name (str): The name of the pipeline. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. + """ + pipeline = sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=pipeline_name) + pipeline_creation_time = pipeline["CreationTime"].strftime("%s") + pipeline_context_name = _get_feature_processor_pipeline_lineage_context_name( + pipeline_name=pipeline_name, pipeline_creation_time=pipeline_creation_time + ) + try: + pipeline_context = context.Context.load( + context_name=pipeline_context_name, sagemaker_session=sagemaker_session + ) + last_update_time = pipeline_context.properties["LastUpdateTime"] + pipeline_version_context_name = ( + _get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name=pipeline_name, pipeline_last_update_time=last_update_time + ) + ) + context.Context.load( + context_name=pipeline_version_context_name, sagemaker_session=sagemaker_session + ) + except ClientError as e: + if RESOURCE_NOT_FOUND == e.response["Error"]["Code"]: + raise ValueError( + "Pipeline lineage resources have not been created yet or have already been deleted" + ". Please try again later." + ) + raise e + + +def _prepare_model_trainer_from_remote_decorator_config( + remote_decorator_config: _JobSettings, + s3_base_uri: str, + client_python_version: str, + spark_dependency_paths: Dict[str, Optional[str]], + pipeline_session: PipelineSession, + role: str, +) -> ModelTrainer: + """Prepares a ModelTrainer instance from remote decorator configuration. + + Args: + remote_decorator_config (_JobSettings): Configurations used for setting up + SageMaker Pipeline Step. + s3_base_uri (str): S3 URI used as destination for dependencies upload. + client_python_version (str): Python version used on client side. + spark_dependency_paths (Dict[str, Optional[str]]): A dictionary contains S3 paths spark + dependency files get uploaded to if present. + pipeline_session (PipelineSession): Pipeline-aware session that causes + ModelTrainer.train() to return step arguments instead of launching a job. + role (str): The IAM role ARN for the training job. + Returns: + ModelTrainer: A configured ModelTrainer instance. + """ + logger.info("Mapping remote decorator config to ModelTrainer params") + + # Build environment dict from remote_decorator_config (strings only for Pydantic validation) + environment = dict(remote_decorator_config.environment_variables or {}) + + # Build command from container entry point and arguments + entry_point_and_args = _get_container_entry_point_and_arguments( + remote_decorator_config=remote_decorator_config, + s3_base_uri=s3_base_uri, + client_python_version=client_python_version, + spark_dependency_paths=spark_dependency_paths, + ) + joined_command = " ".join( + entry_point_and_args["container_entry_point"] + + entry_point_and_args["container_arguments"] + ) + source_code = SourceCode(command=joined_command) + logger.info("SourceCode command: %s", joined_command) + + # Create Compute config + compute = Compute( + instance_type=remote_decorator_config.instance_type, + instance_count=remote_decorator_config.instance_count, + volume_size_in_gb=remote_decorator_config.volume_size, + volume_kms_key_id=remote_decorator_config.volume_kms_key, + ) + logger.info( + "Compute: instance_type=%s, instance_count=%s, volume_size=%s", + remote_decorator_config.instance_type, + remote_decorator_config.instance_count, + remote_decorator_config.volume_size, + ) + + # Create Networking config if VPC config is present + networking = None + if remote_decorator_config.vpc_config: + networking = Networking( + subnets=remote_decorator_config.vpc_config[SUBNETS_KEY], + security_group_ids=remote_decorator_config.vpc_config[SECURITY_GROUP_IDS_KEY], + enable_inter_container_traffic_encryption=( + remote_decorator_config.encrypt_inter_container_traffic + ), + ) + logger.info( + "Networking: subnets=%s, security_groups=%s, encrypt=%s", + remote_decorator_config.vpc_config[SUBNETS_KEY], + remote_decorator_config.vpc_config[SECURITY_GROUP_IDS_KEY], + remote_decorator_config.encrypt_inter_container_traffic, + ) + + # Create StoppingCondition if max_runtime_in_seconds is configured + stopping_condition = None + if remote_decorator_config.max_runtime_in_seconds: + stopping_condition = StoppingCondition( + max_runtime_in_seconds=remote_decorator_config.max_runtime_in_seconds, + ) + + # Create OutputDataConfig + output_data_config = OutputDataConfig( + s3_output_path=s3_base_uri, + kms_key_id=remote_decorator_config.s3_kms_key, + ) + + # Convert tags from List[Tuple[str, str]] to List[Tag] + tags = None + if remote_decorator_config.tags: + tags = [Tag(key=k, value=v) for k, v in remote_decorator_config.tags] + logger.info("Tags count: %d", len(tags) if tags else 0) + + logger.info("Environment keys: %s", list(environment.keys())) + + model_trainer = ModelTrainer( + training_image=remote_decorator_config.image_uri, + role=role, + sagemaker_session=pipeline_session, + compute=compute, + networking=networking, + stopping_condition=stopping_condition, + output_data_config=output_data_config, + source_code=source_code, + training_input_mode="File", + environment=environment, + tags=tags, + ) + + # Inject SCHEDULED_TIME_PIPELINE_PARAMETER after construction to bypass Pydantic + # validation (Parameter is not a string). The @runnable_by_pipeline decorator resolves + # Parameter objects to strings during pipeline definition serialization. + model_trainer.environment[EXECUTION_TIME_PIPELINE_PARAMETER] = SCHEDULED_TIME_PIPELINE_PARAMETER + + logger.info( + "Created ModelTrainer with image=%s, instance_type=%s, instance_count=%s", + remote_decorator_config.image_uri, + remote_decorator_config.instance_type, + remote_decorator_config.instance_count, + ) + + return model_trainer + + +def _get_container_entry_point_and_arguments( + remote_decorator_config: _JobSettings, + s3_base_uri: str, + client_python_version: str, + spark_dependency_paths: Dict[str, Optional[str]], +) -> Dict[str, List[str]]: + """Extracts the container entry point and container arguments from remote decorator configs + + Args: + remote_decorator_config (_JobSettings): Configurations used for setting up + SageMaker Pipeline Step. + s3_base_uri (str): S3 URI used as destination for dependencies upload. + client_python_version (str): Python version used on client side. + spark_dependency_paths (Dict[str, Optional[str]]): A dictionary contains S3 paths spark + dependency files get uploaded to if present. + Returns: + Dict[str, List[str]]: Request dictionary containing container entry point and + arguments setup. + """ + + spark_config = remote_decorator_config.spark_config + jobs_container_entrypoint = JOBS_CONTAINER_ENTRYPOINT.copy() + + if spark_dependency_paths[SPARK_JAR_FILES_PATH]: + jobs_container_entrypoint.extend(["--jars", spark_dependency_paths[SPARK_JAR_FILES_PATH]]) + + if spark_dependency_paths[SPARK_PY_FILES_PATH]: + jobs_container_entrypoint.extend( + ["--py-files", spark_dependency_paths[SPARK_PY_FILES_PATH]] + ) + + if spark_dependency_paths[SPARK_FILES_PATH]: + jobs_container_entrypoint.extend(["--files", spark_dependency_paths[SPARK_FILES_PATH]]) + + if spark_config and spark_config.spark_event_logs_uri: + jobs_container_entrypoint.extend( + ["--spark-event-logs-s3-uri", spark_config.spark_event_logs_uri] + ) + + if spark_config: + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + + container_args = ["--s3_base_uri", s3_base_uri] + container_args.extend(["--region", remote_decorator_config.sagemaker_session.boto_region_name]) + container_args.extend(["--client_python_version", client_python_version]) + + if remote_decorator_config.s3_kms_key: + container_args.extend(["--s3_kms_key", remote_decorator_config.s3_kms_key]) + + return dict( + container_entry_point=jobs_container_entrypoint, + container_arguments=container_args, + ) + + +def _get_remote_decorator_config_from_input( + wrapped_func: Callable, sagemaker_session: Session +) -> _JobSettings: + """Extracts the remote decorator configuration from the wrapped function and other inputs. + + Args: + wrapped_func (Callable): Wrapped user defined function. If it contains remote decorator + job settings, configs will be used to construct remote_decorator_config, otherwise + default job settings will be used. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + Returns: + _JobSettings: Configurations used for creating sagemaker pipeline step. + """ + remote_decorator_config = getattr( + wrapped_func, + "job_settings", + ) + # TODO: Remove this after GA + remote_decorator_config.sagemaker_session = sagemaker_session + + # TODO: This needs to be removed when new mode is introduced. + if remote_decorator_config.spark_config is None: + remote_decorator_config.spark_config = SparkConfig() + remote_decorator_config.image_uri = _JobSettings._get_default_spark_image(sagemaker_session) + + return remote_decorator_config + + +def _get_feature_processor_inputs( + wrapped_func: Callable, +) -> Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource]]: + """Retrieve Feature Processor Config Inputs""" + feature_processor_config: FeatureProcessorConfig = wrapped_func.feature_processor_config + return feature_processor_config.inputs + + +def _get_feature_processor_outputs( + wrapped_func: Callable, +) -> str: + """Retrieve Feature Processor Config Output""" + feature_processor_config: FeatureProcessorConfig = wrapped_func.feature_processor_config + return feature_processor_config.output + + +def _parse_name_from_arn( + name_or_arn: str, regex_pattern: str = FEATURE_GROUP_ARN_REGEX_PATTERN +) -> str: + """Parse the name from a string, if it's an ARN. Otherwise, return the string. + + Args: + fg_uri (str): The Feature Group Name or ARN. + + Returns: + str: The Feature Group Name. + """ + match = re.match(regex_pattern, name_or_arn) + if match: + name = match.group(4) + return name + return name_or_arn + + +def _get_tags_from_pipeline_to_propagate_to_lineage_resources( + pipeline_arn: str, sagemaker_session: Session +) -> List[Dict[str, str]]: + """Retrieve custom tags attached to sagemakre pipeline + + Args: + pipeline_arn (str): SageMaker Pipeline Arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + List[Dict[str, str]]: List of custom tags to be propagated to lineage resources. + """ + tags_in_pipeline = sagemaker_session.sagemaker_client.list_tags(ResourceArn=pipeline_arn)[ + "Tags" + ] + return [d for d in tags_in_pipeline if d["Key"] not in TO_PIPELINE_RESERVED_TAG_KEYS] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py new file mode 100644 index 0000000000..2b4f134f0a --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_contexts.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store Feature Group Contexts""" +from __future__ import absolute_import +import attr + + +@attr.s +class FeatureGroupContexts: + """A Feature Group Context data source. + + Attributes: + feature_group_name (str): The name of the Feature Group. + feature_group_pipeline_context_arn (str): The ARN of the Feature Group Pipeline Context. + feature_group_pipeline_version_context_arn (str): + The ARN of the Feature Group Versions Context + """ + + name: str = attr.ib() + pipeline_context_arn: str = attr.ib() + pipeline_version_context_arn: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py new file mode 100644 index 0000000000..fd160eb470 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py @@ -0,0 +1,184 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Feature Processor Lineage""" +from __future__ import absolute_import + +import re +from typing import Dict, Any +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._constants import FEATURE_GROUP_ARN_REGEX_PATTERN +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + FEATURE_GROUP, + CREATION_TIME, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_group_pipeline_lineage_context_name, + _get_feature_group_pipeline_version_lineage_context_name, +) + +logger = logging.getLogger(SAGEMAKER) + + +class FeatureGroupLineageEntityHandler: + """Class for handling Feature Group Lineage""" + + @staticmethod + def retrieve_feature_group_context_arns( + feature_group_name: str, sagemaker_session: Session + ) -> FeatureGroupContexts: + """Retrieve Feature Group Contexts. + + Arguments: + feature_group_name (str): The Feature Group Name. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + FeatureGroupContexts: The Feature Group Pipeline and Version Context. + """ + feature_group = FeatureGroupLineageEntityHandler._describe_feature_group( + feature_group_name=FeatureGroupLineageEntityHandler.parse_name_from_arn( + feature_group_name + ), + sagemaker_session=sagemaker_session, + ) + feature_group_name = feature_group[FEATURE_GROUP] + feature_group_creation_time = feature_group[CREATION_TIME].strftime("%s") + feature_group_pipeline_context = ( + FeatureGroupLineageEntityHandler._load_feature_group_pipeline_context( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + sagemaker_session=sagemaker_session, + ) + ) + feature_group_pipeline_version_context = ( + FeatureGroupLineageEntityHandler._load_feature_group_pipeline_version_context( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + sagemaker_session=sagemaker_session, + ) + ) + return FeatureGroupContexts( + name=feature_group_name, + pipeline_context_arn=feature_group_pipeline_context.context_arn, + pipeline_version_context_arn=feature_group_pipeline_version_context.context_arn, + ) + + @staticmethod + def _describe_feature_group( + feature_group_name: str, sagemaker_session: Session + ) -> Dict[str, Any]: + """Retrieve the Feature Group. + + Arguments: + feature_group_name (str): The Feature Group Name. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Dict[str, Any]: The Feature Group details. + """ + feature_group = sagemaker_session.sagemaker_client.describe_feature_group( + FeatureGroupName=feature_group_name + ) + logger.debug( + "Called describe_feature_group with %s and received: %s", + feature_group_name, + feature_group, + ) + return feature_group + + @staticmethod + def _load_feature_group_pipeline_context( + feature_group_name: str, + feature_group_creation_time: str, + sagemaker_session: Session, + ) -> Context: + """Retrieve Feature Group Pipeline Context + + Arguments: + feature_group_name (str): The Feature Group Name. + feature_group_creation_time (str): The Feature Group Creation Time, + in long epoch seconds. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The Feature Group Pipeline Context. + """ + feature_group_pipeline_context = _get_feature_group_pipeline_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + return Context.load( + context_name=feature_group_pipeline_context, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def _load_feature_group_pipeline_version_context( + feature_group_name: str, + feature_group_creation_time: str, + sagemaker_session: Session, + ) -> Context: + """Retrieve Feature Group Pipeline Version Context + + Arguments: + feature_group_name (str): The Feature Group Name. + feature_group_creation_time (str): The Feature Group Creation Time, + in long epoch seconds. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The Feature Group Pipeline Version Context. + """ + feature_group_pipeline_version_context = ( + _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name=feature_group_name, + feature_group_creation_time=feature_group_creation_time, + ) + ) + return Context.load( + context_name=feature_group_pipeline_version_context, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def parse_name_from_arn(fg_uri: str) -> str: + """Parse the name from a string, if it's an ARN. Otherwise, return the string. + + Arguments: + fg_uri (str): The Feature Group Name or ARN. + + Returns: + str: The Feature Group Name. + """ + match = re.match(FEATURE_GROUP_ARN_REGEX_PATTERN, fg_uri) + if match: + feature_group_name = match.group(4) + return feature_group_name + return fg_uri diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py new file mode 100644 index 0000000000..d706b3b441 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage.py @@ -0,0 +1,759 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Lineage Associations""" +from __future__ import absolute_import +import logging +from datetime import datetime +from typing import Optional, Iterator, List, Dict, Set, Sequence, Union +import attr +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import ( + PipelineTrigger, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + LAST_UPDATE_TIME, + PIPELINE_CONTEXT_NAME_KEY, + PIPELINE_CONTEXT_VERSION_NAME_KEY, + FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + DATA_SET, + TRANSFORMATION_CODE, + CREATION_TIME, + RESOURCE_NOT_FOUND, + ERROR, + CODE, + LAST_MODIFIED_TIME, + TRANSFORMATION_CODE_STATUS_INACTIVE, + TRANSFORMATION_CODE_STATUS_ACTIVE, + CONTRIBUTED_TO, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.lineage.artifact import Artifact +from sagemaker.core.lineage.association import AssociationSummary +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) + +logger = logging.getLogger(SAGEMAKER) + + +@attr.s +class FeatureProcessorLineageHandler: + """Class to Create and Update FeatureProcessor Lineage Entities. + + Attributes: + pipeline_name (str): Pipeline Name. + pipeline_arn (str): The ARN of the Pipeline. + pipeline (str): The details of the Pipeline. + inputs (Sequence[Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, + BaseDataSource]]): The inputs to the Feature processor. + output (str): The output Feature Group. + transformation_code (TransformationCode): The Transformation Code for Feature Processor. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + + pipeline_name: str = attr.ib() + pipeline_arn: str = attr.ib() + pipeline: Dict = attr.ib() + sagemaker_session: Session = attr.ib() + inputs: Sequence[ + Union[FeatureGroupDataSource, CSVDataSource, ParquetDataSource, BaseDataSource] + ] = attr.ib(default=None) + output: str = attr.ib(default=None) + transformation_code: TransformationCode = attr.ib(default=None) + + def create_lineage(self, tags: Optional[List[Dict[str, str]]] = None) -> None: + """Create and Update Feature Processor Lineage""" + input_feature_group_contexts: List[FeatureGroupContexts] = ( + self._retrieve_input_feature_group_contexts() + ) + output_feature_group_contexts: FeatureGroupContexts = ( + self._retrieve_output_feature_group_contexts() + ) + input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts() + transformation_code_artifact: Optional[Artifact] = ( + S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=self.transformation_code, + pipeline_last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + ) + if transformation_code_artifact is not None: + logger.info("Created Transformation Code Artifact: %s", transformation_code_artifact) + if tags: + transformation_code_artifact.set_tags(tags) # pylint: disable=E1101 + # Create the Pipeline Lineage for the first time + if not self._check_if_pipeline_lineage_exists(): + self._create_new_pipeline_lineage( + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + else: + self._update_pipeline_lineage( + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + + def get_pipeline_lineage_names(self) -> Optional[Dict[str, str]]: + """Retrieve Pipeline Lineage Names. + + Returns: + Optional[Dict[str, str]]: Pipeline and Pipeline version lineage names. + """ + if not self._check_if_pipeline_lineage_exists(): + return None + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + return { + PIPELINE_CONTEXT_NAME_KEY: pipeline_context.context_name, + PIPELINE_CONTEXT_VERSION_NAME_KEY: current_pipeline_version_context.context_name, + } + + def create_schedule_lineage( + self, + pipeline_name: str, + schedule_arn, + schedule_expression, + state, + start_date: datetime, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + """Class to Create and Update FeatureProcessor Lineage Entities. + + Arguments: + pipeline_name (str): Pipeline Name. + schedule_arn (str): The ARN of the Schedule. + schedule_expression (str): The expression that defines when the schedule runs. + It supports at expression, rate expression and cron expression. + state (str):Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/ + API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details. + If not specified, it will default to ENABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + tags (Optional[List[Dict[str, str]]]): Custom tags to be attached to schedule + lineage resource. + """ + pipeline_context: Context = self._get_pipeline_context() + pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + pipeline_schedule: PipelineSchedule = PipelineSchedule( + schedule_name=pipeline_name, + schedule_arn=schedule_arn, + schedule_expression=schedule_expression, + pipeline_name=pipeline_name, + state=state, + start_date=start_date.strftime("%s"), + ) + schedule_artifact: Artifact = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=pipeline_schedule, + sagemaker_session=self.sagemaker_session, + ) + if tags: + schedule_artifact.set_tags(tags) + + LineageAssociationHandler.add_upstream_schedule_associations( + schedule_artifact=schedule_artifact, + pipeline_version_context_arn=pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def create_trigger_lineage( + self, + pipeline_name: str, + trigger_arn: str, + event_pattern: str, + state: str, + tags: Optional[List[Dict[str, str]]] = None, + ) -> None: + """Class to Create and Update FeatureProcessor Pipeline Trigger Lineage Entities. + + Arguments: + pipeline_name (str): Pipeline Name. + trigger_arn (str): The ARN of the EventBridge Rule. + event_pattern (str): The event pattern for the rule. + state (str): Specifies whether the trigger is enabled or disabled. Valid values are + ENABLED and DISABLED. If not specified, it will default to ENABLED. + tags (Optional[List[Dict[str, str]]]): Custom tags to be attached to trigger + lineage resource. + """ + pipeline_context: Context = self._get_pipeline_context() + pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + pipeline_trigger: PipelineTrigger = PipelineTrigger( + trigger_name=pipeline_name, + trigger_arn=trigger_arn, + event_pattern=event_pattern, + pipeline_name=pipeline_name, + state=state, + ) + trigger_artifact: Artifact = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=pipeline_trigger, + sagemaker_session=self.sagemaker_session, + ) + if tags: + trigger_artifact.set_tags(tags) + + LineageAssociationHandler._add_association( + source_arn=trigger_artifact.artifact_arn, + destination_arn=pipeline_version_context.context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=self.sagemaker_session, + ) + + def upsert_tags_for_lineage_resources(self, tags: List[Dict[str, str]]) -> None: + """Add or update tags for lineage resources using tags attached to sagemaker pipeline as + + source of truth. + + Args: + tags (List[Dict[str, str]]): Custom tags to be attached to lineage resources. + """ + if not tags: + return + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + input_raw_data_artifacts: List[Artifact] = self._retrieve_input_raw_data_artifacts() + pipeline_context.set_tags(tags) + current_pipeline_version_context.set_tags(tags) + for input_raw_data_artifact in input_raw_data_artifacts: + input_raw_data_artifact.set_tags(tags) + + event_bridge_scheduler_helper = EventBridgeSchedulerHelper( + self.sagemaker_session, + self.sagemaker_session.boto_session.client("scheduler"), + ) + event_bridge_schedule = event_bridge_scheduler_helper.describe_schedule(self.pipeline_name) + + event_bridge_rule_helper = EventBridgeRuleHelper( + self.sagemaker_session, + self.sagemaker_session.boto_session.client("events"), + ) + event_bridge_rule = event_bridge_rule_helper.describe_rule(self.pipeline_name) + + if event_bridge_schedule: + schedule_artifact_summary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=event_bridge_schedule["Arn"], + sagemaker_session=self.sagemaker_session, + ) + if schedule_artifact_summary is not None: + pipeline_schedule_artifact: Artifact = ( + S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=schedule_artifact_summary.artifact_arn, + sagemaker_session=self.sagemaker_session, + ) + ) + pipeline_schedule_artifact.set_tags(tags) + + if event_bridge_rule: + rule_artifact_summary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=event_bridge_rule["Arn"], + sagemaker_session=self.sagemaker_session, + ) + if rule_artifact_summary: + pipeline_trigger_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=rule_artifact_summary.artifact_arn, + sagemaker_session=self.sagemaker_session, + ) + pipeline_trigger_artifact.set_tags(tags) + + def _create_new_pipeline_lineage( + self, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact], + ) -> None: + """Create pipeline lineage resources.""" + + pipeline_context = self._create_pipeline_lineage_for_new_pipeline() + pipeline_version_context = self._create_pipeline_version_lineage() + self._add_associations_for_pipeline( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_versions_context_arn=pipeline_version_context.context_arn, + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_version_context_arn=pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def _update_pipeline_lineage( + self, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact], + ) -> None: + """Update pipeline lineage resources.""" + + # If pipeline lineage exists then determine whether to create a new version. + pipeline_context: Context = self._get_pipeline_context() + current_pipeline_version_context: Context = self._get_pipeline_version_context( + last_update_time=pipeline_context.properties[LAST_UPDATE_TIME] + ) + upstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) + ) + + upstream_raw_data_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=DATA_SET, + sagemaker_session=self.sagemaker_session, + ) + ) + + upstream_transformation_code: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_upstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + source_type=TRANSFORMATION_CODE, + sagemaker_session=self.sagemaker_session, + ) + ) + + downstream_feature_group_associations: Iterator[AssociationSummary] = ( + LineageAssociationHandler.list_downstream_associations( + # pylint: disable=no-member + entity_arn=current_pipeline_version_context.context_arn, + destination_type=FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE, + sagemaker_session=self.sagemaker_session, + ) + ) + + is_upstream_feature_group_equal: bool = self._compare_upstream_feature_groups( + upstream_feature_group_associations=upstream_feature_group_associations, + input_feature_group_contexts=input_feature_group_contexts, + ) + is_downstream_feature_group_equal: bool = self._compare_downstream_feature_groups( + downstream_feature_group_associations=downstream_feature_group_associations, + output_feature_group_contexts=output_feature_group_contexts, + ) + is_upstream_raw_data_equal: bool = self._compare_upstream_raw_data( + upstream_raw_data_associations=upstream_raw_data_associations, + input_raw_data_artifacts=input_raw_data_artifacts, + ) + + self._update_last_transformation_code( + upstream_transformation_code_associations=upstream_transformation_code + ) + if ( + not is_upstream_feature_group_equal + or not is_downstream_feature_group_equal + or not is_upstream_raw_data_equal + ): + if not is_upstream_raw_data_equal: + logger.info("Raw data inputs have changed from the last pipeline configuration.") + if not is_upstream_feature_group_equal: + logger.info( + "Feature group inputs have changed from the last pipeline configuration." + ) + if not is_downstream_feature_group_equal: + logger.info( + "Feature Group output has changed from the last pipeline configuration." + ) + pipeline_context.properties["LastUpdateTime"] = self.pipeline[ + "LastModifiedTime" + ].strftime("%s") + PipelineLineageEntityHandler.update_pipeline_context(pipeline_context=pipeline_context) + new_pipeline_version_context: Context = self._create_pipeline_version_lineage() + self._add_associations_for_pipeline( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_versions_context_arn=new_pipeline_version_context.context_arn, + input_feature_group_contexts=input_feature_group_contexts, + input_raw_data_artifacts=input_raw_data_artifacts, + output_feature_group_contexts=output_feature_group_contexts, + transformation_code_artifact=transformation_code_artifact, + ) + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + # pylint: disable=no-member + pipeline_context_arn=pipeline_context.context_arn, + # pylint: disable=no-member + pipeline_version_context_arn=new_pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + elif transformation_code_artifact is not None: + # We will append the new transformation code artifact + # to the existing pipeline version. + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=transformation_code_artifact, + # pylint: disable=no-member + pipeline_version_context_arn=current_pipeline_version_context.context_arn, + sagemaker_session=self.sagemaker_session, + ) + + def _retrieve_input_raw_data_artifacts(self) -> List[Artifact]: + """Retrieve input Raw Data Artifacts. + + Returns: + List[Artifact]: List of Raw Data Artifacts. + """ + raw_data_artifacts: List[Artifact] = list() + raw_data_uri_set: Set[str] = set() + + for data_source in self.inputs: + if isinstance(data_source, (CSVDataSource, ParquetDataSource, BaseDataSource)): + data_source_uri = ( + data_source.s3_uri + if isinstance(data_source, (CSVDataSource, ParquetDataSource)) + else data_source.data_source_unique_id + ) + if data_source_uri not in raw_data_uri_set: + raw_data_uri_set.add(data_source_uri) + raw_data_artifacts.append( + S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, + sagemaker_session=self.sagemaker_session, + ) + ) + + return raw_data_artifacts + + def _compare_upstream_raw_data( + self, + upstream_raw_data_associations: Iterator[AssociationSummary], + input_raw_data_artifacts: List[Artifact], + ) -> bool: + """Compare the existing and the new upstream Raw Data. + + Arguments: + upstream_raw_data_associations (Iterator[AssociationSummary]): + Upstream existing raw data associations for the pipeline. + input_raw_data_artifacts (List[Artifact]): + New Upstream raw data for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + raw_data_association_set = { + raw_data_association.source_arn + for raw_data_association in upstream_raw_data_associations + } + if len(raw_data_association_set) != len(input_raw_data_artifacts): + return False + for raw_data in input_raw_data_artifacts: + if raw_data.artifact_arn not in raw_data_association_set: + return False + return True + + def _compare_downstream_feature_groups( + self, + downstream_feature_group_associations: Iterator[AssociationSummary], + output_feature_group_contexts: FeatureGroupContexts, + ) -> bool: + """Compare the existing and the new downstream Feature Groups. + + Arguments: + downstream_feature_group_associations (Iterator[AssociationSummary]): + Downstream existing Feature Group association for the pipeline. + output_feature_group_contexts (List[Artifact]): + New Downstream Feature group for the pipeline. + + Returns: + bool: Boolean if old and new Downstream is same. + """ + feature_group_association_set = set() + for feature_group_association in downstream_feature_group_associations: + feature_group_association_set.add(feature_group_association.destination_arn) + if len(feature_group_association_set) != 1: + ValueError( + f"There should only be one Feature Group as output, " + f"instead we got {len(feature_group_association_set)}. " + f"With Feature Group Versions Contexts: {feature_group_association_set}" + ) + return ( + output_feature_group_contexts.pipeline_version_context_arn + in feature_group_association_set + ) + + def _compare_upstream_feature_groups( + self, + upstream_feature_group_associations: Iterator[AssociationSummary], + input_feature_group_contexts: List[FeatureGroupContexts], + ) -> bool: + """Compare the existing and the new upstream Feature Group. + + Arguments: + upstream_feature_group_associations (Iterator[AssociationSummary]): + Upstream existing Feature Group association for the pipeline. + input_feature_group_contexts (List[Artifact]): + New Upstream Feature group for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + feature_group_association_set = set() + for feature_group_association in upstream_feature_group_associations: + feature_group_association_set.add(feature_group_association.source_arn) + if len(feature_group_association_set) != len(input_feature_group_contexts): + return False + for feature_group in input_feature_group_contexts: + if feature_group.pipeline_version_context_arn not in feature_group_association_set: + return False + return True + + def _update_last_transformation_code( + self, upstream_transformation_code_associations: Iterator[AssociationSummary] + ) -> None: + """Compare the existing and the new upstream Transformation Code. + + Arguments: + upstream_transformation_code_associations (Iterator[AssociationSummary]): + Upstream existing transformation code associations for the pipeline. + + Returns: + bool: Boolean if old and new upstream is same. + """ + upstream_transformation_code = next(upstream_transformation_code_associations, None) + if upstream_transformation_code is None: + return + + last_transformation_code_artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=upstream_transformation_code.source_arn, + sagemaker_session=self.sagemaker_session, + ) + logger.info( + "Retrieved previous transformation code artifact: %s", last_transformation_code_artifact + ) + if ( + last_transformation_code_artifact.properties["state"] + == TRANSFORMATION_CODE_STATUS_ACTIVE + ): + last_transformation_code_artifact.properties["state"] = ( + TRANSFORMATION_CODE_STATUS_INACTIVE + ) + last_transformation_code_artifact.properties["exclusive_end_date"] = self.pipeline[ + LAST_MODIFIED_TIME + ].strftime("%s") + S3LineageEntityHandler.update_transformation_code_artifact( + transformation_code_artifact=last_transformation_code_artifact + ) + logger.info("Updated the last transformation artifact") + + def _get_pipeline_context(self) -> Context: + """Retrieve Pipeline Context. + + Returns: + Context: The Pipeline Context. + """ + return PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=self.pipeline_name, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _get_pipeline_version_context(self, last_update_time: str) -> Context: + """Retrieve Pipeline Version Context. + + Returns: + Context: The Pipeline Version Context. + """ + return PipelineVersionLineageEntityHandler.load_pipeline_version_context( + pipeline_name=self.pipeline_name, + last_update_time=last_update_time, + sagemaker_session=self.sagemaker_session, + ) + + def _check_if_pipeline_lineage_exists(self) -> bool: + """Check if Pipeline Lineage exists. + + Returns: + bool: Check if pipeline lineage exists. + """ + try: + PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=self.pipeline_name, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + return True + except ClientError as e: + if e.response[ERROR][CODE] == RESOURCE_NOT_FOUND: + return False + raise e + + def _retrieve_input_feature_group_contexts(self) -> List[FeatureGroupContexts]: + """Retrieve input Feature Groups' Context ARNs. + + Returns: + List[FeatureGroupContexts]: List of Input Feature Groups for the pipeline. + """ + feature_group_contexts: List[FeatureGroupContexts] = list() + feature_group_input_set: Set[str] = set() + for data_source in self.inputs: + if isinstance(data_source, FeatureGroupDataSource): + feature_group_name: str = FeatureGroupLineageEntityHandler.parse_name_from_arn( + data_source.name + ) + if feature_group_name not in feature_group_input_set: + feature_group_input_set.add(feature_group_name) + feature_group_contexts.append( + FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=data_source.name, + sagemaker_session=self.sagemaker_session, + ) + ) + return feature_group_contexts + + def _retrieve_output_feature_group_contexts(self) -> FeatureGroupContexts: + """Retrieve output Feature Group's Context ARNs. + + Returns: + FeatureGroupContexts: The output Feature Group for the pipeline. + """ + return FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=self.output, sagemaker_session=self.sagemaker_session + ) + + def _create_pipeline_lineage_for_new_pipeline(self) -> Context: + """Create Pipeline Context for a new pipeline. + + Returns: + Context: The Pipeline Context. + """ + return PipelineLineageEntityHandler.create_pipeline_context( + pipeline_name=self.pipeline_name, + pipeline_arn=self.pipeline_arn, + creation_time=self.pipeline[CREATION_TIME].strftime("%s"), + last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _create_pipeline_version_lineage(self) -> Context: + """Create a new Pipeline Version Context. + + Returns: + Context: The Pipeline Versions Context. + """ + return PipelineVersionLineageEntityHandler.create_pipeline_version_context( + pipeline_name=self.pipeline_name, + pipeline_arn=self.pipeline_arn, + last_update_time=self.pipeline[LAST_MODIFIED_TIME].strftime("%s"), + sagemaker_session=self.sagemaker_session, + ) + + def _add_associations_for_pipeline( + self, + pipeline_context_arn: str, + pipeline_versions_context_arn: str, + input_feature_group_contexts: List[FeatureGroupContexts], + input_raw_data_artifacts: List[Artifact], + output_feature_group_contexts: FeatureGroupContexts, + transformation_code_artifact: Optional[Artifact] = None, + ) -> None: + """Add Feature Processor Lineage Associations for the Pipeline + + Arguments: + pipeline_context_arn (str): The pipeline Context ARN. + pipeline_versions_context_arn (str): The pipeline Version Context ARN. + input_feature_group_contexts (List[FeatureGroupContexts]): List of input FeatureGroups. + input_raw_data_artifacts (List[Artifact]): List of input raw data. + output_feature_group_contexts (FeatureGroupContexts): Output Feature Group + transformation_code_artifact (Optional[Artifact]): The transformation Code. + """ + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=input_feature_group_contexts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + LineageAssociationHandler.add_downstream_feature_group_data_associations( + feature_group_output=output_feature_group_contexts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + LineageAssociationHandler.add_upstream_raw_data_associations( + raw_data_inputs=input_raw_data_artifacts, + pipeline_context_arn=pipeline_context_arn, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) + + if transformation_code_artifact is not None: + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=transformation_code_artifact, + pipeline_version_context_arn=pipeline_versions_context_arn, + sagemaker_session=self.sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py new file mode 100644 index 0000000000..1a4e9ed04f --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_processor_lineage_name_helper.py @@ -0,0 +1,101 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle lineage resource name generation.""" +from __future__ import absolute_import + +FEATURE_PROCESSOR_CREATED_PREFIX = "sm-fs-fe" +FEATURE_PROCESSOR_CREATED_TRIGGER_PREFIX = "sm-fs-fe-trigger" +FEATURE_GROUP_PIPELINE_CONTEXT_SUFFIX = "feature-group-pipeline" +FEATURE_GROUP_PIPELINE_CONTEXT_VERSION_SUFFIX = "feature-group-pipeline-version" +FEATURE_PROCESSOR_PIPELINE_CONTEXT_SUFFIX = "fep" +FEATURE_PROCESSOR_PIPELINE_VERSION_CONTEXT_SUFFIX = "fep-ver" + + +def _get_feature_processor_lineage_context_name( + resource_name: str, + resource_creation_time: str, + lineage_context_prefix: str = None, + lineage_context_suffix: str = None, +) -> str: + """Generic naming generation function for lineage resources used by feature_processor.""" + context_name_base = [f"{resource_name}-{resource_creation_time}"] + if lineage_context_prefix: + context_name_base.insert(0, lineage_context_prefix) + if lineage_context_suffix: + context_name_base.append(lineage_context_suffix) + return "-".join(context_name_base) + + +def _get_feature_group_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group contexts.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, resource_creation_time=feature_group_creation_time + ) + + +def _get_feature_group_pipeline_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group pipeline.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, + resource_creation_time=feature_group_creation_time, + lineage_context_suffix=FEATURE_GROUP_PIPELINE_CONTEXT_SUFFIX, + ) + + +def _get_feature_group_pipeline_version_lineage_context_name( + feature_group_name: str, feature_group_creation_time: str +) -> str: + """Generate context name for feature group pipeline version.""" + return _get_feature_processor_lineage_context_name( + resource_name=feature_group_name, + resource_creation_time=feature_group_creation_time, + lineage_context_suffix=FEATURE_GROUP_PIPELINE_CONTEXT_VERSION_SUFFIX, + ) + + +def _get_feature_processor_pipeline_lineage_context_name( + pipeline_name: str, pipeline_creation_time: str +) -> str: + """Generate context name for feature processor pipeline.""" + return _get_feature_processor_lineage_context_name( + resource_name=pipeline_name, + resource_creation_time=pipeline_creation_time, + lineage_context_prefix=FEATURE_PROCESSOR_CREATED_PREFIX, + lineage_context_suffix=FEATURE_PROCESSOR_PIPELINE_CONTEXT_SUFFIX, + ) + + +def _get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name: str, pipeline_last_update_time: str +) -> str: + """Generate context name for feature processor pipeline version.""" + return _get_feature_processor_lineage_context_name( + resource_name=pipeline_name, + resource_creation_time=pipeline_last_update_time, + lineage_context_prefix=FEATURE_PROCESSOR_CREATED_PREFIX, + lineage_context_suffix=FEATURE_PROCESSOR_PIPELINE_VERSION_CONTEXT_SUFFIX, + ) + + +def _get_feature_processor_schedule_lineage_artifact_name(schedule_name: str) -> str: + """Generate artifact name for feature processor pipeline schedule.""" + return "-".join([FEATURE_PROCESSOR_CREATED_PREFIX, schedule_name]) + + +def _get_feature_processor_trigger_lineage_artifact_name(trigger_name: str) -> str: + """Generate artifact name for feature processor pipeline trigger.""" + return "-".join([FEATURE_PROCESSOR_CREATED_TRIGGER_PREFIX, trigger_name]) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py new file mode 100644 index 0000000000..0413b5d7c1 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_lineage_association_handler.py @@ -0,0 +1,300 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Lineage Associations""" +from __future__ import absolute_import +import logging +from typing import List, Optional, Iterator +from botocore.exceptions import ClientError + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor._constants import VALIDATION_EXCEPTION +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + CONTRIBUTED_TO, + ERROR, + CODE, + SAGEMAKER, + ASSOCIATED_WITH, +) +from sagemaker.core.lineage.artifact import Artifact +from sagemaker.core.lineage.association import Association, AssociationSummary + +logger = logging.getLogger(SAGEMAKER) + + +class LineageAssociationHandler: + """Class to handler the FeatureProcessor Lineage Associations""" + + @staticmethod + def add_upstream_feature_group_data_associations( + feature_group_inputs: List[FeatureGroupContexts], + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Feature Group Lineage Associations. + + Arguments: + feature_group_inputs (List[FeatureGroupContexts]): The input Feature Group List. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + for feature_group in feature_group_inputs: + LineageAssociationHandler._add_association( + source_arn=feature_group.pipeline_context_arn, + destination_arn=pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_raw_data_associations( + raw_data_inputs: List[Artifact], + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Raw Data Lineage Associations. + + Arguments: + raw_data_inputs (List[Artifact]): The input raw data List. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + for raw_data_artifact in raw_data_inputs: + LineageAssociationHandler._add_association( + source_arn=raw_data_artifact.artifact_arn, + destination_arn=pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=raw_data_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_transformation_code_associations( + transformation_code_artifact: Artifact, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Transformation Code Lineage Associations. + + Arguments: + transformation_code_artifact (Artifact): The transformation Code Artifact. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=transformation_code_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_upstream_schedule_associations( + schedule_artifact: Artifact, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Upstream Schedule Lineage Associations. + + Arguments: + schedule_artifact (Artifact): The schedule Artifact. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=schedule_artifact.artifact_arn, + destination_arn=pipeline_version_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_downstream_feature_group_data_associations( + feature_group_output: FeatureGroupContexts, + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Downstream Feature Group Lineage Associations. + + Arguments: + feature_group_output (FeatureGroupContexts): The output Feature Group. + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=pipeline_context_arn, + destination_arn=feature_group_output.pipeline_context_arn, + association_type=CONTRIBUTED_TO, + sagemaker_session=sagemaker_session, + ) + LineageAssociationHandler._add_association( + source_arn=pipeline_version_context_arn, + destination_arn=feature_group_output.pipeline_version_context_arn, + association_type="ContributedTo", + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def add_pipeline_and_pipeline_version_association( + pipeline_context_arn: str, + pipeline_version_context_arn: str, + sagemaker_session: Session, + ) -> None: + """Add the FeatureProcessor Lineage Association + + between the Pipeline and the Pipeline Versions. + + Arguments: + pipeline_context_arn (str): The pipeline context arn. + pipeline_version_context_arn (str): The pipeline version context arn. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + LineageAssociationHandler._add_association( + source_arn=pipeline_context_arn, + destination_arn=pipeline_version_context_arn, + association_type=ASSOCIATED_WITH, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def list_upstream_associations( + entity_arn: str, source_type: str, sagemaker_session: Session + ) -> Iterator[AssociationSummary]: + """List Upstream Lineage Associations. + + Arguments: + entity_arn (str): The Lineage Entity ARN. + source_type (str): The Source Type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return LineageAssociationHandler._list_association( + destination_arn=entity_arn, + source_type=source_type, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def list_downstream_associations( + entity_arn: str, destination_type: str, sagemaker_session: Session + ) -> Iterator[AssociationSummary]: + """List Downstream Lineage Associations. + + Arguments: + entity_arn (str): The Lineage Entity ARN. + destination_type (str): The Destination Type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return LineageAssociationHandler._list_association( + source_arn=entity_arn, + destination_type=destination_type, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def _add_association( + source_arn: str, + destination_arn: str, + association_type: str, + sagemaker_session: Session, + ) -> None: + """Add Lineage Association. + + Arguments: + source_arn (str): The source ARN. + destination_arn (str): The destination ARN. + association_type (str): The association type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + try: + logger.info( + "Adding association with source_arn: " + "%s, destination_arn: %s and association_type: %s.", + source_arn, + destination_arn, + association_type, + ) + Association.create( + source_arn=source_arn, + destination_arn=destination_arn, + association_type=association_type, + sagemaker_session=sagemaker_session, + ) + except ClientError as e: + if e.response[ERROR][CODE] == VALIDATION_EXCEPTION: + logger.info("Association already exists") + else: + raise e + + @staticmethod + def _list_association( + sagemaker_session: Session, + source_arn: Optional[str] = None, + source_type: Optional[str] = None, + destination_arn: Optional[str] = None, + destination_type: Optional[str] = None, + ) -> Iterator[AssociationSummary]: + """List Lineage Associations. + + Arguments: + source_arn (str): The source ARN. + source_type (str): The source type. + destination_arn (str): The destination ARN. + destination_type (str): The destination type. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + """ + return Association.list( + source_arn=source_arn, + source_type=source_type, + destination_arn=destination_arn, + destination_type=destination_type, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py new file mode 100644 index 0000000000..3bf80e9d95 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_lineage_entity_handler.py @@ -0,0 +1,105 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Pipeline Lineage""" +from __future__ import absolute_import +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + PIPELINE_NAME_KEY, + PIPELINE_CREATION_TIME_KEY, + LAST_UPDATE_TIME_KEY, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_pipeline_lineage_context_name, +) +from sagemaker.core.lineage import context + +logger = logging.getLogger(SAGEMAKER) + + +class PipelineLineageEntityHandler: + """Class for handling FeatureProcessor Pipeline Lineage""" + + @staticmethod + def create_pipeline_context( + pipeline_name: str, + pipeline_arn: str, + creation_time: str, + last_update_time: str, + sagemaker_session: Session, + ) -> Context: + """Create the FeatureProcessor Pipeline context. + + Arguments: + pipeline_name (str): The pipeline name. + pipeline_arn (str): The pipeline ARN. + creation_time (str): The pipeline creation time. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline context. + """ + return context.Context.create( + context_name=_get_feature_processor_pipeline_lineage_context_name( + pipeline_name, creation_time + ), + context_type="FeatureEngineeringPipeline", + source_uri=pipeline_arn, + source_type=creation_time, + properties={ + PIPELINE_NAME_KEY: pipeline_name, + PIPELINE_CREATION_TIME_KEY: creation_time, + LAST_UPDATE_TIME_KEY: last_update_time, + }, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_pipeline_context( + pipeline_name: str, creation_time: str, sagemaker_session: Session + ) -> Context: + """Load the FeatureProcessor Pipeline context. + + Arguments: + pipeline_name (str): The pipeline name. + creation_time (str): The pipeline creation time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline context. + """ + return Context.load( + context_name=_get_feature_processor_pipeline_lineage_context_name( + pipeline_name, creation_time + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def update_pipeline_context(pipeline_context: Context) -> None: + """Update the FeatureProcessor Pipeline context + + Arguments: + pipeline_context (Context): The pipeline context. + """ + pipeline_context.save() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py new file mode 100644 index 0000000000..08f10fb8fb --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_schedule.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store the Pipeline Schedule""" +from __future__ import absolute_import +import attr + + +@attr.s +class PipelineSchedule: + """A Schedule definition for FeatureProcessor Lineage. + + Attributes: + schedule_name (str): Schedule Name. + schedule_arn (str): The ARN of the Schedule. + schedule_expression (str): The expression that defines when the schedule runs. It supports + at expression, rate expression and cron expression. See https://docs.aws.amazon.com/ + scheduler/latest/APIReference/API_CreateSchedule.html#scheduler-CreateSchedule-request + -ScheduleExpression for more details. + pipeline_name (str): The SageMaker Pipeline name that will be scheduled. + state (str): Specifies whether the schedule is enabled or disabled. Valid values are + ENABLED and DISABLED. See https://docs.aws.amazon.com/scheduler/latest/APIReference/ + API_CreateSchedule.html#scheduler-CreateSchedule-request-State for more details. + If not specified, it will default to DISABLED. + start_date (Optional[datetime]): The date, in UTC, after which the schedule can begin + invoking its target. Depending on the schedule’s recurrence expression, invocations + might occur on, or after, the StartDate you specify. + """ + + schedule_name: str = attr.ib() + schedule_arn: str = attr.ib() + schedule_expression: str = attr.ib() + pipeline_name: str = attr.ib() + state: str = attr.ib() + start_date: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py new file mode 100644 index 0000000000..e58003f396 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_trigger.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store the Pipeline Schedule""" +from __future__ import absolute_import +import attr + + +@attr.s +class PipelineTrigger: + """An evnet based trigger definition for FeatureProcessor Lineage. + + Attributes: + trigger_name (str): Trigger Name. + trigger_arn (str): The ARN of the Trigger. + event_pattern (str): The event pattern. For more information, see Amazon EventBridge + event patterns in the Amazon EventBridge User Guide. + pipeline_name (str): The SageMaker Pipeline name that will be triggered. + state (str): Specifies whether the trigger is enabled or disabled. Valid values are + ENABLED and DISABLED. + """ + + trigger_name: str = attr.ib() + trigger_arn: str = attr.ib() + event_pattern: str = attr.ib() + pipeline_name: str = attr.ib() + state: str = attr.ib() diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py new file mode 100644 index 0000000000..5d0b4c979b --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_pipeline_version_lineage_entity_handler.py @@ -0,0 +1,92 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle Pipeline Version Lineage""" +from __future__ import absolute_import +import logging + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + SAGEMAKER, + PIPELINE_VERSION_CONTEXT_TYPE, + PIPELINE_NAME_KEY, + LAST_UPDATE_TIME_KEY, +) +from sagemaker.core.lineage.context import Context + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_pipeline_version_lineage_context_name, +) + +logger = logging.getLogger(SAGEMAKER) + + +class PipelineVersionLineageEntityHandler: + """Class for handling FeatureProcessor Pipeline Version Lineage""" + + @staticmethod + def create_pipeline_version_context( + pipeline_name: str, + pipeline_arn: str, + last_update_time: str, + sagemaker_session: Session, + ) -> Context: + """Create the FeatureProcessor Pipeline Version context. + + Arguments: + pipeline_name (str): The pipeline name. + pipeline_arn (str): The pipeline ARN. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline version context. + """ + return Context.create( + context_name=_get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name, last_update_time + ), + context_type=f"{PIPELINE_VERSION_CONTEXT_TYPE}-{pipeline_name}", + source_uri=pipeline_arn, + source_type=last_update_time, + properties={ + PIPELINE_NAME_KEY: pipeline_name, + LAST_UPDATE_TIME_KEY: last_update_time, + }, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_pipeline_version_context( + pipeline_name: str, last_update_time: str, sagemaker_session: Session + ) -> Context: + """Load the FeatureProcessor Pipeline Version context. + + Arguments: + pipeline_name (str): The pipeline name. + last_update_time (str): The pipeline last update time. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Context: The pipeline version context. + """ + return Context.load( + context_name=_get_feature_processor_pipeline_version_lineage_context_name( + pipeline_name, last_update_time + ), + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py new file mode 100644 index 0000000000..78a0f18c7c --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_s3_lineage_entity_handler.py @@ -0,0 +1,316 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to handle S3 Lineage""" +from __future__ import absolute_import +import logging +from typing import Union, Optional, List + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor import ( + CSVDataSource, + ParquetDataSource, + BaseDataSource, +) + +# pylint: disable=C0301 +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage_name_helper import ( + _get_feature_processor_schedule_lineage_artifact_name, + _get_feature_processor_trigger_lineage_artifact_name, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + TRANSFORMATION_CODE_STATUS_ACTIVE, + FEP_LINEAGE_PREFIX, + TRANSFORMATION_CODE_ARTIFACT_NAME, +) +from sagemaker.core.lineage.artifact import Artifact, ArtifactSummary + +logger = logging.getLogger("sagemaker") + + +class S3LineageEntityHandler: + """Class for handling FeatureProcessor S3 Artifact Lineage""" + + @staticmethod + def retrieve_raw_data_artifact( + raw_data: Union[CSVDataSource, ParquetDataSource, BaseDataSource], + sagemaker_session: Session, + ) -> Artifact: + """Load or create the FeatureProcessor Pipeline's raw data Artifact. + + Arguments: + raw_data (Union[CSVDataSource, ParquetDataSource, BaseDataSource]): The raw data to be + retrieved. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The raw data artifact. + """ + raw_data_uri = ( + raw_data.s3_uri + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_unique_id + ) + raw_data_artifact_name = ( + "sm-fs-fe-raw-data" + if isinstance(raw_data, (CSVDataSource, ParquetDataSource)) + else raw_data.data_source_name + ) + + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=raw_data_uri, sagemaker_session=sagemaker_session + ) + if load_artifact is not None: + return S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + + return S3LineageEntityHandler._create_artifact( + s3_uri=raw_data_uri, + artifact_type="DataSet", + artifact_name=raw_data_artifact_name, + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def update_transformation_code_artifact( + transformation_code_artifact: Artifact, + ) -> None: + """Update Pipeline's transformation code Artifact. + + Arguments: + transformation_code_artifact (TransformationCode): The transformation code Artifact to be updated. + """ + transformation_code_artifact.save() + + @staticmethod + def create_transformation_code_artifact( + transformation_code: TransformationCode, + pipeline_last_update_time: str, + sagemaker_session: Session, + ) -> Optional[Artifact]: + """Create the FeatureProcessor Pipeline's transformation code Artifact. + + Arguments: + transformation_code (TransformationCode): The transformation code to be retrieved. + pipeline_last_update_time (str): The last update time of the pipeline. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The transformation code artifact. + """ + if transformation_code is None: + return None + + properties = dict( + state=TRANSFORMATION_CODE_STATUS_ACTIVE, + inclusive_start_date=pipeline_last_update_time, + ) + if transformation_code.name is not None: + properties["name"] = transformation_code.name + if transformation_code.author is not None: + properties["author"] = transformation_code.author + + return S3LineageEntityHandler._create_artifact( + s3_uri=transformation_code.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=pipeline_last_update_time)], + properties=properties, + artifact_type="TransformationCode", + artifact_name=f"{FEP_LINEAGE_PREFIX}-" + f"{TRANSFORMATION_CODE_ARTIFACT_NAME}-" + f"{pipeline_last_update_time}", + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def retrieve_pipeline_schedule_artifact( + pipeline_schedule: PipelineSchedule, + sagemaker_session: Session, + _get_feature_processor_schedule_lineage_artifact_namef=None, + ) -> Optional[Artifact]: + """Load or create the FeatureProcessor Pipeline's schedule Artifact + + Arguments: + pipeline_schedule (PipelineSchedule): Class to hold the Pipeline Schedule details + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Schedule Artifact. + """ + if pipeline_schedule is None: + return None + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=pipeline_schedule.schedule_arn, + sagemaker_session=sagemaker_session, + ) + if load_artifact is not None: + pipeline_schedule_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + pipeline_schedule_artifact.properties["pipeline_name"] = pipeline_schedule.pipeline_name + pipeline_schedule_artifact.properties["schedule_expression"] = ( + pipeline_schedule.schedule_expression + ) + pipeline_schedule_artifact.properties["state"] = pipeline_schedule.state + pipeline_schedule_artifact.properties["start_date"] = pipeline_schedule.start_date + pipeline_schedule_artifact.save() + return pipeline_schedule_artifact + + return S3LineageEntityHandler._create_artifact( + s3_uri=pipeline_schedule.schedule_arn, + artifact_type="PipelineSchedule", + artifact_name=_get_feature_processor_schedule_lineage_artifact_name( + schedule_name=pipeline_schedule.schedule_name + ), + properties=dict( + pipeline_name=pipeline_schedule.pipeline_name, + schedule_expression=pipeline_schedule.schedule_expression, + state=pipeline_schedule.state, + start_date=pipeline_schedule.start_date, + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def retrieve_pipeline_trigger_artifact( + pipeline_trigger: PipelineTrigger, + sagemaker_session: Session, + ) -> Optional[Artifact]: + """Load or create the FeatureProcessor Pipeline's trigger Artifact + + Arguments: + pipeline_trigger (PipelineTrigger): Class to hold the Pipeline Trigger details + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Trigger Artifact. + """ + if pipeline_trigger is None: + return None + load_artifact: ArtifactSummary = S3LineageEntityHandler._load_artifact_from_s3_uri( + s3_uri=pipeline_trigger.trigger_arn, + sagemaker_session=sagemaker_session, + ) + if load_artifact is not None: + pipeline_trigger_artifact: Artifact = S3LineageEntityHandler.load_artifact_from_arn( + artifact_arn=load_artifact.artifact_arn, + sagemaker_session=sagemaker_session, + ) + pipeline_trigger_artifact.properties["pipeline_name"] = pipeline_trigger.pipeline_name + pipeline_trigger_artifact.properties["event_pattern"] = pipeline_trigger.event_pattern + pipeline_trigger_artifact.properties["state"] = pipeline_trigger.state + pipeline_trigger_artifact.save() + return pipeline_trigger_artifact + + return S3LineageEntityHandler._create_artifact( + s3_uri=pipeline_trigger.trigger_arn, + artifact_type="PipelineTrigger", + artifact_name=_get_feature_processor_trigger_lineage_artifact_name( + trigger_name=pipeline_trigger.trigger_name + ), + properties=dict( + pipeline_name=pipeline_trigger.pipeline_name, + event_pattern=pipeline_trigger.event_pattern, + state=pipeline_trigger.state, + ), + sagemaker_session=sagemaker_session, + ) + + @staticmethod + def load_artifact_from_arn(artifact_arn: str, sagemaker_session: Session) -> Artifact: + """Load Lineage Artifacts from ARN. + + Arguments: + artifact_arn (str): The Artifact ARN. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The Artifact for the provided ARN. + """ + return Artifact.load(artifact_arn=artifact_arn, sagemaker_session=sagemaker_session) + + @staticmethod + def _load_artifact_from_s3_uri( + s3_uri: str, sagemaker_session: Session + ) -> Optional[ArtifactSummary]: + """Load FeatureProcessor S3 Lineage Artifacts. + + Arguments: + s3_uri (str): The s3 uri of the Artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + ArtifactSummary: The Artifact Summary for the provided S3 URI. + """ + artifacts = Artifact.list(source_uri=s3_uri, sagemaker_session=sagemaker_session) + for artifact_summary in artifacts: + # We want to make sure that source_type is empty. + # Since SDK will not set it while creating artifacts. + if ( + artifact_summary.source.source_types is None + or len(artifact_summary.source.source_types) == 0 + ): + return artifact_summary + return None + + @staticmethod + def _create_artifact( + s3_uri: str, + artifact_type: str, + sagemaker_session: Session, + properties: Optional[dict] = None, + artifact_name: Optional[str] = None, + source_types: Optional[List[dict]] = None, + ) -> Artifact: + """Create Lineage Artifacts. + + Arguments: + s3_uri (str): The s3 uri of the Artifact. + artifact_type (str): The Artifact type. + properties (Optional[dict]): The properties of the Artifact. + artifact_name (Optional[str]): The name of the Artifact. + sagemaker_session (Session): Session object which manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + function creates one using the default AWS configuration chain. + + Returns: + Artifact: The new Artifact. + """ + return Artifact.create( + source_uri=s3_uri, + source_types=source_types, + artifact_type=artifact_type, + artifact_name=artifact_name, + properties=properties, + sagemaker_session=sagemaker_session, + ) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py new file mode 100644 index 0000000000..70ce48d910 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_transformation_code.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains class to store Transformation Code""" +from __future__ import absolute_import +from typing import Optional +import attr + + +@attr.s +class TransformationCode: + """A Transformation Code definition for FeatureProcessor Lineage. + + Attributes: + s3_uri (str): The S3 URI of the code. + name (Optional[str]): The name of the code Artifact object. + author (Optional[str]): The author of the code. + """ + + s3_uri: str = attr.ib() + name: Optional[str] = attr.ib(default=None) + author: Optional[str] = attr.ib(default=None) diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py new file mode 100644 index 0000000000..25f4b04716 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/constants.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Module containing constants for feature_processor and feature_scheduler module.""" +from __future__ import absolute_import + +FEATURE_GROUP_PIPELINE_VERSION_CONTEXT_TYPE = "FeatureGroupPipelineVersion" +PIPELINE_CONTEXT_TYPE = "FeatureEngineeringPipeline" +PIPELINE_VERSION_CONTEXT_TYPE = "FeatureEngineeringPipelineVersion" +PIPELINE_CONTEXT_NAME_SUFFIX = "fep" +PIPELINE_VERSION_CONTEXT_NAME_SUFFIX = "fep-ver" +FEP_LINEAGE_PREFIX = "sm-fs-fe" +DATA_SET = "DataSet" +TRANSFORMATION_CODE = "TransformationCode" +LAST_UPDATE_TIME = "LastUpdateTime" +LAST_MODIFIED_TIME = "LastModifiedTime" +CREATION_TIME = "CreationTime" +RESOURCE_NOT_FOUND = "ResourceNotFound" +ERROR = "Error" +CODE = "Code" +SAGEMAKER = "sagemaker" +CONTRIBUTED_TO = "ContributedTo" +ASSOCIATED_WITH = "AssociatedWith" +FEATURE_GROUP = "FeatureGroupName" +FEATURE_GROUP_PIPELINE_SUFFIX = "feature-group-pipeline" +FEATURE_GROUP_PIPELINE_VERSION_SUFFIX = "feature-group-pipeline-version" +PIPELINE_CONTEXT_NAME_KEY = "pipeline_context_name" +PIPELINE_CONTEXT_VERSION_NAME_KEY = "pipeline_version_context_name" +PIPELINE_NAME_KEY = "PipelineName" +PIPELINE_CREATION_TIME_KEY = "PipelineCreationTime" +LAST_UPDATE_TIME_KEY = "LastUpdateTime" +TRANSFORMATION_CODE_STATUS_ACTIVE = "Active" +TRANSFORMATION_CODE_STATUS_INACTIVE = "Inactive" +TRANSFORMATION_CODE_ARTIFACT_NAME = "transformation-code" From 7c2663dca99eea982293ed4aade8f87d0f01e65b Mon Sep 17 00:00:00 2001 From: BassemHalim Date: Mon, 9 Feb 2026 16:01:00 -0800 Subject: [PATCH 7/8] test(feature_store): Add Feature Processor unit tests --- .../sagemaker/core/helper/session_helper.py | 220 ++ .../tests/unit/session/test_session_helper.py | 545 +++ .../feature_processor/_input_loader.py | 4 +- .../feature_processor/feature_scheduler.py | 9 +- .../_feature_group_lineage_entity_handler.py | 4 +- .../lineage/test_constants.py | 401 +++ ...st_feature_group_lineage_entity_handler.py | 62 + .../lineage/test_feature_processor_lineage.py | 2966 +++++++++++++++++ .../test_lineage_association_handler.py | 224 ++ .../test_pipeline_lineage_entity_handler.py | 74 + .../lineage/test_pipeline_trigger.py | 33 + ...pipeline_version_lineage_entity_handler.py | 67 + .../lineage/test_s3_lineage_entity_handler.py | 434 +++ .../feature_processor/test_config_uploader.py | 317 ++ .../feature_processor/test_data_helpers.py | 166 + .../feature_processor/test_data_source.py | 34 + .../feature_processor/test_env.py | 122 + .../test_event_bridge_rule_helper.py | 301 ++ .../test_event_bridge_scheduler_helper.py | 96 + .../feature_processor/test_factory.py | 75 + .../test_feature_processor.py | 122 + .../test_feature_processor_config.py | 46 + .../test_feature_processor_pipeline_events.py | 30 + .../test_feature_scheduler.py | 1057 ++++++ .../feature_processor/test_input_loader.py | 320 ++ .../test_input_offset_parser.py | 143 + .../feature_processor/test_params_loader.py | 86 + .../test_spark_session_factory.py | 175 + .../test_udf_arg_provider.py | 280 ++ .../test_udf_output_receiver.py | 106 + .../feature_processor/test_udf_wrapper.py | 85 + .../feature_processor/test_validation.py | 192 ++ 32 files changed, 8784 insertions(+), 12 deletions(-) create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py create mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 41957e30a2..b4c327ec09 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -85,6 +85,10 @@ TAGS, SESSION_DEFAULT_S3_BUCKET_PATH, SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH, + FEATURE_GROUP, + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, ) # Setting LOGGER for backward compatibility, in case users import it... @@ -1607,6 +1611,222 @@ def delete_endpoint_config(self, endpoint_config_name): logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name) self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) + def delete_feature_group(self, feature_group_name): + """Delete an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to delete. + """ + logger.info("Deleting feature group with name: %s", feature_group_name) + self.sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name) + + def create_feature_group( + self, + feature_group_name, + record_identifier_name, + event_time_feature_name, + feature_definitions, + role_arn=None, + online_store_config=None, + offline_store_config=None, + throughput_config=None, + description=None, + tags=None, + ): + """Create an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + record_identifier_name (str): Name of the record identifier feature. + event_time_feature_name (str): Name of the event time feature. + feature_definitions (list): List of feature definitions. + role_arn (str): ARN of the role used to execute the API (default: None). + Resolved from SageMaker Config if not provided. + online_store_config (dict): Online store configuration (default: None). + offline_store_config (dict): Offline store configuration (default: None). + throughput_config (dict): Throughput configuration (default: None). + description (str): Description of the Feature Group (default: None). + tags (Optional[Tags]): Tags for labeling the Feature Group (default: None). + + Returns: + dict: Response from the CreateFeatureGroup API. + """ + tags = format_tags(tags) + tags = _append_project_tags(tags) + tags = self._append_sagemaker_config_tags( + tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS) + ) + + role_arn = resolve_value_from_config( + role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self + ) + + inferred_online_store_config = update_nested_dictionary_with_values_from_config( + online_store_config, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + sagemaker_session=self, + ) + if inferred_online_store_config is not None: + # OnlineStore should be handled differently because if you set KmsKeyId, then you + # need to set EnableOnlineStore key as well + inferred_online_store_config["EnableOnlineStore"] = True + + inferred_offline_store_config = update_nested_dictionary_with_values_from_config( + offline_store_config, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, + sagemaker_session=self, + ) + + kwargs = dict( + FeatureGroupName=feature_group_name, + RecordIdentifierFeatureName=record_identifier_name, + EventTimeFeatureName=event_time_feature_name, + FeatureDefinitions=feature_definitions, + RoleArn=role_arn, + ) + update_args( + kwargs, + OnlineStoreConfig=inferred_online_store_config, + OfflineStoreConfig=inferred_offline_store_config, + ThroughputConfig=throughput_config, + Description=description, + Tags=tags, + ) + + logger.info("Creating feature group with name: %s", feature_group_name) + return self.sagemaker_client.create_feature_group(**kwargs) + + def describe_feature_group(self, feature_group_name, next_token=None): + """Describe an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to describe. + next_token (str): A token for paginated results (default: None). + + Returns: + dict: Response from the DescribeFeatureGroup API. + """ + args = {"FeatureGroupName": feature_group_name} + update_args(args, NextToken=next_token) + return self.sagemaker_client.describe_feature_group(**args) + + def update_feature_group( + self, + feature_group_name, + feature_additions=None, + online_store_config=None, + throughput_config=None, + ): + """Update an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Amazon SageMaker Feature Group to update. + feature_additions (list): List of feature definitions to add (default: None). + online_store_config (dict): Online store configuration updates (default: None). + throughput_config (dict): Throughput configuration updates (default: None). + + Returns: + dict: Response from the UpdateFeatureGroup API. + """ + args = {"FeatureGroupName": feature_group_name} + update_args( + args, + FeatureAdditions=feature_additions, + OnlineStoreConfig=online_store_config, + ThroughputConfig=throughput_config, + ) + return self.sagemaker_client.update_feature_group(**args) + + def list_feature_groups( + self, + name_contains=None, + feature_group_status_equals=None, + offline_store_status_equals=None, + creation_time_after=None, + creation_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ): + """List Amazon SageMaker Feature Groups. + + Args: + name_contains (str): Filter by name substring (default: None). + feature_group_status_equals (str): Filter by status (default: None). + offline_store_status_equals (str): Filter by offline store status (default: None). + creation_time_after (datetime): Filter by creation time lower bound (default: None). + creation_time_before (datetime): Filter by creation time upper bound (default: None). + sort_order (str): Sort order, 'Ascending' or 'Descending' (default: None). + sort_by (str): Sort by field (default: None). + max_results (int): Maximum number of results (default: None). + next_token (str): Pagination token (default: None). + + Returns: + dict: Response from the ListFeatureGroups API. + """ + args = {} + update_args( + args, + NameContains=name_contains, + FeatureGroupStatusEquals=feature_group_status_equals, + OfflineStoreStatusEquals=offline_store_status_equals, + CreationTimeAfter=creation_time_after, + CreationTimeBefore=creation_time_before, + SortOrder=sort_order, + SortBy=sort_by, + MaxResults=max_results, + NextToken=next_token, + ) + return self.sagemaker_client.list_feature_groups(**args) + + def update_feature_metadata( + self, + feature_group_name, + feature_name, + description=None, + parameter_additions=None, + parameter_removals=None, + ): + """Update metadata for a feature in an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + feature_name (str): Name of the feature to update metadata for. + description (str): Updated description for the feature (default: None). + parameter_additions (list): Parameters to add (default: None). + parameter_removals (list): Parameters to remove (default: None). + + Returns: + dict: Response from the UpdateFeatureMetadata API. + """ + args = { + "FeatureGroupName": feature_group_name, + "FeatureName": feature_name, + } + update_args( + args, + Description=description, + ParameterAdditions=parameter_additions, + ParameterRemovals=parameter_removals, + ) + return self.sagemaker_client.update_feature_metadata(**args) + + def describe_feature_metadata(self, feature_group_name, feature_name): + """Describe metadata for a feature in an Amazon SageMaker Feature Group. + + Args: + feature_group_name (str): Name of the Feature Group. + feature_name (str): Name of the feature to describe metadata for. + + Returns: + dict: Response from the DescribeFeatureMetadata API. + """ + return self.sagemaker_client.describe_feature_metadata( + FeatureGroupName=feature_group_name, + FeatureName=feature_name, + ) + def wait_for_optimization_job(self, job, poll=5): """Wait for an Amazon SageMaker Optimization job to complete. diff --git a/sagemaker-core/tests/unit/session/test_session_helper.py b/sagemaker-core/tests/unit/session/test_session_helper.py index ca4fd81aa8..7e2004c1d0 100644 --- a/sagemaker-core/tests/unit/session/test_session_helper.py +++ b/sagemaker-core/tests/unit/session/test_session_helper.py @@ -29,6 +29,11 @@ update_args, NOTEBOOK_METADATA_FILE, ) +from sagemaker.core.config.config_schema import ( + FEATURE_GROUP_ROLE_ARN_PATH, + FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, + FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, +) class TestSession: @@ -451,3 +456,543 @@ def test_update_args_with_none_values(self): assert args["existing"] == "value" assert "new_key" not in args assert args["another_key"] == "another_value" + +class TestFeatureGroupSessionMethods: + """Test cases for Feature Group session methods""" + + @pytest.fixture + def session_with_mock_client(self): + """Create a Session with a mocked sagemaker_client.""" + mock_boto_session = Mock() + mock_boto_session.region_name = "us-west-2" + mock_boto_session.client.return_value = Mock() + mock_boto_session.resource.return_value = Mock() + session = Session(boto_session=mock_boto_session) + session.sagemaker_client = Mock() + return session + + # --- delete_feature_group --- + + def test_delete_feature_group(self, session_with_mock_client): + """Test delete_feature_group delegates to sagemaker_client.""" + session = session_with_mock_client + session.delete_feature_group("my-feature-group") + + session.sagemaker_client.delete_feature_group.assert_called_once_with( + FeatureGroupName="my-feature-group" + ) + + # --- describe_feature_group --- + + def test_describe_feature_group(self, session_with_mock_client): + """Test describe_feature_group delegates and returns response.""" + session = session_with_mock_client + expected = {"FeatureGroupName": "my-fg", "CreationTime": "2024-01-01"} + session.sagemaker_client.describe_feature_group.return_value = expected + + result = session.describe_feature_group("my-fg") + + session.sagemaker_client.describe_feature_group.assert_called_once_with( + FeatureGroupName="my-fg" + ) + assert result == expected + + def test_describe_feature_group_with_next_token(self, session_with_mock_client): + """Test describe_feature_group includes NextToken when provided.""" + session = session_with_mock_client + session.sagemaker_client.describe_feature_group.return_value = {} + + session.describe_feature_group("my-fg", next_token="abc123") + + session.sagemaker_client.describe_feature_group.assert_called_once_with( + FeatureGroupName="my-fg", NextToken="abc123" + ) + + def test_describe_feature_group_omits_none_next_token(self, session_with_mock_client): + """Test describe_feature_group omits NextToken when None.""" + session = session_with_mock_client + session.sagemaker_client.describe_feature_group.return_value = {} + + session.describe_feature_group("my-fg", next_token=None) + + call_kwargs = session.sagemaker_client.describe_feature_group.call_args[1] + assert "NextToken" not in call_kwargs + + # --- update_feature_group --- + + def test_update_feature_group_all_params(self, session_with_mock_client): + """Test update_feature_group with all optional params provided.""" + session = session_with_mock_client + expected = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123:feature-group/my-fg"} + session.sagemaker_client.update_feature_group.return_value = expected + + additions = [{"FeatureName": "new_feat", "FeatureType": "String"}] + online_cfg = {"EnableOnlineStore": True} + throughput_cfg = {"ThroughputMode": "OnDemand"} + + result = session.update_feature_group( + "my-fg", + feature_additions=additions, + online_store_config=online_cfg, + throughput_config=throughput_cfg, + ) + + session.sagemaker_client.update_feature_group.assert_called_once_with( + FeatureGroupName="my-fg", + FeatureAdditions=additions, + OnlineStoreConfig=online_cfg, + ThroughputConfig=throughput_cfg, + ) + assert result == expected + + def test_update_feature_group_omits_none_params(self, session_with_mock_client): + """Test update_feature_group omits None optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_group.return_value = {} + + session.update_feature_group("my-fg") + + call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] + assert call_kwargs == {"FeatureGroupName": "my-fg"} + + def test_update_feature_group_partial_params(self, session_with_mock_client): + """Test update_feature_group with only some optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_group.return_value = {} + + throughput_cfg = {"ThroughputMode": "Provisioned"} + session.update_feature_group("my-fg", throughput_config=throughput_cfg) + + call_kwargs = session.sagemaker_client.update_feature_group.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "ThroughputConfig": throughput_cfg, + } + + # --- list_feature_groups --- + + def test_list_feature_groups_no_params(self, session_with_mock_client): + """Test list_feature_groups with no filters delegates with empty args.""" + session = session_with_mock_client + expected = {"FeatureGroupSummaries": []} + session.sagemaker_client.list_feature_groups.return_value = expected + + result = session.list_feature_groups() + + session.sagemaker_client.list_feature_groups.assert_called_once_with() + assert result == expected + + def test_list_feature_groups_all_params(self, session_with_mock_client): + """Test list_feature_groups with all params provided.""" + session = session_with_mock_client + session.sagemaker_client.list_feature_groups.return_value = {} + + session.list_feature_groups( + name_contains="test", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after="2024-01-01", + creation_time_before="2024-12-31", + sort_order="Ascending", + sort_by="Name", + max_results=10, + next_token="token123", + ) + + session.sagemaker_client.list_feature_groups.assert_called_once_with( + NameContains="test", + FeatureGroupStatusEquals="Created", + OfflineStoreStatusEquals="Active", + CreationTimeAfter="2024-01-01", + CreationTimeBefore="2024-12-31", + SortOrder="Ascending", + SortBy="Name", + MaxResults=10, + NextToken="token123", + ) + + def test_list_feature_groups_omits_none_params(self, session_with_mock_client): + """Test list_feature_groups omits None params.""" + session = session_with_mock_client + session.sagemaker_client.list_feature_groups.return_value = {} + + session.list_feature_groups(name_contains="test", max_results=5) + + call_kwargs = session.sagemaker_client.list_feature_groups.call_args[1] + assert call_kwargs == {"NameContains": "test", "MaxResults": 5} + + # --- update_feature_metadata --- + + def test_update_feature_metadata_all_params(self, session_with_mock_client): + """Test update_feature_metadata with all optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + additions = [{"Key": "team", "Value": "ml"}] + removals = [{"Key": "deprecated"}] + + result = session.update_feature_metadata( + "my-fg", + "my-feature", + description="Updated desc", + parameter_additions=additions, + parameter_removals=removals, + ) + + session.sagemaker_client.update_feature_metadata.assert_called_once_with( + FeatureGroupName="my-fg", + FeatureName="my-feature", + Description="Updated desc", + ParameterAdditions=additions, + ParameterRemovals=removals, + ) + assert result == {} + + def test_update_feature_metadata_omits_none_params(self, session_with_mock_client): + """Test update_feature_metadata omits None optional params.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + session.update_feature_metadata("my-fg", "my-feature") + + call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "FeatureName": "my-feature", + } + + def test_update_feature_metadata_partial_params(self, session_with_mock_client): + """Test update_feature_metadata with only description.""" + session = session_with_mock_client + session.sagemaker_client.update_feature_metadata.return_value = {} + + session.update_feature_metadata("my-fg", "my-feature", description="New desc") + + call_kwargs = session.sagemaker_client.update_feature_metadata.call_args[1] + assert call_kwargs == { + "FeatureGroupName": "my-fg", + "FeatureName": "my-feature", + "Description": "New desc", + } + + # --- describe_feature_metadata --- + + def test_describe_feature_metadata(self, session_with_mock_client): + """Test describe_feature_metadata delegates and returns response.""" + session = session_with_mock_client + expected = {"FeatureGroupName": "my-fg", "FeatureName": "my-feature"} + session.sagemaker_client.describe_feature_metadata.return_value = expected + + result = session.describe_feature_metadata("my-fg", "my-feature") + + session.sagemaker_client.describe_feature_metadata.assert_called_once_with( + FeatureGroupName="my-fg", FeatureName="my-feature" + ) + assert result == expected + +MODULE = "sagemaker.core.helper.session_helper" + + +class TestCreateFeatureGroup: + """Test cases for create_feature_group session method.""" + + @pytest.fixture + def session(self): + """Create a Session with a mocked sagemaker_client.""" + mock_boto_session = Mock() + mock_boto_session.region_name = "us-west-2" + mock_boto_session.client.return_value = Mock() + mock_boto_session.resource.return_value = Mock() + session = Session(boto_session=mock_boto_session) + session.sagemaker_client = Mock() + return session + + @pytest.fixture + def base_args(self): + """Minimal required arguments for create_feature_group.""" + return dict( + feature_group_name="my-fg", + record_identifier_name="record_id", + event_time_feature_name="event_time", + feature_definitions=[{"FeatureName": "f1", "FeatureType": "String"}], + ) + + # --- Full parameter pass-through --- + + def test_create_feature_group_all_params(self, session, base_args): + """Test that all parameters are passed through to sagemaker_client.""" + role = "arn:aws:iam::123456789012:role/Role" + online_cfg = {"SecurityConfig": {"KmsKeyId": "key-123"}} + offline_cfg = {"S3StorageConfig": {"S3Uri": "s3://bucket"}} + throughput_cfg = {"ThroughputMode": "ON_DEMAND"} + description = "My feature group" + tags = [{"Key": "team", "Value": "ml"}] + + expected_response = {"FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-fg"} + session.sagemaker_client.create_feature_group.return_value = expected_response + + with patch(f"{MODULE}.format_tags", return_value=tags) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=tags) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=tags), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=role), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", side_effect=[online_cfg, offline_cfg]): + + result = session.create_feature_group( + **base_args, + role_arn=role, + online_store_config=online_cfg, + offline_store_config=offline_cfg, + throughput_config=throughput_cfg, + description=description, + tags=tags, + ) + + assert result == expected_response + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["FeatureGroupName"] == "my-fg" + assert call_kwargs["RecordIdentifierFeatureName"] == "record_id" + assert call_kwargs["EventTimeFeatureName"] == "event_time" + assert call_kwargs["FeatureDefinitions"] == base_args["feature_definitions"] + assert call_kwargs["RoleArn"] == role + # EnableOnlineStore is set to True when online config is inferred + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + assert call_kwargs["OfflineStoreConfig"] == offline_cfg + assert call_kwargs["ThroughputConfig"] == throughput_cfg + assert call_kwargs["Description"] == description + assert call_kwargs["Tags"] == tags + + # --- Tag processing pipeline --- + + def test_tag_processing_pipeline_order(self, session, base_args): + """Test that tags go through format_tags -> _append_project_tags -> _append_sagemaker_config_tags.""" + raw_tags = {"team": "ml"} + formatted = [{"Key": "team", "Value": "ml"}] + with_project = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}] + with_config = [{"Key": "team", "Value": "ml"}, {"Key": "project", "Value": "p1"}, {"Key": "cfg", "Value": "v"}] + + with patch(f"{MODULE}.format_tags", return_value=formatted) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=with_project) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=with_config) as mock_cfg, \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, tags=raw_tags) + + # format_tags is called with the raw input + mock_format.assert_called_once_with(raw_tags) + # _append_project_tags receives the formatted tags + mock_proj.assert_called_once_with(formatted) + # _append_sagemaker_config_tags receives the project-appended tags + mock_cfg.assert_called_once_with(with_project, "SageMaker.FeatureGroup.Tags") + + # Final tags in the API call should be the config-appended tags + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["Tags"] == with_config + + def test_tags_none_still_processed(self, session, base_args): + """Test that None tags still go through the pipeline (format_tags handles None).""" + with patch(f"{MODULE}.format_tags", return_value=None) as mock_format, \ + patch(f"{MODULE}._append_project_tags", return_value=None) as mock_proj, \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, tags=None) + + mock_format.assert_called_once_with(None) + mock_proj.assert_called_once_with(None) + # Tags=None should be omitted from the API call via update_args + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "Tags" not in call_kwargs + + # --- role_arn resolution from config --- + + def test_role_arn_resolved_from_config_when_none(self, session, base_args): + """Test that role_arn is resolved from SageMaker Config when not provided.""" + config_role = "arn:aws:iam::123456789012:role/ConfigRole" + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=config_role) as mock_resolve, \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, role_arn=None) + + mock_resolve.assert_called_once_with( + None, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["RoleArn"] == config_role + + def test_role_arn_passed_through_when_provided(self, session, base_args): + """Test that an explicit role_arn is passed to resolve_value_from_config (which returns it).""" + explicit_role = "arn:aws:iam::123456789012:role/ExplicitRole" + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value=explicit_role) as mock_resolve, \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args, role_arn=explicit_role) + + mock_resolve.assert_called_once_with( + explicit_role, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["RoleArn"] == explicit_role + + # --- online_store_config merging and EnableOnlineStore --- + + def test_online_store_config_merged_and_enable_set(self, session, base_args): + """Test that online_store_config is merged from config and EnableOnlineStore=True is set.""" + inferred_online = {"SecurityConfig": {"KmsKeyId": "config-key"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[inferred_online, None]) as mock_update: + + session.create_feature_group(**base_args, online_store_config=None) + + # First call is for online store config + mock_update.assert_any_call( + None, FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + assert call_kwargs["OnlineStoreConfig"]["SecurityConfig"]["KmsKeyId"] == "config-key" + + def test_online_store_config_none_when_no_config(self, session, base_args): + """Test that OnlineStoreConfig is omitted when config returns None.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "OnlineStoreConfig" not in call_kwargs + + def test_online_store_config_explicit_gets_enable_set(self, session, base_args): + """Test that explicitly provided online_store_config also gets EnableOnlineStore=True.""" + explicit_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} + # update_nested_dictionary returns the merged result + merged_online = {"SecurityConfig": {"KmsKeyId": "my-key"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[merged_online, None]): + + session.create_feature_group(**base_args, online_store_config=explicit_online) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OnlineStoreConfig"]["EnableOnlineStore"] is True + + # --- offline_store_config merging --- + + def test_offline_store_config_merged_from_config(self, session, base_args): + """Test that offline_store_config is merged from SageMaker Config.""" + inferred_offline = {"S3StorageConfig": {"S3Uri": "s3://config-bucket"}} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", + side_effect=[None, inferred_offline]) as mock_update: + + session.create_feature_group(**base_args, offline_store_config=None) + + # Second call is for offline store config + mock_update.assert_any_call( + None, FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH, sagemaker_session=session + ) + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["OfflineStoreConfig"] == inferred_offline + + def test_offline_store_config_none_when_no_config(self, session, base_args): + """Test that OfflineStoreConfig is omitted when config returns None.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "OfflineStoreConfig" not in call_kwargs + + # --- None optional parameters omitted --- + + def test_none_optional_params_omitted(self, session, base_args): + """Test that None optional params (throughput, description, tags) are omitted from API call.""" + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group(**base_args) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert "ThroughputConfig" not in call_kwargs + assert "Description" not in call_kwargs + assert "Tags" not in call_kwargs + assert "OnlineStoreConfig" not in call_kwargs + assert "OfflineStoreConfig" not in call_kwargs + # Required params should still be present + assert "FeatureGroupName" in call_kwargs + assert "RecordIdentifierFeatureName" in call_kwargs + assert "EventTimeFeatureName" in call_kwargs + assert "FeatureDefinitions" in call_kwargs + assert "RoleArn" in call_kwargs + + def test_partial_optional_params(self, session, base_args): + """Test that only provided optional params appear in the API call.""" + throughput = {"ThroughputMode": "ON_DEMAND"} + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + session.create_feature_group( + **base_args, + throughput_config=throughput, + description="test desc", + ) + + call_kwargs = session.sagemaker_client.create_feature_group.call_args[1] + assert call_kwargs["ThroughputConfig"] == throughput + assert call_kwargs["Description"] == "test desc" + assert "Tags" not in call_kwargs + assert "OnlineStoreConfig" not in call_kwargs + assert "OfflineStoreConfig" not in call_kwargs + + # --- Return value --- + + def test_returns_api_response(self, session, base_args): + """Test that the method returns the sagemaker_client response.""" + expected = {"FeatureGroupArn": "arn:fg"} + session.sagemaker_client.create_feature_group.return_value = expected + + with patch(f"{MODULE}.format_tags", return_value=None), \ + patch(f"{MODULE}._append_project_tags", return_value=None), \ + patch.object(session, "_append_sagemaker_config_tags", return_value=None), \ + patch(f"{MODULE}.resolve_value_from_config", return_value="arn:role"), \ + patch(f"{MODULE}.update_nested_dictionary_with_values_from_config", return_value=None): + + result = session.create_feature_group(**base_args) + + assert result == expected diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py index 3e82262858..627de943c1 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/_input_loader.py @@ -96,8 +96,8 @@ def load_from_feature_group( sagemaker_session: Session = self.sagemaker_session or Session() feature_group_name = feature_group_data_source.name - feature_group = sagemaker_session.sagemaker_client.describe_feature_group( - FeatureGroupName=self._parse_name_from_arn(feature_group_name) + feature_group = sagemaker_session.describe_feature_group( + self._parse_name_from_arn(feature_group_name) ) logger.debug( "Called describe_feature_group with %s and received: %s", diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py index b7ae647ef7..c9039d982c 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/feature_scheduler.py @@ -40,7 +40,6 @@ ) from sagemaker.core.lineage import context from sagemaker.core.lineage._utils import get_resource_name_from_arn -from sagemaker.core.resources import FeatureGroup from sagemaker.core.remote_function.runtime_environment.runtime_environment_manager import ( RuntimeEnvironmentManager, ) @@ -771,12 +770,8 @@ def _validate_fg_lineage_resources(feature_group_name: str, sagemaker_session: S groups. """ - # TODO: Add describe_feature_group to V3 sagemaker_session so we can use - # sagemaker_session.describe_feature_group() directly instead of FeatureGroup.get(). - feature_group = FeatureGroup.get( - feature_group_name=feature_group_name, session=sagemaker_session.boto_session - ) - feature_group_creation_time = feature_group.creation_time.strftime("%s") + feature_group = sagemaker_session.describe_feature_group(feature_group_name=feature_group_name) + feature_group_creation_time = feature_group["CreationTime"].strftime("%s") feature_group_context = _get_feature_group_lineage_context_name( feature_group_name=feature_group_name, feature_group_creation_time=feature_group_creation_time, diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py index fd160eb470..55230d7c1c 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_processor/lineage/_feature_group_lineage_entity_handler.py @@ -99,9 +99,7 @@ def _describe_feature_group( Returns: Dict[str, Any]: The Feature Group details. """ - feature_group = sagemaker_session.sagemaker_client.describe_feature_group( - FeatureGroupName=feature_group_name - ) + feature_group = sagemaker_session.describe_feature_group(feature_group_name) logger.debug( "Called describe_feature_group with %s and received: %s", feature_group_name, diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py new file mode 100644 index 0000000000..12a323f871 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_constants.py @@ -0,0 +1,401 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains constants of feature processor to be used for unit tests.""" +from __future__ import absolute_import + +import datetime +from typing import List, Sequence, Union + +from botocore.exceptions import ClientError +from mock import Mock +from pyspark.sql import DataFrame + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + BaseDataSource, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_contexts import ( + FeatureGroupContexts, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.core.lineage._api_types import ContextSource +from sagemaker.core.lineage.artifact import Artifact, ArtifactSource, ArtifactSummary +from sagemaker.core.lineage.context import Context + +PIPELINE_NAME = "test-pipeline-01" +PIPELINE_ARN = "arn:aws:sagemaker:us-west-2:12345789012:pipeline/test-pipeline-01" +CREATION_TIME = "123123123" +LAST_UPDATE_TIME = "234234234" +SAGEMAKER_SESSION_MOCK = Mock(Session) +CONTEXT_MOCK_01 = Mock(Context) +CONTEXT_MOCK_02 = Mock(Context) + + +class MockDataSource(BaseDataSource): + + data_source_unique_id = "test_source_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + +FEATURE_GROUP_DATA_SOURCE: List[FeatureGroupDataSource] = [ + FeatureGroupDataSource( + name="feature-group-01", + ), + FeatureGroupDataSource( + name="feature-group-02", + ), +] + +FEATURE_GROUP_INPUT: List[FeatureGroupContexts] = [ + FeatureGroupContexts( + name="feature-group-01", + pipeline_context_arn="feature-group-01-pipeline-context-arn", + pipeline_version_context_arn="feature-group-01-pipeline-version-context-arn", + ), + FeatureGroupContexts( + name="feature-group-02", + pipeline_context_arn="feature-group-02-pipeline-context-arn", + pipeline_version_context_arn="feature-group-02-pipeline-version-context-arn", + ), +] + +RAW_DATA_INPUT: Sequence[Union[CSVDataSource, ParquetDataSource, BaseDataSource]] = [ + CSVDataSource(s3_uri="raw-data-uri-01"), + CSVDataSource(s3_uri="raw-data-uri-02"), + ParquetDataSource(s3_uri="raw-data-uri-03"), + MockDataSource(), +] + +RAW_DATA_INPUT_ARTIFACTS: List[Artifact] = [ + Artifact(artifact_arn="artifact-01-arn"), + Artifact(artifact_arn="artifact-02-arn"), + Artifact(artifact_arn="artifact-03-arn"), + Artifact(artifact_arn="artifact-04-arn"), +] + +PIPELINE_SCHEDULE = PipelineSchedule( + schedule_name="schedule-name", + schedule_arn="schedule-arn", + schedule_expression="schedule-expression", + pipeline_name="pipeline-name", + state="state", + start_date="123123123", +) + +PIPELINE_SCHEDULE_2 = PipelineSchedule( + schedule_name="schedule-name-2", + schedule_arn="schedule-arn", + schedule_expression="schedule-expression-2", + pipeline_name="pipeline-name", + state="state-2", + start_date="234234234", +) + +PIPELINE_TRIGGER = PipelineTrigger( + trigger_name="trigger-name", + trigger_arn="trigger-arn", + pipeline_name="pipeline-name", + event_pattern="event-pattern", + state="Enabled", +) + +PIPELINE_TRIGGER_2 = PipelineTrigger( + trigger_name="trigger-name-2", + trigger_arn="trigger-arn", + pipeline_name="pipeline-name", + event_pattern="event-pattern-2", + state="Enabled", +) + +PIPELINE_TRIGGER_ARTIFACT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-trigger-trigger-name", + artifact_type="PipelineTrigger", + source={"source_uri": "trigger-arn"}, + properties=dict( + pipeline_name=PIPELINE_TRIGGER.pipeline_name, + event_pattern=PIPELINE_TRIGGER.event_pattern, + state=PIPELINE_TRIGGER.state, + ), +) + +PIPELINE_TRIGGER_ARTIFACT_SUMMARY: ArtifactSummary = ArtifactSummary( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-trigger-trigger-name", + source=ArtifactSource( + source_uri="trigger-arn", + ), + artifact_type="PipelineTrigger", + creation_time=datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), +) + +ARTIFACT_RESULT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" + }, + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 28, 21, 53, 47, 912000), +) + +SCHEDULE_ARTIFACT_RESULT: Artifact = Artifact( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" + }, + properties=dict( + pipeline_name=PIPELINE_SCHEDULE.pipeline_name, + schedule_expression=PIPELINE_SCHEDULE.schedule_expression, + state=PIPELINE_SCHEDULE.state, + start_date=PIPELINE_SCHEDULE.start_date, + ), + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 28, 21, 53, 47, 912000), +) + +ARTIFACT_SUMMARY: ArtifactSummary = ArtifactSummary( + artifact_arn="arn:aws:sagemaker:us-west-2:789975069016:artifact/7be06af3274fd01d1c18c96f97141f32", + artifact_name="sm-fs-fe-raw-data", + source=ArtifactSource( + source_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + source_types=[], + ), + artifact_type="DataSet", + creation_time=datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), +) + +TRANSFORMATION_CODE_ARTIFACT_1 = Artifact( + artifact_arn="ts-artifact-01-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + "source_types": [{"source_id_type": "Custom", "value": "1684369626"}], + }, + properties={ + "name": "test-name", + "author": "test-author", + "inclusive_start_date": "1684369626", + "state": "Active", + }, +) + +TRANSFORMATION_CODE_ARTIFACT_2 = Artifact( + artifact_arn="ts-artifact-02-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + "source_types": [{"source_id_type": "Custom", "value": "1684369626"}], + }, + properties={ + "name": "test-name", + "author": "test-author", + "inclusive_start_date": "1684369626", + "state": "Active", + }, +) + +INACTIVE_TRANSFORMATION_CODE_ARTIFACT_1 = Artifact( + artifact_arn="ts-artifact-02-arn", + artifact_name="sm-fs-fe-transformation-code", + source={ + "source_uri": "s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + "source_types": [{"source_id_type": "Custom", "value": "1684369307"}], + }, + Properties={ + "name": "test-name", + "author": "test-author", + "exclusive_end_date": "1684369626", + "inclusive_start_date": "1684369307", + "state": "Inactive", + }, +) + +VALIDATION_EXCEPTION = ClientError( + {"Error": {"Code": "ValidationException", "Message": "AssociationAlreadyExists"}}, + "Operation", +) + +RESOURCE_NOT_FOUND_EXCEPTION = ClientError( + {"Error": {"Code": "ResourceNotFound", "Message": "ResourceDoesNotExists"}}, + "Operation", +) + +NON_VALIDATION_EXCEPTION = ClientError( + {"Error": {"Code": "NonValidationException", "Message": "NonValidationError"}}, + "Operation", +) + +FEATURE_GROUP_NAME = "feature-group-name-01" +FEATURE_GROUP = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:789975069016:feature-group/feature-group-name-01", + "FeatureGroupName": "feature-group-name-01", + "RecordIdentifierFeatureName": "model_year_status", + "EventTimeFeatureName": "ingest_time", + "FeatureDefinitions": [ + {"FeatureName": "model_year_status", "FeatureType": "String"}, + {"FeatureName": "avg_mileage", "FeatureType": "String"}, + {"FeatureName": "max_mileage", "FeatureType": "String"}, + {"FeatureName": "avg_price", "FeatureType": "String"}, + {"FeatureName": "max_price", "FeatureType": "String"}, + {"FeatureName": "avg_msrp", "FeatureType": "String"}, + {"FeatureName": "max_msrp", "FeatureType": "String"}, + {"FeatureName": "ingest_time", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime(2023, 4, 27, 21, 4, 17, 926000), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://sagemaker-us-west-2-789975069016/" + "feature-store/feature-processor/" + "suryans-v2/offline-store", + "ResolvedOutputS3Uri": "s3://sagemaker-us-west-2-" + "789975069016/feature-store/" + "feature-processor/suryans-v2/" + "offline-store/789975069016/" + "sagemaker/us-west-2/" + "offline-store/" + "feature-group-name-01-" + "1682629457/data", + }, + "DisableGlueTableCreation": False, + "DataCatalogConfig": { + "TableName": "feature-group-name-01_1682629457", + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + }, + }, + "RoleArn": "arn:aws:iam::789975069016:role/service-role/AmazonSageMaker-ExecutionRole-20230421T100744", + "FeatureGroupStatus": "Created", + "OnlineStoreTotalSizeBytes": 0, + "ResponseMetadata": { + "RequestId": "8f139791-345d-4388-8d6d-40420495a3c4", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "8f139791-345d-4388-8d6d-40420495a3c4", + "content-type": "application/x-amz-json-1.1", + "content-length": "1608", + "date": "Mon, 01 May 2023 21:42:59 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE = { + "PipelineArn": "arn:aws:sagemaker:us-west-2:597217924798:pipeline/test-pipeline-26", + "PipelineName": "test-pipeline-26", + "PipelineDisplayName": "test-pipeline-26", + "PipelineDefinition": '{"Version": "2020-12-01", "Metadata": {}, ' + '"Parameters": [{"Name": "scheduled-time", "Type": "String"}], ' + '"PipelineExperimentConfig": {"ExperimentName": {"Get": "Execution.PipelineName"}, ' + '"TrialName": {"Get": "Execution.PipelineExecutionId"}}, ' + '"Steps": [{"Name": "test-pipeline-26-training-step", "Type": ' + '"Training", "Arguments": {"AlgorithmSpecification": {"TrainingInputMode": ' + '"File", "TrainingImage": "153931337802.dkr.ecr.us-west-2.amazonaws.com/' + 'sagemaker-spark-processing:3.2-cpu-py39-v1.1", "ContainerEntrypoint": ' + '["/bin/bash", "/opt/ml/input/data/sagemaker_remote_function_bootstrap/' + 'job_driver.sh", "--files", "s3://bugbash-schema-update/temp.sh", ' + '"/opt/ml/input/data/sagemaker_remote_function_bootstrap/spark_app.py"], ' + '"ContainerArguments": ["--s3_base_uri", ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"--region", "us-west-2", "--client_python_version", "3.9"]}, ' + '"OutputDataConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26"}, ' + '"StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": ' + '{"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.m5.xlarge"}, ' + '"RoleArn": "arn:aws:iam::597217924798:role/Admin", "InputDataConfig": ' + '[{"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26/' + 'sagemaker_remote_function_bootstrap", "S3DataDistributionType": ' + '"FullyReplicated"}}, "ChannelName": "sagemaker_remote_function_bootstrap"}, ' + '{"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": ' + '"s3://bugbash-schema-update/sagemaker-2.142.1.dev0-py2.py3-none-any.whl", ' + '"S3DataDistributionType": "FullyReplicated"}}, "ChannelName": ' + '"sagemaker_wheel_file"}], "Environment": {"AWS_DEFAULT_REGION": "us-west-2"}, ' + '"DebugHookConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"CollectionConfigurations": []},' + ' "ProfilerConfig": {"S3OutputPath": ' + '"s3://bugbash-schema-update-suryans/test-pipeline-26", ' + '"DisableProfiler": false}, "RetryStrategy": {"MaximumRetryAttempts": 1}}}]}', + "RoleArn": "arn:aws:iam::597217924798:role/Admin", + "PipelineStatus": "Active", + "CreationTime": datetime.datetime(2023, 4, 27, 9, 46, 35, 686000), + "LastModifiedTime": datetime.datetime(2023, 4, 27, 20, 27, 36, 648000), + "CreatedBy": {}, + "LastModifiedBy": {}, + "ResponseMetadata": { + "RequestId": "2075bc1c-1b34-4fe5-b7d8-7cfdf784a7d9", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "2075bc1c-1b34-4fe5-b7d8-7cfdf784a7d9", + "content-type": "application/x-amz-json-1.1", + "content-length": "2555", + "date": "Thu, 04 May 2023 00:28:35 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE_CONTEXT: Context = Context( + context_arn=f"{PIPELINE_NAME}-context-arn", + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + context_type="FeatureEngineeringPipeline", + source=ContextSource(source_uri=PIPELINE_ARN, source_types=[]), + properties={ + "PipelineName": PIPELINE_NAME, + "PipelineCreationTime": CREATION_TIME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, +) + +PIPELINE_VERSION_CONTEXT: Context = Context( + context_arn=f"{PIPELINE_NAME}-version-context-arn", + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + context_type=f"FeatureEngineeringPipelineVersion-{PIPELINE_NAME}", + source=ContextSource(source_uri=PIPELINE_ARN, source_types=LAST_UPDATE_TIME), + properties={"PipelineName": PIPELINE_NAME, "LastUpdateTime": LAST_UPDATE_TIME}, +) + +TRANSFORMATION_CODE_INPUT_1: TransformationCode = TransformationCode( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz", + author="test-author", + name="test-name", +) + +TRANSFORMATION_CODE_INPUT_2: TransformationCode = TransformationCode( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz/2", + author="test-author", + name="test-name", +) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py new file mode 100644 index 0000000000..bc725570b9 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_group_lineage_entity_handler.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from mock import patch, call + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.core.lineage.context import Context + +from test_constants import ( + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, + CONTEXT_MOCK_02, + FEATURE_GROUP, + FEATURE_GROUP_NAME, +) + + +def test_retrieve_feature_group_context_arns(): + with patch.object( + SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP + ) as fg_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn-fep" + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep-ver" + result = FeatureGroupLineageEntityHandler.retrieve_feature_group_context_arns( + feature_group_name=FEATURE_GROUP_NAME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result.name == FEATURE_GROUP_NAME + assert result.pipeline_context_arn == "context-arn-fep" + assert result.pipeline_version_context_arn == "context-arn-fep-ver" + fg_describe_method.assert_called_once_with(FEATURE_GROUP_NAME) + context_load.assert_has_calls( + [ + call( + context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + context_name=f'{FEATURE_GROUP_NAME}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline-version", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == context_load.call_count diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py new file mode 100644 index 0000000000..fe5783210a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_feature_processor_lineage.py @@ -0,0 +1,2966 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy +import datetime +from typing import Iterator, List + +import pytest +from mock import call, patch, Mock + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from sagemaker.mlops.feature_store.feature_processor.lineage.constants import ( + TRANSFORMATION_CODE_STATUS_INACTIVE, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.lineage.artifact import Artifact +from test_constants import ( + FEATURE_GROUP_DATA_SOURCE, + FEATURE_GROUP_INPUT, + LAST_UPDATE_TIME, + PIPELINE, + PIPELINE_ARN, + PIPELINE_CONTEXT, + PIPELINE_NAME, + PIPELINE_VERSION_CONTEXT, + RAW_DATA_INPUT, + RAW_DATA_INPUT_ARTIFACTS, + RESOURCE_NOT_FOUND_EXCEPTION, + SAGEMAKER_SESSION_MOCK, + SCHEDULE_ARTIFACT_RESULT, + PIPELINE_TRIGGER_ARTIFACT, + TRANSFORMATION_CODE_ARTIFACT_1, + TRANSFORMATION_CODE_ARTIFACT_2, + TRANSFORMATION_CODE_INPUT_1, + TRANSFORMATION_CODE_INPUT_2, + ARTIFACT_SUMMARY, + ARTIFACT_RESULT, +) + +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_group_lineage_entity_handler import ( + FeatureGroupLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._feature_processor_lineage import ( + FeatureProcessorLineageHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_schedule import ( + PipelineSchedule, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import ( + PipelineTrigger, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.core.lineage._api_types import AssociationSummary + +SCHEDULE_ARN = "" +SCHEDULE_EXPRESSION = "" +STATE = "" +TRIGGER_ARN = "" +EVENT_PATTERN = "" +START_DATE = datetime.datetime(2023, 4, 28, 21, 53, 47, 912000) +TAGS = [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("scheduler").return_value = Mock() + return Mock(Session, boto_session=boto_session) + + +@pytest.fixture +def event_bridge_scheduler_helper(sagemaker_session): + return EventBridgeSchedulerHelper( + sagemaker_session, sagemaker_session.boto_session.client("scheduler") + ) + + +def test_create_lineage_when_no_lineage_exists_with_fg_only(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + ): + lineage_handler.create_lineage() + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_not_called() + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=[], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_lineage_when_no_lineage_exists_with_raw_data_only(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_called_once_with( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=[], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_no_lineage_exists_with_fg_and_raw_data_with_tags(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_no_lineage_exists_with_no_transformation_code(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method, + patch.object( + PipelineLineageEntityHandler, + "create_pipeline_context", + return_value=PIPELINE_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + [], + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=None, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_not_called() + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_not_called() + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_not_called() + + +def test_create_lineage_when_already_exist_with_no_version_change(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as create_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_pipeline_version_context_method.assert_not_called() + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_raw_data(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=[RAW_DATA_INPUT[0], RAW_DATA_INPUT[1]] + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 2 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=[RAW_DATA_INPUT_ARTIFACTS[0], RAW_DATA_INPUT_ARTIFACTS[1]], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_input_fg(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + [FEATURE_GROUP_DATA_SOURCE[0]], + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[FEATURE_GROUP_INPUT[0], FEATURE_GROUP_INPUT[0]], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=[FEATURE_GROUP_INPUT[0]], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_output_fg(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[1].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[1], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_1, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + add_upstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_downstream_feature_group_data_associations_method.assert_called_once_with( + feature_group_output=FEATURE_GROUP_INPUT[1], + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_raw_data_associations_method.assert_called_once_with( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert pipeline_context.properties["LastUpdateTime"] == PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_pipeline_context_method.assert_called_once_with(pipeline_context=pipeline_context) + + add_pipeline_and_pipeline_version_association_method.assert_called_once_with( + pipeline_context_arn=PIPELINE_CONTEXT.context_arn, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_changed_transformation_code(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_last_transformation_code_as_none(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_1.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_1.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + update_transformation_code_artifact_method.assert_not_called() + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_all_previous_transformation_code_as_none(): + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=TRANSFORMATION_CODE_ARTIFACT_2, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=TRANSFORMATION_CODE_INPUT_2, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_not_called() + update_transformation_code_artifact_method.assert_not_called() + + assert pipeline_context.properties["LastUpdateTime"] == LAST_UPDATE_TIME + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + add_upstream_transformation_code_associations_method.assert_called_once_with( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_2, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_lineage_when_already_exist_with_removed_transformation_code(): + transformation_code_1 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + FeatureGroupLineageEntityHandler, + "retrieve_feature_group_context_arns", + side_effect=[ + FEATURE_GROUP_INPUT[0], + FEATURE_GROUP_INPUT[1], + FEATURE_GROUP_INPUT[0], + ], + ) as retrieve_feature_group_context_arns_method, + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + S3LineageEntityHandler, + "create_transformation_code_artifact", + return_value=None, + ) as create_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + generate_pipeline_version_upstream_transformation_code(), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, + "load_artifact_from_arn", + return_value=transformation_code_1, + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, + "update_transformation_code_artifact", + ) as update_transformation_code_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "update_pipeline_context", + ) as update_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "create_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ), + patch.object( + LineageAssociationHandler, "add_upstream_feature_group_data_associations" + ) as add_upstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_downstream_feature_group_data_associations" + ) as add_downstream_feature_group_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_raw_data_associations" + ) as add_upstream_raw_data_associations_method, + patch.object( + LineageAssociationHandler, "add_upstream_transformation_code_associations" + ) as add_upstream_transformation_code_associations_method, + patch.object( + LineageAssociationHandler, "add_pipeline_and_pipeline_version_association" + ) as add_pipeline_and_pipeline_version_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_lineage(TAGS) + + retrieve_feature_group_context_arns_method.assert_has_calls( + [ + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[1].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + feature_group_name=FEATURE_GROUP_DATA_SOURCE[0].name, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == retrieve_feature_group_context_arns_method.call_count + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=SAGEMAKER_SESSION_MOCK), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=SAGEMAKER_SESSION_MOCK), + ] + ) + assert 4 == retrieve_raw_data_artifact_method.call_count + + create_transformation_code_artifact_method.assert_called_once_with( + transformation_code=None, + pipeline_last_update_time=PIPELINE["LastModifiedTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_upstream_associations_method.assert_has_calls( + [ + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="DataSet", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_type="TransformationCode", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == list_upstream_associations_method.call_count + + list_downstream_associations_method.assert_called_once_with( + entity_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_type="FeatureGroupPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_artifact_from_arn_method.assert_called_once_with( + artifact_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + transformation_code_2 = copy.deepcopy(TRANSFORMATION_CODE_ARTIFACT_1) + transformation_code_2.properties["state"] = TRANSFORMATION_CODE_STATUS_INACTIVE + transformation_code_2.properties["exclusive_end_date"] = PIPELINE["LastModifiedTime"].strftime( + "%s" + ) + update_transformation_code_artifact_method.assert_called_once_with( + transformation_code_artifact=transformation_code_2 + ) + + update_pipeline_context_method.assert_not_called() + add_upstream_feature_group_data_associations_method.assert_not_called() + add_downstream_feature_group_data_associations_method.assert_not_called() + add_upstream_raw_data_associations_method.assert_not_called() + add_upstream_transformation_code_associations_method.assert_not_called() + add_pipeline_and_pipeline_version_association_method.assert_not_called() + + artifact_set_tags.assert_not_called() + + +def test_get_pipeline_lineage_names_when_no_lineage_exists(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + side_effect=RESOURCE_NOT_FOUND_EXCEPTION, + ) as load_pipeline_context_method: + return_value = lineage_handler.get_pipeline_lineage_names() + + assert return_value is None + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_get_pipeline_lineage_names_when_lineage_exists(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_1, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + ): + return_value = lineage_handler.get_pipeline_lineage_names() + + assert return_value == dict( + pipeline_context_name=PIPELINE_CONTEXT.context_name, + pipeline_version_context_name=PIPELINE_VERSION_CONTEXT.context_name, + ) + + load_pipeline_context_method.assert_has_calls( + [ + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == load_pipeline_context_method.call_count + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_schedule_lineage(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_schedule_artifact", + return_value=SCHEDULE_ARTIFACT_RESULT, + ) as retrieve_pipeline_schedule_artifact_method, + patch.object( + LineageAssociationHandler, + "add_upstream_schedule_associations", + ) as add_upstream_schedule_associations_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_schedule_lineage( + pipeline_name=PIPELINE_NAME, + schedule_arn=SCHEDULE_ARN, + schedule_expression=SCHEDULE_EXPRESSION, + state=STATE, + start_date=START_DATE, + tags=TAGS, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_pipeline_schedule_artifact_method.assert_called_once_with( + pipeline_schedule=PipelineSchedule( + schedule_name=PIPELINE_NAME, + schedule_arn=SCHEDULE_ARN, + schedule_expression=SCHEDULE_EXPRESSION, + pipeline_name=PIPELINE_NAME, + state=STATE, + start_date=START_DATE.strftime("%s"), + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_upstream_schedule_associations_method.assert_called_once_with( + schedule_artifact=SCHEDULE_ARTIFACT_RESULT, + pipeline_version_context_arn=PIPELINE_VERSION_CONTEXT.context_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_create_trigger_lineage(): + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + with ( + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=PIPELINE_CONTEXT, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + S3LineageEntityHandler, + "retrieve_pipeline_trigger_artifact", + return_value=PIPELINE_TRIGGER_ARTIFACT, + ) as retrieve_pipeline_trigger_artifact_method, + patch.object( + LineageAssociationHandler, + "_add_association", + ) as add_association_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + ): + lineage_handler.create_trigger_lineage( + pipeline_name=PIPELINE_NAME, + trigger_arn=TRIGGER_ARN, + event_pattern=EVENT_PATTERN, + state=STATE, + tags=TAGS, + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=PIPELINE_CONTEXT.properties["LastUpdateTime"], + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + retrieve_pipeline_trigger_artifact_method.assert_called_once_with( + pipeline_trigger=PipelineTrigger( + trigger_name=PIPELINE_NAME, + trigger_arn=TRIGGER_ARN, + pipeline_name=PIPELINE_NAME, + event_pattern=EVENT_PATTERN, + state=STATE, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + add_association_method.assert_called_once_with( + source_arn=PIPELINE_TRIGGER_ARTIFACT.artifact_arn, + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_set_tags.assert_called_once_with(TAGS) + + +def test_upsert_tags_for_lineage_resources(): + pipeline_context = copy.deepcopy(PIPELINE_CONTEXT) + mock_session = Mock(Session) + lineage_handler = FeatureProcessorLineageHandler( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + pipeline=PIPELINE, + inputs=RAW_DATA_INPUT + FEATURE_GROUP_DATA_SOURCE, + output=FEATURE_GROUP_DATA_SOURCE[0].name, + transformation_code=TRANSFORMATION_CODE_INPUT_2, + sagemaker_session=mock_session, + ) + lineage_handler.sagemaker_session.boto_session = Mock() + lineage_handler.sagemaker_session.sagemaker_client = Mock() + with ( + patch.object( + S3LineageEntityHandler, + "retrieve_raw_data_artifact", + side_effect=[ + RAW_DATA_INPUT_ARTIFACTS[0], + RAW_DATA_INPUT_ARTIFACTS[1], + RAW_DATA_INPUT_ARTIFACTS[2], + RAW_DATA_INPUT_ARTIFACTS[3], + ], + ) as retrieve_raw_data_artifact_method, + patch.object( + PipelineLineageEntityHandler, + "load_pipeline_context", + return_value=pipeline_context, + ) as load_pipeline_context_method, + patch.object( + PipelineVersionLineageEntityHandler, + "load_pipeline_version_context", + return_value=PIPELINE_VERSION_CONTEXT, + ) as load_pipeline_version_context_method, + patch.object( + LineageAssociationHandler, + "list_upstream_associations", + side_effect=[ + generate_pipeline_version_upstream_feature_group_list(), + generate_pipeline_version_upstream_raw_data_list(), + iter([]), + ], + ) as list_upstream_associations_method, + patch.object( + LineageAssociationHandler, + "list_downstream_associations", + return_value=generate_pipeline_version_downstream_feature_group(), + ) as list_downstream_associations_method, + patch.object( + S3LineageEntityHandler, "load_artifact_from_arn", return_value=ARTIFACT_RESULT + ) as load_artifact_from_arn_method, + patch.object( + S3LineageEntityHandler, "_load_artifact_from_s3_uri", return_value=ARTIFACT_SUMMARY + ) as load_artifact_from_s3_uri_method, + patch.object( + Artifact, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as artifact_set_tags, + patch.object( + Context, + "set_tags", + return_value={ + "Tags": [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + }, + ) as context_set_tags, + patch.object( + EventBridgeSchedulerHelper, "describe_schedule", return_value=dict(Arn="schedule_arn") + ) as get_event_bridge_schedule, + patch.object( + EventBridgeRuleHelper, "describe_rule", return_value=dict(Arn="rule_arn") + ) as get_event_bridge_rule, + ): + lineage_handler.upsert_tags_for_lineage_resources(TAGS) + + retrieve_raw_data_artifact_method.assert_has_calls( + [ + call(raw_data=RAW_DATA_INPUT[0], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[1], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[2], sagemaker_session=mock_session), + call(raw_data=RAW_DATA_INPUT[3], sagemaker_session=mock_session), + ] + ) + + load_pipeline_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + creation_time=PIPELINE["CreationTime"].strftime("%s"), + sagemaker_session=mock_session, + ) + + load_pipeline_version_context_method.assert_called_once_with( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=mock_session, + ) + + list_upstream_associations_method.assert_not_called() + list_downstream_associations_method.assert_not_called() + load_artifact_from_s3_uri_method.assert_has_calls( + [ + call(s3_uri="schedule_arn", sagemaker_session=mock_session), + call(s3_uri="rule_arn", sagemaker_session=mock_session), + ] + ) + get_event_bridge_schedule.assert_called_once_with(PIPELINE_NAME) + get_event_bridge_rule.assert_called_once_with(PIPELINE_NAME) + load_artifact_from_arn_method.assert_called_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, sagemaker_session=mock_session + ) + + # three raw data artifact, one schedule artifact and one trigger artifact + artifact_set_tags.assert_has_calls( + [ + call(TAGS), + call(TAGS), + call(TAGS), + call(TAGS), + call(TAGS), + ] + ) + # pipeline context and current pipeline version context + context_set_tags.assert_has_calls( + [ + call(TAGS), + call(TAGS), + ] + ) + + +def generate_pipeline_version_upstream_feature_group_list() -> Iterator[AssociationSummary]: + pipeline_version_upstream_fg: List[AssociationSummary] = list() + for feature_group in FEATURE_GROUP_INPUT: + pipeline_version_upstream_fg.append( + AssociationSummary( + source_arn=feature_group.pipeline_version_context_arn, + source_name=f"{feature_group.name}-pipeline-version", + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ) + return iter(pipeline_version_upstream_fg) + + +def generate_pipeline_version_upstream_raw_data_list() -> Iterator[AssociationSummary]: + pipeline_version_upstream_fg: List[AssociationSummary] = list() + for raw_data in RAW_DATA_INPUT_ARTIFACTS: + pipeline_version_upstream_fg.append( + AssociationSummary( + source_arn=raw_data.artifact_arn, + source_name="sm-fs-fe-raw-data", + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ) + return iter(pipeline_version_upstream_fg) + + +def generate_pipeline_version_upstream_transformation_code() -> Iterator[AssociationSummary]: + return iter( + [ + AssociationSummary( + source_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + source_name=TRANSFORMATION_CODE_ARTIFACT_1.artifact_name, + destination_arn=PIPELINE_VERSION_CONTEXT.context_arn, + destination_name=PIPELINE_VERSION_CONTEXT.context_name, + association_type="ContributedTo", + ) + ] + ) + + +def generate_pipeline_version_downstream_feature_group() -> Iterator[AssociationSummary]: + return iter( + [ + AssociationSummary( + source_arn=PIPELINE_VERSION_CONTEXT.context_arn, + source_name=PIPELINE_VERSION_CONTEXT.context_name, + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_version_context_arn, + destination_name=f"{FEATURE_GROUP_INPUT[0].name}-pipeline-version", + association_type="ContributedTo", + ) + ] + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py new file mode 100644 index 0000000000..a3d24dd0b5 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_lineage_association_handler.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +from mock import patch, call +import pytest + +from sagemaker.mlops.feature_store.feature_processor.lineage._lineage_association_handler import ( + LineageAssociationHandler, +) +from sagemaker.core.lineage.association import Association +from botocore.exceptions import ClientError + +from test_constants import ( + FEATURE_GROUP_INPUT, + RAW_DATA_INPUT_ARTIFACTS, + VALIDATION_EXCEPTION, + NON_VALIDATION_EXCEPTION, + SAGEMAKER_SESSION_MOCK, + TRANSFORMATION_CODE_ARTIFACT_1, +) + + +def test_add_upstream_feature_group_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for feature_group in FEATURE_GROUP_INPUT: + create_association_method.assert_has_calls( + [ + call( + source_arn=feature_group.pipeline_context_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(FEATURE_GROUP_INPUT) * 2 == create_association_method.call_count + + +def test_add_upstream_feature_group_data_associations_when_association_already_exists(): + with patch.object( + Association, "create", side_effect=VALIDATION_EXCEPTION + ) as create_association_method: + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for feature_group in FEATURE_GROUP_INPUT: + create_association_method.assert_has_calls( + [ + call( + source_arn=feature_group.pipeline_context_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=feature_group.pipeline_version_context_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(FEATURE_GROUP_INPUT) * 2 == create_association_method.call_count + + +def test_add_upstream_feature_group_data_associations_when_non_validation_exception(): + with patch.object(Association, "create", side_effect=NON_VALIDATION_EXCEPTION): + with pytest.raises(ClientError): + LineageAssociationHandler.add_upstream_feature_group_data_associations( + feature_group_inputs=FEATURE_GROUP_INPUT, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_add_upstream_raw_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_raw_data_associations( + raw_data_inputs=RAW_DATA_INPUT_ARTIFACTS, + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + for raw_data in RAW_DATA_INPUT_ARTIFACTS: + create_association_method.assert_has_calls( + [ + call( + source_arn=raw_data.artifact_arn, + destination_arn="pipeline-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn=raw_data.artifact_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert len(RAW_DATA_INPUT_ARTIFACTS) * 2 == create_association_method.call_count + + +def test_add_upstream_transformation_code_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_upstream_transformation_code_associations( + transformation_code_artifact=TRANSFORMATION_CODE_ARTIFACT_1, + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_called_once_with( + source_arn=TRANSFORMATION_CODE_ARTIFACT_1.artifact_arn, + destination_arn="pipeline-version-context-arn", + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_add_downstream_feature_group_data_associations(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_downstream_feature_group_data_associations( + feature_group_output=FEATURE_GROUP_INPUT[0], + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_has_calls( + [ + call( + source_arn="pipeline-context-arn", + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + source_arn="pipeline-version-context-arn", + destination_arn=FEATURE_GROUP_INPUT[0].pipeline_version_context_arn, + association_type="ContributedTo", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 2 == create_association_method.call_count + + +def test_add_pipeline_and_pipeline_version_association(): + with patch.object(Association, "create") as create_association_method: + LineageAssociationHandler.add_pipeline_and_pipeline_version_association( + pipeline_context_arn="pipeline-context-arn", + pipeline_version_context_arn="pipeline-version-context-arn", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + create_association_method.assert_called_once_with( + source_arn="pipeline-context-arn", + destination_arn="pipeline-version-context-arn", + association_type="AssociatedWith", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_list_upstream_associations(): + with patch.object(Association, "list") as list_association_method: + LineageAssociationHandler.list_upstream_associations( + entity_arn="pipeline-context-arn", + source_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_association_method.assert_called_once_with( + source_arn=None, + source_type="FeatureEngineeringPipelineVersion", + destination_arn="pipeline-context-arn", + destination_type=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_list_downstream_associations(): + with patch.object(Association, "list") as list_association_method: + LineageAssociationHandler.list_downstream_associations( + entity_arn="pipeline-context-arn", + destination_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + list_association_method.assert_called_once_with( + source_arn="pipeline-context-arn", + source_type=None, + destination_arn=None, + destination_type="FeatureEngineeringPipelineVersion", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py new file mode 100644 index 0000000000..deb76a7748 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_lineage_entity_handler.py @@ -0,0 +1,74 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_lineage_entity_handler import ( + PipelineLineageEntityHandler, +) +from sagemaker.core.lineage.context import Context +from test_constants import ( + PIPELINE_NAME, + PIPELINE_ARN, + CREATION_TIME, + LAST_UPDATE_TIME, + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, +) + + +def test_create_pipeline_context(): + with patch.object(Context, "create", return_value=CONTEXT_MOCK_01) as create_method: + result = PipelineLineageEntityHandler.create_pipeline_context( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + creation_time=CREATION_TIME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + create_method.assert_called_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + context_type="FeatureEngineeringPipeline", + source_uri=PIPELINE_ARN, + source_type=CREATION_TIME, + properties={ + "PipelineName": PIPELINE_NAME, + "PipelineCreationTime": CREATION_TIME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_load_pipeline_context(): + with patch.object(Context, "load", return_value=CONTEXT_MOCK_01) as load_method: + result = PipelineLineageEntityHandler.load_pipeline_context( + pipeline_name=PIPELINE_NAME, + creation_time=CREATION_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + load_method.assert_called_once_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{CREATION_TIME}-fep", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_update_pipeline_context(): + with patch.object(Context, "save", return_value=CONTEXT_MOCK_01): + PipelineLineageEntityHandler.update_pipeline_context(pipeline_context=CONTEXT_MOCK_01) + CONTEXT_MOCK_01.save.assert_called_once() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py new file mode 100644 index 0000000000..c936c3c164 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_trigger.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_trigger import PipelineTrigger + + +def test_pipeline_trigger(): + + trigger = PipelineTrigger( + trigger_name="test_trigger", + trigger_arn="test_arn", + event_pattern="test_pattern", + pipeline_name="test_pipeline", + state="Enabled", + ) + + assert trigger.trigger_name == "test_trigger" + assert trigger.trigger_arn == "test_arn" + assert trigger.event_pattern == "test_pattern" + assert trigger.pipeline_name == "test_pipeline" + assert trigger.state == "Enabled" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py new file mode 100644 index 0000000000..52e65749be --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_pipeline_version_lineage_entity_handler.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import patch + +from sagemaker.mlops.feature_store.feature_processor.lineage._pipeline_version_lineage_entity_handler import ( + PipelineVersionLineageEntityHandler, +) + +from sagemaker.core.lineage.context import Context + +from test_constants import ( + PIPELINE_NAME, + PIPELINE_ARN, + LAST_UPDATE_TIME, + SAGEMAKER_SESSION_MOCK, + CONTEXT_MOCK_01, +) + + +def test_create_pipeline_version_context(): + with patch.object(Context, "create", return_value=CONTEXT_MOCK_01) as create_method: + result = PipelineVersionLineageEntityHandler.create_pipeline_version_context( + pipeline_name=PIPELINE_NAME, + pipeline_arn=PIPELINE_ARN, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + create_method.assert_called_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + context_type=f"FeatureEngineeringPipelineVersion-{PIPELINE_NAME}", + source_uri=PIPELINE_ARN, + source_type=LAST_UPDATE_TIME, + properties={ + "PipelineName": PIPELINE_NAME, + "LastUpdateTime": LAST_UPDATE_TIME, + }, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_load_pipeline_version_context(): + with patch.object(Context, "load", return_value=CONTEXT_MOCK_01) as load_method: + result = PipelineVersionLineageEntityHandler.load_pipeline_version_context( + pipeline_name=PIPELINE_NAME, + last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == CONTEXT_MOCK_01 + load_method.assert_called_once_with( + context_name=f"sm-fs-fe-{PIPELINE_NAME}-{LAST_UPDATE_TIME}-fep-ver", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py new file mode 100644 index 0000000000..8b34806f1d --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/lineage/test_s3_lineage_entity_handler.py @@ -0,0 +1,434 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import copy + +from mock import patch +from test_constants import ( + ARTIFACT_RESULT, + ARTIFACT_SUMMARY, + PIPELINE_SCHEDULE, + PIPELINE_SCHEDULE_2, + PIPELINE_TRIGGER, + PIPELINE_TRIGGER_2, + PIPELINE_TRIGGER_ARTIFACT, + PIPELINE_TRIGGER_ARTIFACT_SUMMARY, + SCHEDULE_ARTIFACT_RESULT, + TRANSFORMATION_CODE_ARTIFACT_1, + TRANSFORMATION_CODE_INPUT_1, + LAST_UPDATE_TIME, + MockDataSource, +) +from test_pipeline_lineage_entity_handler import SAGEMAKER_SESSION_MOCK + +from sagemaker.mlops.feature_store.feature_processor import CSVDataSource +from sagemaker.mlops.feature_store.feature_processor.lineage._s3_lineage_entity_handler import ( + S3LineageEntityHandler, +) +from sagemaker.mlops.feature_store.feature_processor.lineage._transformation_code import ( + TransformationCode, +) +from sagemaker.core.lineage.artifact import Artifact + +raw_data = CSVDataSource( + s3_uri="s3://sagemaker-us-west-2-789975069016/transform-2023-04-28-21-50-14-616/" + "transform-2023-04-28-21-50-14-616/output/model.tar.gz" +) + + +def test_retrieve_raw_data_artifact_when_artifact_already_exist(): + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=raw_data, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=raw_data.s3_uri, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_create_method.assert_not_called() + + +def test_retrieve_raw_data_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=raw_data, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=raw_data.s3_uri, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=raw_data.s3_uri, + artifact_type="DataSet", + artifact_name="sm-fs-fe-raw-data", + properties=None, + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_user_defined_raw_data_artifact_when_artifact_already_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_create_method.assert_not_called() + + +def test_retrieve_user_defined_raw_data_artifact_when_artifact_does_not_exist(): + data_source = MockDataSource() + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_raw_data_artifact( + raw_data=data_source, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, sagemaker_session=SAGEMAKER_SESSION_MOCK + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=data_source.data_source_unique_id, + artifact_type="DataSet", + artifact_name=data_source.data_source_name, + properties=None, + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact(): + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=TRANSFORMATION_CODE_INPUT_1, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == TRANSFORMATION_CODE_ARTIFACT_1 + + artifact_create_method.assert_called_once_with( + source_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=LAST_UPDATE_TIME)], + artifact_type="TransformationCode", + artifact_name=f"sm-fs-fe-transformation-code-{LAST_UPDATE_TIME}", + properties=dict( + name=TRANSFORMATION_CODE_INPUT_1.name, + author=TRANSFORMATION_CODE_INPUT_1.author, + state="Active", + inclusive_start_date=LAST_UPDATE_TIME, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact_when_no_author_or_name(): + transformation_code_input = TransformationCode(s3_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri) + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=transformation_code_input, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == TRANSFORMATION_CODE_ARTIFACT_1 + + artifact_create_method.assert_called_once_with( + source_uri=TRANSFORMATION_CODE_INPUT_1.s3_uri, + source_types=[dict(SourceIdType="Custom", Value=LAST_UPDATE_TIME)], + artifact_type="TransformationCode", + artifact_name=f"sm-fs-fe-transformation-code-{LAST_UPDATE_TIME}", + properties=dict( + state="Active", + inclusive_start_date=LAST_UPDATE_TIME, + ), + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_create_transformation_code_artifact_when_no_code_provided(): + with patch.object( + Artifact, "create", return_value=TRANSFORMATION_CODE_ARTIFACT_1 + ) as artifact_create_method: + + result = S3LineageEntityHandler.create_transformation_code_artifact( + transformation_code=None, + pipeline_last_update_time=LAST_UPDATE_TIME, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result is None + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object(Artifact, "load", return_value=ARTIFACT_RESULT) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + artifact_type="PipelineSchedule", + artifact_name=f"sm-fs-fe-{PIPELINE_SCHEDULE.schedule_name}", + properties=dict( + pipeline_name=PIPELINE_SCHEDULE.pipeline_name, + schedule_expression=PIPELINE_SCHEDULE.schedule_expression, + state=PIPELINE_SCHEDULE.state, + start_date=PIPELINE_SCHEDULE.start_date, + ), + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_exists(): + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=SCHEDULE_ARTIFACT_RESULT + ) as artifact_load_method: + with patch.object(SCHEDULE_ARTIFACT_RESULT, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=SCHEDULE_ARTIFACT_RESULT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == SCHEDULE_ARTIFACT_RESULT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_schedule_artifact_when_artifact_updated(): + schedule_artifact_result = copy.deepcopy(SCHEDULE_ARTIFACT_RESULT) + with patch.object(Artifact, "list", return_value=[ARTIFACT_SUMMARY]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=schedule_artifact_result + ) as artifact_load_method: + with patch.object(schedule_artifact_result, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=schedule_artifact_result + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_schedule_artifact( + pipeline_schedule=PIPELINE_SCHEDULE_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == schedule_artifact_result + assert schedule_artifact_result != SCHEDULE_ARTIFACT_RESULT + assert result.properties["pipeline_name"] == PIPELINE_SCHEDULE_2.pipeline_name + assert result.properties["schedule_expression"] == PIPELINE_SCHEDULE_2.schedule_expression + assert result.properties["state"] == PIPELINE_SCHEDULE_2.state + assert result.properties["start_date"] == PIPELINE_SCHEDULE_2.start_date + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_SCHEDULE.schedule_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_does_not_exist(): + with patch.object(Artifact, "list", return_value=[]) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_load_method: + with patch.object( + Artifact, "create", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == PIPELINE_TRIGGER_ARTIFACT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_not_called() + + artifact_create_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + artifact_type="PipelineTrigger", + artifact_name=f"sm-fs-fe-trigger-{PIPELINE_TRIGGER.trigger_name}", + properties=dict( + pipeline_name=PIPELINE_TRIGGER.pipeline_name, + event_pattern=PIPELINE_TRIGGER.event_pattern, + state=PIPELINE_TRIGGER.state, + ), + source_types=None, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_exists(): + with patch.object( + Artifact, "list", return_value=[PIPELINE_TRIGGER_ARTIFACT_SUMMARY] + ) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_load_method: + with patch.object(PIPELINE_TRIGGER_ARTIFACT, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=PIPELINE_TRIGGER_ARTIFACT + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == PIPELINE_TRIGGER_ARTIFACT + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=PIPELINE_TRIGGER_ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() + + +def test_retrieve_pipeline_trigger_artifact_when_artifact_updated(): + trigger_artifact_result = copy.deepcopy(PIPELINE_TRIGGER_ARTIFACT) + with patch.object( + Artifact, "list", return_value=[PIPELINE_TRIGGER_ARTIFACT_SUMMARY] + ) as artifact_list_method: + with patch.object( + Artifact, "load", return_value=trigger_artifact_result + ) as artifact_load_method: + with patch.object(trigger_artifact_result, "save") as artifact_save_method: + with patch.object( + Artifact, "create", return_value=trigger_artifact_result + ) as artifact_create_method: + result = S3LineageEntityHandler.retrieve_pipeline_trigger_artifact( + pipeline_trigger=PIPELINE_TRIGGER_2, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + assert result == trigger_artifact_result + assert trigger_artifact_result != PIPELINE_TRIGGER_ARTIFACT + assert result.properties["pipeline_name"] == PIPELINE_TRIGGER_2.pipeline_name + assert result.properties["event_pattern"] == PIPELINE_TRIGGER_2.event_pattern + assert result.properties["state"] == PIPELINE_TRIGGER_2.state + + artifact_list_method.assert_called_once_with( + source_uri=PIPELINE_TRIGGER.trigger_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_load_method.assert_called_once_with( + artifact_arn=PIPELINE_TRIGGER_ARTIFACT_SUMMARY.artifact_arn, + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + artifact_save_method.assert_called_once_with() + + artifact_create_method.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py new file mode 100644 index 0000000000..25ded72266 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_config_uploader.py @@ -0,0 +1,317 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +from mock import Mock, patch +from sagemaker.mlops.feature_store.feature_processor._config_uploader import ( + ConfigUploader, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + SPARK_JAR_FILES_PATH, + SPARK_FILES_PATH, + SPARK_PY_FILES_PATH, + SAGEMAKER_WHL_FILE_S3_PATH, +) +from sagemaker.core.remote_function.job import ( + _JobSettings, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + SPARK_CONF_CHANNEL_NAME, +) +from sagemaker.core.remote_function.spark_config import SparkConfig +from sagemaker.core.helper.session_helper import Session + + +@pytest.fixture +def sagemaker_session(): + return Mock(Session) + + +@pytest.fixture +def wrapped_func(): + return Mock() + + +@pytest.fixture +def runtime_env_manager(): + mocked_runtime_env_manager = Mock() + mocked_runtime_env_manager.snapshot.return_value = "some_dependency_path" + return mocked_runtime_env_manager + + +def custom_file_filter(): + pass + + +@pytest.fixture +def remote_decorator_config(sagemaker_session): + return Mock( + _JobSettings, + sagemaker_session=sagemaker_session, + s3_root_uri="some_s3_uri", + s3_kms_key="some_kms", + spark_config=SparkConfig(), + dependencies=None, + include_local_workdir=True, + workdir_config=None, + pre_execution_commands="some_commands", + pre_execution_script="some_path", + python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, + custom_file_filter=None, + ) + + +@pytest.fixture +def config_uploader(remote_decorator_config, runtime_env_manager): + return ConfigUploader(remote_decorator_config, runtime_env_manager) + + +@pytest.fixture +def remote_decorator_config_with_filter(sagemaker_session): + return Mock( + _JobSettings, + sagemaker_session=sagemaker_session, + s3_root_uri="some_s3_uri", + s3_kms_key="some_kms", + spark_config=SparkConfig(), + dependencies=None, + include_local_workdir=True, + pre_execution_commands="some_commands", + pre_execution_script="some_path", + python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH, + custom_file_filter=custom_file_filter, + ) + + +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.StoredFunction") +def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrapped_func): + mock_stored_function.save(wrapped_func).return_value = None + config_uploader._prepare_and_upload_callable(wrapped_func, "s3_base_uri", sagemaker_session) + assert mock_stored_function.called_once_with( + s3_base_uri="s3_base_uri", + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +def test_prepare_and_upload_workspace(mock_upload, config_uploader): + remote_decorator_config = config_uploader.remote_decorator_config + s3_path = config_uploader._prepare_and_upload_workspace( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_path == mock_upload.return_value + mock_upload.assert_called_once_with( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=None, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +def test_prepare_and_upload_workspace_with_filter( + mock_job_upload, remote_decorator_config_with_filter, runtime_env_manager +): + config_uploader_with_filter = ConfigUploader( + remote_decorator_config=remote_decorator_config_with_filter, + runtime_env_manager=runtime_env_manager, + ) + remote_decorator_config = config_uploader_with_filter.remote_decorator_config + config_uploader_with_filter._prepare_and_upload_workspace( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=remote_decorator_config_with_filter.custom_file_filter, + ) + + mock_job_upload.assert_called_once_with( + local_dependencies_path="some/path/to/dependency", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key=remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + custom_file_filter=custom_file_filter, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +def test_prepare_and_upload_runtime_scripts(mock_upload, config_uploader): + s3_path = config_uploader._prepare_and_upload_runtime_scripts( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_path == mock_upload.return_value + mock_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +def test_prepare_and_upload_spark_dependent_files(mock_upload, config_uploader): + s3_paths = config_uploader._prepare_and_upload_spark_dependent_files( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + assert s3_paths == mock_upload.return_value + mock_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key=config_uploader.remote_decorator_config.s3_kms_key, + sagemaker_session=sagemaker_session, + ) + + +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.Channel") +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.DataSource") +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.S3DataSource") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.StoredFunction") +def test_prepare_step_input_channel( + mock_upload_callable, + mock_script_upload, + mock_dependency_upload, + mock_spark_dependency_upload, + mock_s3_data_source, + mock_data_source, + mock_channel, + config_uploader, + wrapped_func, +): + ( + input_data_config, + spark_dependency_paths, + ) = config_uploader.prepare_step_input_channel_for_spark_mode( + wrapped_func, + config_uploader.remote_decorator_config.s3_root_uri, + sagemaker_session, + ) + remote_decorator_config = config_uploader.remote_decorator_config + + assert mock_upload_callable.called_once_with(wrapped_func) + + mock_script_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + ) + + mock_dependency_upload.assert_called_once_with( + local_dependencies_path="some_dependency_path", + include_local_workdir=True, + pre_execution_commands=remote_decorator_config.pre_execution_commands, + pre_execution_script_local_path=remote_decorator_config.pre_execution_script, + s3_base_uri=remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + custom_file_filter=None, + ) + + mock_spark_dependency_upload.assert_called_once_with( + spark_config=config_uploader.remote_decorator_config.spark_config, + s3_base_uri=config_uploader.remote_decorator_config.s3_root_uri, + s3_kms_key="some_kms", + sagemaker_session=sagemaker_session, + ) + + # Verify input_data_config is a list of Channel objects + assert isinstance(input_data_config, list) + # 3 channels: runtime scripts, workspace, spark conf + assert len(input_data_config) == 3 + + # Verify each channel was constructed with the correct data + # Channel 1: runtime scripts + mock_s3_data_source.assert_any_call( + s3_uri="some_s3_uri", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + # Channel 2: workspace + mock_s3_data_source.assert_any_call( + s3_uri=f"{config_uploader.remote_decorator_config.s3_root_uri}/sm_rf_user_ws", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + # Channel 3: spark conf + mock_s3_data_source.assert_any_call( + s3_uri="path_d", + s3_data_type="S3Prefix", + s3_data_distribution_type="FullyReplicated", + ) + + assert mock_s3_data_source.call_count == 3 + assert mock_data_source.call_count == 3 + assert mock_channel.call_count == 3 + + # Verify channel names and input_mode + channel_call_kwargs = [call.kwargs for call in mock_channel.call_args_list] + channel_names = [kw["channel_name"] for kw in channel_call_kwargs] + assert RUNTIME_SCRIPTS_CHANNEL_NAME in channel_names + assert REMOTE_FUNCTION_WORKSPACE in channel_names + assert SPARK_CONF_CHANNEL_NAME in channel_names + for kw in channel_call_kwargs: + assert kw["input_mode"] == "File" + + assert spark_dependency_paths == { + SPARK_JAR_FILES_PATH: "path_a", + SPARK_PY_FILES_PATH: "path_b", + SPARK_FILES_PATH: "path_c", + } diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py new file mode 100644 index 0000000000..e69499c5b1 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_helpers.py @@ -0,0 +1,166 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import json + +from dateutil.tz import tzlocal +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + +INPUT_S3_URI = "s3://bucket/prefix/" +INPUT_FEATURE_GROUP_NAME = "input-fg" +INPUT_FEATURE_GROUP_ARN = "arn:aws:sagemaker:us-west-2:12345789012:feature-group/input-fg" +INPUT_FEATURE_GROUP_S3_URI = "s3://bucket/input-fg/" +INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI = ( + "s3://bucket/input-fg/feature-store/12345789012/" + "sagemaker/us-west-2/offline-store/input-fg-12345/data" +) + +FEATURE_GROUP_DATA_SOURCE = FeatureGroupDataSource(name=INPUT_FEATURE_GROUP_ARN) +S3_DATA_SOURCE = CSVDataSource(s3_uri=INPUT_S3_URI) +FEATURE_PROCESSOR_INPUTS = [FEATURE_GROUP_DATA_SOURCE, S3_DATA_SOURCE] +OUTPUT_FEATURE_GROUP_ARN = "arn:aws:sagemaker:us-west-2:12345789012:feature-group/output-fg" + +FEATURE_GROUP_SYSTEM_PARAMS = { + "feature_group_name": "input-fg", + "online_store_enabled": True, + "offline_store_enabled": False, + "offline_store_resolved_s3_uri": None, +} +SYSTEM_PARAMS = {"system": {"scheduled_time": "2023-03-25T02:01:26Z"}} +USER_INPUT_PARAMS = { + "some-key": "some-value", + "some-other-key": {"some-key": "some-value"}, +} + +DATA_SOURCE_UNIQUE_ID_TOO_LONG = """ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\ +""" + +DESCRIBE_FEATURE_GROUP_RESPONSE = { + "FeatureGroupArn": INPUT_FEATURE_GROUP_ARN, + "FeatureGroupName": INPUT_FEATURE_GROUP_NAME, + "RecordIdentifierFeatureName": "id", + "EventTimeFeatureName": "ingest_time", + "FeatureDefinitions": [ + {"FeatureName": "id", "FeatureType": "String"}, + {"FeatureName": "model", "FeatureType": "String"}, + {"FeatureName": "model_year", "FeatureType": "String"}, + {"FeatureName": "status", "FeatureType": "String"}, + {"FeatureName": "mileage", "FeatureType": "String"}, + {"FeatureName": "price", "FeatureType": "String"}, + {"FeatureName": "msrp", "FeatureType": "String"}, + {"FeatureName": "ingest_time", "FeatureType": "Fractional"}, + ], + "CreationTime": datetime.datetime(2023, 3, 29, 19, 15, 47, 20000, tzinfo=tzlocal()), + "OnlineStoreConfig": {"EnableOnlineStore": True}, + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": INPUT_FEATURE_GROUP_S3_URI, + "ResolvedOutputS3Uri": INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI, + }, + "DisableGlueTableCreation": False, + "DataCatalogConfig": { + "TableName": "input_fg_1680142547", + "Catalog": "AwsDataCatalog", + "Database": "sagemaker_featurestore", + }, + }, + "RoleArn": "arn:aws:iam::12345789012:role/role-name", + "FeatureGroupStatus": "Created", + "OnlineStoreTotalSizeBytes": 12345, + "ResponseMetadata": { + "RequestId": "d36d3647-1632-4f4e-9f7c-2a4e38e4c6f8", + "HTTPStatusCode": 200, + "HTTPHeaders": { + "x-amzn-requestid": "d36d3647-1632-4f4e-9f7c-2a4e38e4c6f8", + "content-type": "application/x-amz-json-1.1", + "content-length": "1311", + "date": "Fri, 31 Mar 2023 01:05:49 GMT", + }, + "RetryAttempts": 0, + }, +} + +PIPELINE = { + "PipelineArn": "some_pipeline_arn", + "RoleArn": "some_execution_role_arn", + "CreationTime": datetime.datetime(2023, 3, 29, 19, 15, 47, 20000, tzinfo=tzlocal()), + "PipelineDefinition": json.dumps( + { + "Steps": [ + { + "RetryPolicies": [ + { + "BackoffRate": 2.0, + "IntervalSeconds": 1, + "MaxAttempts": 5, + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + }, + { + "BackoffRate": 2.0, + "IntervalSeconds": 1, + "MaxAttempts": 5, + "ExceptionType": [ + "SageMaker.JOB_INTERNAL_ERROR", + "SageMaker.CAPACITY_ERROR", + "SageMaker.RESOURCE_LIMIT", + ], + }, + ] + } + ] + } + ), +} + + +def create_fp_config( + inputs=None, + output=OUTPUT_FEATURE_GROUP_ARN, + mode=FeatureProcessorMode.PYSPARK, + target_stores=None, + enable_ingestion=True, + parameters=None, + spark_config=None, +): + """Helper method to create a FeatureProcessorConfig with fewer arguments.""" + + return FeatureProcessorConfig.create( + inputs=inputs or FEATURE_PROCESSOR_INPUTS, + output=output, + mode=mode, + target_stores=target_stores, + enable_ingestion=enable_ingestion, + parameters=parameters, + spark_config=spark_config, + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py new file mode 100644 index 0000000000..51637fa979 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_data_source.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor._data_source import PySparkDataSource + + +def test_pyspark_data_source(): + class TestDataSource(PySparkDataSource): + + data_source_unique_id = "test_unique_id" + data_source_name = "test_source_name" + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestDataSource() + + assert test_data_source.data_source_name == "test_source_name" + assert test_data_source.data_source_unique_id == "test_unique_id" + assert test_data_source.read_data(spark=None, params=None) is None diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py new file mode 100644 index 0000000000..4aff330087 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_env.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json + +from mock import mock_open, patch +import pytest +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper + +SINGLE_NODE_RESOURCE_CONFIG = { + "current_host": "algo-1", + "current_instance_type": "ml.m5.xlarge", + "current_group_name": "homogeneousCluster", + "hosts": ["algo-1"], + "instance_groups": [ + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-1"], + } + ], + "network_interface_name": "eth0", +} +MULTI_NODE_COUNT = 3 +MULTI_NODE_RESOURCE_CONFIG = { + "current_host": "algo-1", + "current_instance_type": "ml.m5.xlarge", + "current_group_name": "homogeneousCluster", + "hosts": ["algo-1", "algo-2", "algo-3"], + "instance_groups": [ + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-1"], + }, + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-2"], + }, + { + "instance_group_name": "homogeneousCluster", + "instance_type": "ml.m5.xlarge", + "hosts": ["algo-3"], + }, + ], + "network_interface_name": "eth0", +} + + +@patch("builtins.open") +def test_is_training_job(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().is_training_job() is True + + mocked_open.assert_called_once_with("/opt/ml/input/config/resourceconfig.json", "r") + + +@patch("builtins.open") +def test_is_not_training_job(mocked_open): + mocked_open.side_effect = FileNotFoundError() + + assert EnvironmentHelper().is_training_job() is False + + +@patch("builtins.open") +def test_get_instance_count_single_node(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().get_instance_count() == 1 + + +@patch("builtins.open") +def test_get_instance_count_multi_node(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(MULTI_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().get_instance_count() == MULTI_NODE_COUNT + + +@patch("builtins.open") +def test_load_training_resource_config(mocked_open): + mocked_open.side_effect = mock_open(read_data=json.dumps(SINGLE_NODE_RESOURCE_CONFIG)) + + assert EnvironmentHelper().load_training_resource_config() == SINGLE_NODE_RESOURCE_CONFIG + + +@patch("builtins.open") +def test_load_training_resource_config_none(mocked_open): + mocked_open.side_effect = FileNotFoundError() + + assert EnvironmentHelper().load_training_resource_config() is None + + +@pytest.mark.parametrize( + "is_training_result", + [(True), (False)], +) +@patch("datetime.now.strftime", return_value="test_current_time") +@patch("sagemaker.mlops.feature_store.feature_processor._env.EnvironmentHelper.is_training_job") +@patch("os.environ", return_value={"scheduled_time": "test_time"}) +def get_job_scheduled_time(mock_env, mock_is_training, mock_datetime, is_training_result): + + mock_is_training.return_value = is_training_result + output_time = EnvironmentHelper().get_job_scheduled_time + + if is_training_result: + assert output_time == "test_scheduled_time" + else: + assert output_time == "test_current_time" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py new file mode 100644 index 0000000000..f44472d519 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_rule_helper.py @@ -0,0 +1,301 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from mock import Mock +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper import ( + EventBridgeRuleHelper, +) +from botocore.exceptions import ClientError +from sagemaker.mlops.feature_store.feature_processor._feature_processor_pipeline_events import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) +import pytest + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("events").return_value = Mock() + return Mock(Session, boto_session=boto_session, sagemaker_client=Mock()) + + +@pytest.fixture +def event_bridge_rule_helper(sagemaker_session): + return EventBridgeRuleHelper(sagemaker_session, sagemaker_session.boto_session.client("events")) + + +def test_put_rule_without_event_pattern(event_bridge_rule_helper): + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_rule = Mock( + return_value=dict(RuleArn="rule_arn") + ) + event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline="target_pipeline", + event_pattern=None, + state="Disabled", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_rule.assert_called_with( + Name="target_pipeline", + EventPattern=( + '{"detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], ' + '"source": ["aws.sagemaker"], "detail": {"currentPipelineExecutionStatus": ' + '["Succeeded"], "pipelineArn": ["pipeline_arn"]}}' + ), + State="Disabled", + ) + + +def test_put_rule_with_event_pattern(event_bridge_rule_helper): + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_rule = Mock( + return_value=dict(RuleArn="rule_arn") + ) + event_bridge_rule_helper.put_rule( + source_pipeline_events=source_pipeline_events, + target_pipeline="target_pipeline", + event_pattern="event_pattern", + state="Disabled", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_rule.assert_called_with( + Name="target_pipeline", + EventPattern="event_pattern", + State="Disabled", + ) + + +def test_put_targets_success(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_targets = Mock( + return_value=dict(FailedEntryCount=0) + ) + event_bridge_rule_helper.put_target( + rule_name="rule_name", + target_pipeline="target_pipeline", + target_pipeline_parameters={"param": "value"}, + role_arn="role_arn", + ) + + event_bridge_rule_helper.event_bridge_rule_client.put_targets.assert_called_with( + Rule="rule_name", + Targets=[ + { + "Id": "pipeline_name", + "Arn": "pipeline_arn", + "RoleArn": "role_arn", + "SageMakerPipelineParameters": {"PipelineParameterList": {"param": "value"}}, + } + ], + ) + + +def test_put_targets_failure(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_bridge_rule_helper.event_bridge_rule_client.put_targets = Mock( + return_value=dict( + FailedEntryCount=1, + FailedEntries=[dict(ErrorMessage="test_error_message")], + ) + ) + with pytest.raises( + Exception, match="Failed to add target pipeline to rule. Failure reason: test_error_message" + ): + event_bridge_rule_helper.put_target( + rule_name="rule_name", + target_pipeline="target_pipeline", + target_pipeline_parameters={"param": "value"}, + role_arn="role_arn", + ) + + +def test_delete_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.delete_rule = Mock() + event_bridge_rule_helper.delete_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.delete_rule.assert_called_with( + Name="rule_name" + ) + + +def test_describe_rule_success(event_bridge_rule_helper): + mock_describe_response = dict(State="ENABLED", RuleName="rule_name") + event_bridge_rule_helper.event_bridge_rule_client.describe_rule = Mock( + return_value=mock_describe_response + ) + assert event_bridge_rule_helper.describe_rule("rule_name") == mock_describe_response + + +def test_describe_rule_non_existent(event_bridge_rule_helper): + mock_describe_response = dict(State="ENABLED", RuleName="rule_name") + event_bridge_rule_helper.event_bridge_rule_client.describe_rule = Mock( + return_value=mock_describe_response, + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="describe_rule", + ), + ) + assert event_bridge_rule_helper.describe_rule("rule_name") is None + + +def test_remove_targets(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.remove_targets = Mock() + event_bridge_rule_helper.remove_targets(rule_name="rule_name", ids=["target_pipeline"]) + event_bridge_rule_helper.event_bridge_rule_client.remove_targets.assert_called_with( + Rule="rule_name", + Ids=["target_pipeline"], + ) + + +def test_enable_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.enable_rule = Mock() + event_bridge_rule_helper.enable_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.enable_rule.assert_called_with( + Name="rule_name" + ) + + +def test_disable_rule(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.disable_rule = Mock() + event_bridge_rule_helper.disable_rule("rule_name") + + event_bridge_rule_helper.event_bridge_rule_client.disable_rule.assert_called_with( + Name="rule_name" + ) + + +def test_add_tags(event_bridge_rule_helper): + event_bridge_rule_helper.event_bridge_rule_client.tag_resource = Mock() + event_bridge_rule_helper.add_tags("rule_arn", [{"key": "value"}]) + + event_bridge_rule_helper.event_bridge_rule_client.tag_resource.assert_called_with( + ResourceARN="rule_arn", Tags=[{"key": "value"}] + ) + + +def test_generate_event_pattern_from_feature_processor_pipeline_events(event_bridge_rule_helper): + event_bridge_rule_helper._generate_pipeline_arn_and_name = Mock( + return_value=dict(pipeline_arn="pipeline_arn", pipeline_name="pipeline_name") + ) + event_pattern = ( + event_bridge_rule_helper._generate_event_pattern_from_feature_processor_pipeline_events( + [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_2", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + ) + ) + + assert ( + event_pattern + == '{"detail-type": ["SageMaker Model Building Pipeline Execution Status Change"], ' + '"$or": [{"source": ["aws.sagemaker"], "detail": {"currentPipelineExecutionStatus": ' + '["Failed"], "pipelineArn": ["pipeline_arn"]}}, {"source": ["aws.sagemaker"], "detail": ' + '{"currentPipelineExecutionStatus": ["Failed"], "pipelineArn": ["pipeline_arn"]}}]}' + ) + + +def test_validate_feature_processor_pipeline_events(event_bridge_rule_helper): + events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + + with pytest.raises(ValueError, match="Pipeline names in pipeline_events must be unique."): + event_bridge_rule_helper._validate_feature_processor_pipeline_events(events) + + +def test_aggregate_pipeline_events_with_same_desired_status(event_bridge_rule_helper): + events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_1", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + FeatureProcessorPipelineEvents( + pipeline_name="test_pipeline_2", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.FAILED], + ), + ] + + assert event_bridge_rule_helper._aggregate_pipeline_events_with_same_desired_status(events) == { + (FeatureProcessorPipelineExecutionStatus.FAILED,): [ + "test_pipeline_1", + "test_pipeline_2", + ] + } + + +@pytest.mark.parametrize( + "pipeline_uri,expected_result", + [ + ( + "arn:aws:sagemaker:us-west-2:123456789012:pipeline/test-pipeline", + dict( + pipeline_arn="arn:aws:sagemaker:us-west-2:123456789012:pipeline/test-pipeline", + pipeline_name="test-pipeline", + ), + ), + ( + "test-pipeline", + dict( + pipeline_arn="test-pipeline-arn", + pipeline_name="test-pipeline", + ), + ), + ], +) +def test_generate_pipeline_arn_and_name(event_bridge_rule_helper, pipeline_uri, expected_result): + event_bridge_rule_helper.sagemaker_session.sagemaker_client.describe_pipeline = Mock( + return_value=dict(PipelineArn="test-pipeline-arn") + ) + assert event_bridge_rule_helper._generate_pipeline_arn_and_name(pipeline_uri) == expected_result diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py new file mode 100644 index 0000000000..baca05e84d --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_event_bridge_scheduler_helper.py @@ -0,0 +1,96 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from datetime import datetime +import pytest +from botocore.exceptions import ClientError + +from sagemaker.mlops.feature_store.feature_processor._event_bridge_scheduler_helper import ( + EventBridgeSchedulerHelper, +) +from mock import Mock + +from sagemaker.core.helper.session_helper import Session + +SCHEDULE_NAME = "test_schedule" +SCHEDULE_ARN = "test_schedule_arn" +NEW_SCHEDULE_ARN = "test_new_schedule_arn" +TARGET_ARN = "test_arn" +CRON_SCHEDULE = "test_cron" +STATE = "ENABLED" +ROLE = "test_role" +START_DATE = datetime.now() + + +@pytest.fixture +def sagemaker_session(): + boto_session = Mock() + boto_session.client("scheduler").return_value = Mock() + return Mock(Session, boto_session=boto_session) + + +@pytest.fixture +def event_bridge_scheduler_helper(sagemaker_session): + return EventBridgeSchedulerHelper( + sagemaker_session, sagemaker_session.boto_session.client("scheduler") + ) + + +def test_upsert_schedule_already_exists(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.event_bridge_scheduler_client.update_schedule.return_value = ( + SCHEDULE_ARN + ) + schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=SCHEDULE_NAME, + pipeline_arn=TARGET_ARN, + schedule_expression=CRON_SCHEDULE, + state=STATE, + start_date=START_DATE, + role=ROLE, + ) + assert schedule_arn == SCHEDULE_ARN + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.assert_not_called() + + +def test_upsert_schedule_not_exists(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.event_bridge_scheduler_client.update_schedule.side_effect = ( + ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="update_schedule", + ) + ) + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.return_value = ( + NEW_SCHEDULE_ARN + ) + + schedule_arn = event_bridge_scheduler_helper.upsert_schedule( + schedule_name=SCHEDULE_NAME, + pipeline_arn=TARGET_ARN, + schedule_expression=CRON_SCHEDULE, + state=STATE, + start_date=START_DATE, + role=ROLE, + ) + assert schedule_arn == NEW_SCHEDULE_ARN + event_bridge_scheduler_helper.event_bridge_scheduler_client.create_schedule.assert_called_once() + + +def test_delete_schedule(event_bridge_scheduler_helper): + event_bridge_scheduler_helper.sagemaker_session.boto_session = Mock() + event_bridge_scheduler_helper.sagemaker_session.sagemaker_client = Mock() + event_bridge_scheduler_helper.delete_schedule(schedule_name=TARGET_ARN) + event_bridge_scheduler_helper.event_bridge_scheduler_client.delete_schedule.assert_called_with( + Name=TARGET_ARN + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py new file mode 100644 index 0000000000..27a3b96231 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_factory.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ( + FeatureProcessorArgValidator, + InputValidator, + SparkUDFSignatureValidator, + InputOffsetValidator, + BaseDataSourceValidator, +) +from sagemaker.core.helper.session_helper import Session + + +def test_get_validation_chain(): + fp_config = tdh.create_fp_config(mode=FeatureProcessorMode.PYSPARK) + result = ValidatorFactory.get_validation_chain(fp_config) + + assert result.validators is not None + assert { + InputValidator, + FeatureProcessorArgValidator, + InputOffsetValidator, + BaseDataSourceValidator, + SparkUDFSignatureValidator, + } == {type(instance) for instance in result.validators} + + +def test_get_udf_wrapper(): + fp_config = tdh.create_fp_config(mode=FeatureProcessorMode.PYSPARK) + udf_wrapper = Mock(UDFWrapper) + + with patch.object( + UDFWrapperFactory, "_get_spark_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper_method: + result = UDFWrapperFactory.get_udf_wrapper(fp_config) + + assert result == udf_wrapper + get_udf_wrapper_method.assert_called_with(fp_config) + + +def test_get_udf_wrapper_invalid_mode(): + fp_config = Mock(FeatureProcessorConfig) + fp_config.mode = FeatureProcessorMode.PYTHON + fp_config.sagemaker_session = Mock(Session) + + with pytest.raises( + ValueError, + match=r"FeatureProcessorMode FeatureProcessorMode.PYTHON is not supported\.", + ): + UDFWrapperFactory.get_udf_wrapper(fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py new file mode 100644 index 0000000000..6f6c8471aa --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor.py @@ -0,0 +1,122 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._factory import ( + UDFWrapperFactory, + ValidatorFactory, +) +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper +from sagemaker.mlops.feature_store.feature_processor._validation import ValidatorChain +from sagemaker.mlops.feature_store.feature_processor.feature_processor import ( + feature_processor, +) + + +@pytest.fixture +def udf(): + return Mock(Callable) + + +@pytest.fixture +def wrapped_udf(): + return Mock() + + +@pytest.fixture +def udf_wrapper(wrapped_udf): + mock = Mock(UDFWrapper) + mock.wrap.return_value = wrapped_udf + return mock + + +@pytest.fixture +def validator_chain(): + return Mock(ValidatorChain) + + +@pytest.fixture +def fp_config(): + mock = Mock(FeatureProcessorConfig) + mock.mode = FeatureProcessorMode.PYSPARK + return mock + + +def test_feature_processor(udf, udf_wrapper, validator_chain, fp_config, wrapped_udf): + with patch.object( + FeatureProcessorConfig, "create", return_value=fp_config + ) as fp_config_create_method: + with patch.object( + UDFWrapperFactory, "get_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper: + with patch.object( + ValidatorFactory, + "get_validation_chain", + return_value=validator_chain, + ) as get_validation_chain: + decorated_udf = feature_processor( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE], + output="", + )(udf) + + assert decorated_udf == wrapped_udf + + fp_config_create_method.assert_called() + get_udf_wrapper.assert_called_with(fp_config) + get_validation_chain.assert_called() + + validator_chain.validate.assert_called_with(fp_config=fp_config, udf=udf) + udf_wrapper.wrap.assert_called_with(fp_config=fp_config, udf=udf) + + assert decorated_udf.feature_processor_config == fp_config + + +def test_feature_processor_validation_fails(udf, udf_wrapper, validator_chain, fp_config): + with patch.object( + FeatureProcessorConfig, "create", return_value=fp_config + ) as fp_config_create_method: + with patch.object( + UDFWrapperFactory, "get_udf_wrapper", return_value=udf_wrapper + ) as get_udf_wrapper: + with patch.object( + ValidatorFactory, + "get_validation_chain", + return_value=validator_chain, + ) as get_validation_chain: + validator_chain.validate.side_effect = ValueError() + + # Verify validation error is raised to user. + with pytest.raises(ValueError): + feature_processor( + inputs=tdh.FEATURE_PROCESSOR_INPUTS, + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + )(udf) + + # Verify validation failure causes execution to terminate early. + # Verify FeatureProcessorConfig interactions. + fp_config_create_method.assert_called() + get_udf_wrapper.assert_called_once() + get_validation_chain.assert_called_once() + validator_chain.validate.assert_called_with(fp_config=fp_config, udf=udf) + udf_wrapper.wrap.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py new file mode 100644 index 0000000000..2914bbb63a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_config.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import attr +import pytest +import test_data_helpers as tdh + +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) + + +def test_feature_processor_config_is_immutable(): + fp_config = FeatureProcessorConfig.create( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + mode=FeatureProcessorMode.PYSPARK, + target_stores=None, + enable_ingestion=True, + parameters=None, + spark_config=None, + ) + + with pytest.raises(attr.exceptions.FrozenInstanceError): + # Only attempting one field, as FrozenInstanceError indicates all fields are frozen + # (as opposed to FrozenAttributeError). + fp_config.inputs = [] + + with pytest.raises( + TypeError, + match="'FeatureProcessorConfig' object does not support item assignment", + ): + fp_config["inputs"] = [] diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py new file mode 100644 index 0000000000..78c506313a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_processor_pipeline_events.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) + + +def test_feature_processor_pipeline_events(): + fe_pipeline_events = FeatureProcessorPipelineEvents( + pipeline_name="pipeline_name", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.EXECUTING], + ) + assert fe_pipeline_events.pipeline_name == "pipeline_name" + assert fe_pipeline_events.pipeline_execution_status == [ + FeatureProcessorPipelineExecutionStatus.EXECUTING + ] diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py new file mode 100644 index 0000000000..1cd7e381e0 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_feature_scheduler.py @@ -0,0 +1,1057 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from datetime import datetime +from typing import Callable + +import pytest +import json +from botocore.exceptions import ClientError +from mock import Mock, patch, call + +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + FeatureProcessorLineageHandler, +) +from sagemaker.mlops.feature_store.feature_processor import ( + FeatureProcessorPipelineEvents, + FeatureProcessorPipelineExecutionStatus, +) +from sagemaker.core.lineage.context import Context +from sagemaker.core.remote_function.spark_config import SparkConfig + +from sagemaker.core.helper.session_helper import Session +from sagemaker.mlops.feature_store.feature_processor._enums import FeatureProcessorMode +from sagemaker.mlops.feature_store.feature_processor._constants import ( + FEATURE_PROCESSOR_TAG_KEY, + FEATURE_PROCESSOR_TAG_VALUE, + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, + PIPELINE_NAME_MAXIMUM_LENGTH, +) +from sagemaker.mlops.feature_store.feature_processor.feature_scheduler import ( + schedule, + to_pipeline, + execute, + delete_schedule, + describe, + list_pipelines, + put_trigger, + enable_trigger, + disable_trigger, + delete_trigger, + _validate_fg_lineage_resources, + _validate_pipeline_lineage_resources, +) +from sagemaker.core.remote_function.job import ( + _JobSettings, + SPARK_APP_SCRIPT_PATH, + RUNTIME_SCRIPTS_CHANNEL_NAME, + REMOTE_FUNCTION_WORKSPACE, + ENTRYPOINT_SCRIPT_NAME, + SPARK_CONF_CHANNEL_NAME, +) +from sagemaker.core.workflow.parameters import Parameter, ParameterTypeEnum +from sagemaker.mlops.workflow.retry import ( + StepRetryPolicy, + StepExceptionTypeEnum, + SageMakerJobStepRetryPolicy, + SageMakerJobExceptionTypeEnum, +) +import test_data_helpers as tdh + +REGION = "us-west-2" +IMAGE = "image_uri" +BUCKET = "my-s3-bucket" +DEFAULT_BUCKET_PREFIX = "default_bucket_prefix" +S3_URI = f"s3://{BUCKET}/keyprefix" +DEFAULT_IMAGE = ( + "153931337802.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark-processing:3.2-cpu-py39-v1.1" +) +PIPELINE_ARN = "pipeline_arn" +SCHEDULE_ARN = "schedule_arn" +SCHEDULE_ROLE_ARN = "my_schedule_role_arn" +EXECUTION_ROLE_ARN = "my_execution_role_arn" +EVENT_BRIDGE_RULE_ARN = "arn:aws:events:us-west-2:123456789012:rule/test-rule" +VALID_SCHEDULE_STATE = "ENABLED" +INVALID_SCHEDULE_STATE = "invalid" +TEST_REGION = "us-west-2" +PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name" +PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name" +NOW = datetime.now() +SAGEMAKER_SESSION_MOCK = Mock(Session) +CONTEXT_MOCK_01 = Mock(Context) +CONTEXT_MOCK_02 = Mock(Context) +CONTEXT_MOCK_03 = Mock(Context) +FEATURE_GROUP = tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() +PIPELINE = tdh.PIPELINE.copy() +TAGS = [dict(Key="key_1", Value="value_1"), dict(Key="key_2", Value="value_2")] + + +def mock_session(): + session = Mock() + session.default_bucket.return_value = BUCKET + session.default_bucket_prefix = DEFAULT_BUCKET_PREFIX + session.expand_role.return_value = EXECUTION_ROLE_ARN + session.boto_region_name = TEST_REGION + session.sagemaker_config = None + session._append_sagemaker_config_tags.return_value = [] + session.default_bucket_prefix = None + session.sagemaker_client = Mock() + return session + + +def mock_pipeline(): + pipeline = Mock() + pipeline.describe.return_value = {"PipelineArn": PIPELINE_ARN} + pipeline.upsert.return_value = None + return pipeline + + +def mock_event_bridge_scheduler_helper(): + helper = Mock() + helper.upsert_schedule.return_value = dict(ScheduleArn=SCHEDULE_ARN) + helper.delete_schedule.return_value = None + helper.describe_schedule.return_value = { + "Arn": "some_schedule_arn", + "ScheduleExpression": "some_schedule_expression", + "StartDate": NOW, + "State": VALID_SCHEDULE_STATE, + "Target": {"Arn": "some_pipeline_arn", "RoleArn": "some_schedule_role_arn"}, + } + return helper + + +def mock_event_bridge_rule_helper(): + helper = Mock() + helper.describe_rule.return_value = { + "Arn": "some_rule_arn", + "EventPattern": "some_event_pattern", + "State": "ENABLED", + } + return helper + + +def mock_feature_processor_lineage(): + return Mock(FeatureProcessorLineageHandler) + + +@pytest.fixture +def job_function(): + return Mock(Callable) + + +@pytest.fixture +def config_uploader(): + uploader = Mock() + uploader.return_value = "some_s3_uri" + uploader.prepare_and_upload_runtime_scripts.return_value = "some_s3_uri" + return uploader + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_fg_lineage_resources", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.Pipeline", + return_value=mock_pipeline(), +) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor._config_uploader.TrainingInput") +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.TrainingStep") +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.ModelTrainer") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader" + "._prepare_and_upload_spark_dependent_files", + return_value=("path_a", "path_b", "path_c", "path_d"), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_workspace", + return_value="some_s3_uri", +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_runtime_scripts", + return_value="some_s3_uri", +) +@patch("sagemaker.mlops.feature_store.feature_processor.feature_scheduler.RuntimeEnvironmentManager") +@patch( + "sagemaker.mlops.feature_store.feature_processor._config_uploader.ConfigUploader._prepare_and_upload_callable" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.create_lineage" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.get_pipeline_lineage_names", + return_value=dict( + pipeline_context_name="pipeline-context-name", + pipeline_version_context_name="pipeline-version-context-name", + ), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.PipelineSession" +) +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.expand_role", side_effect=lambda session, role: role) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline( + get_execution_role, + expand_role, + session, + mock_pipeline_session, + mock_get_pipeline_lineage_names, + mock_create_lineage, + mock_upload_callable, + mock_runtime_manager, + mock_script_upload, + mock_dependency_upload, + mock_spark_dependency_upload, + mock_model_trainer, + mock_training_step, + mock_training_input, + mock_spark_image, + pipeline, + lineage_validator, +): + session.sagemaker_config = None + session.boto_region_name = TEST_REGION + session.expand_role.return_value = EXECUTION_ROLE_ARN + + # Configure RuntimeEnvironmentManager mock to return proper string values + mock_runtime_manager_instance = mock_runtime_manager.return_value + mock_runtime_manager_instance._current_python_version.return_value = "3.10" + mock_runtime_manager_instance.snapshot.return_value = "/tmp/snapshot" + + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + sagemaker_session=session, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", session.boto_region_name]) + + mock_feature_processor_config = Mock( + mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg" + ) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + pipeline_arn = to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + tags=[("tag_key_1", "tag_value_1"), ("tag_key_2", "tag_value_2")], + sagemaker_session=session, + ) + assert pipeline_arn == PIPELINE_ARN + + assert mock_upload_callable.called_once_with(job_function) + + mock_script_upload.assert_called_once_with( + spark_config, + f"{S3_URI}/pipeline_name", + None, + session, + ) + + mock_dependency_upload.assert_called_once_with( + "/tmp/snapshot", + True, + None, + None, + f"{S3_URI}/pipeline_name", + None, + session, + None, + ) + + mock_spark_dependency_upload.assert_called_once_with( + spark_config, + f"{S3_URI}/pipeline_name", + None, + session, + ) + + mock_model_trainer.assert_called_once() + # Verify ModelTrainer was configured correctly + model_trainer_call_kwargs = mock_model_trainer.call_args[1] + assert model_trainer_call_kwargs["training_image"] == "some_image_uri" + assert model_trainer_call_kwargs["role"] == EXECUTION_ROLE_ARN + assert model_trainer_call_kwargs["training_input_mode"] == "File" + + # Verify PipelineSession was passed to ModelTrainer + assert model_trainer_call_kwargs["sagemaker_session"] == mock_pipeline_session.return_value + mock_pipeline_session.assert_called_once_with( + boto_session=session.boto_session, + default_bucket=session.default_bucket(), + default_bucket_prefix=session.default_bucket_prefix, + ) + + # Verify Compute config + compute_arg = model_trainer_call_kwargs["compute"] + assert compute_arg.instance_type == "ml.m5.large" + assert compute_arg.instance_count == 1 + assert compute_arg.volume_size_in_gb == 30 + + # No VPC config was provided, so networking should be None + assert model_trainer_call_kwargs["networking"] is None + + # max_runtime_in_seconds defaults to 86400 in _JobSettings + stopping_condition_arg = model_trainer_call_kwargs["stopping_condition"] + assert stopping_condition_arg.max_runtime_in_seconds == 86400 + + # Verify OutputDataConfig + output_data_config_arg = model_trainer_call_kwargs["output_data_config"] + assert output_data_config_arg.s3_output_path == f"{S3_URI}/pipeline_name" + + # Verify SourceCode has a command string + source_code_arg = model_trainer_call_kwargs["source_code"] + assert source_code_arg.command is not None + assert len(source_code_arg.command) > 0 + assert "--client_python_version 3.10" in source_code_arg.command + + # No tags on _JobSettings, so tags should be None + assert model_trainer_call_kwargs["tags"] is None + + # Verify train() was called with input_data_config + mock_model_trainer.return_value.train.assert_called_once() + train_call_kwargs = mock_model_trainer.return_value.train.call_args[1] + assert "input_data_config" in train_call_kwargs + + mock_training_step.assert_called_once_with( + name="-".join(["pipeline_name", "feature-processor"]), + step_args=mock_model_trainer.return_value.train.return_value, + retry_policies=[ + StepRetryPolicy( + exception_types=[ + StepExceptionTypeEnum.SERVICE_FAULT, + StepExceptionTypeEnum.THROTTLING, + ], + max_attempts=1, + ), + SageMakerJobStepRetryPolicy( + exception_types=[ + SageMakerJobExceptionTypeEnum.INTERNAL_ERROR, + SageMakerJobExceptionTypeEnum.CAPACITY_ERROR, + SageMakerJobExceptionTypeEnum.RESOURCE_LIMIT, + ], + max_attempts=1, + ), + ], + ) + + pipeline.assert_called_once_with( + name="pipeline_name", + steps=[mock_training_step()], + sagemaker_session=session, + parameters=[Parameter(name="scheduled_time", parameter_type=ParameterTypeEnum.STRING)], + ) + + pipeline().upsert.assert_has_calls( + [ + call( + role_arn=EXECUTION_ROLE_ARN, + tags=[ + dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE), + dict(Key="tag_key_1", Value="tag_value_1"), + dict(Key="tag_key_2", Value="tag_value_2"), + ], + ), + call( + role_arn=EXECUTION_ROLE_ARN, + tags=[ + { + "Key": PIPELINE_CONTEXT_NAME_TAG_KEY, + "Value": "pipeline-context-name", + }, + { + "Key": PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY, + "Value": "pipeline-version-context-name", + }, + ], + ), + ] + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_not_wrapped_by_feature_processor(get_execution_role, session): + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + wrapped_func = Mock( + Callable, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Please wrap step parameter with feature_processor decorator in order to use to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_not_wrapped_by_remote(get_execution_role, session): + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYTHON) + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=None, + wrapped_func=job_function, + ) + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Please wrap step parameter with remote decorator in order to use to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_wrong_mode(get_execution_role, mock_spark_image, session): + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", TEST_REGION]) + + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYTHON) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYTHON + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Mode FeatureProcessorMode.PYTHON is not supported by to_pipeline API.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_pipeline_name_length_limit_exceeds( + get_execution_role, mock_spark_image, session +): + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", TEST_REGION]) + + mock_feature_processor_config = Mock(mode=FeatureProcessorMode.PYSPARK) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="Pipeline name used by feature processor should be less than 80 " + "characters. Please choose another pipeline name.", + ): + to_pipeline( + pipeline_name="".join(["a" for _ in range(PIPELINE_NAME_MAXIMUM_LENGTH + 1)]), + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch( + "sagemaker.core.remote_function.job._JobSettings._get_default_spark_image", + return_value="some_image_uri", +) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_to_pipeline_used_reserved_tags(get_execution_role, mock_spark_image, session): + session.sagemaker_config = None + session.boto_region_name = TEST_REGION + session.expand_role.return_value = EXECUTION_ROLE_ARN + spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"]) + job_settings = _JobSettings( + spark_config=spark_config, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + sagemaker_session=session, + ) + jobs_container_entrypoint = [ + "/bin/bash", + f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}", + ] + jobs_container_entrypoint.extend(["--jars", "path_a"]) + jobs_container_entrypoint.extend(["--py-files", "path_b"]) + jobs_container_entrypoint.extend(["--files", "path_c"]) + jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH]) + container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"] + container_args.extend(["--region", session.boto_region_name]) + + mock_feature_processor_config = Mock( + mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg" + ) + mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK + + wrapped_func = Mock( + Callable, + feature_processor_config=mock_feature_processor_config, + job_settings=job_settings, + wrapped_func=job_function, + ) + wrapped_func.feature_processor_config.return_value = mock_feature_processor_config + wrapped_func.job_settings.return_value = job_settings + wrapped_func.wrapped_func.return_value = job_function + + with pytest.raises( + ValueError, + match="sm-fs-fe:created-from is a reserved tag key for to_pipeline API. Please choose another tag.", + ): + to_pipeline( + pipeline_name="pipeline_name", + step=wrapped_func, + role=EXECUTION_ROLE_ARN, + max_retries=1, + tags=[("sm-fs-fe:created-from", "random")], + sagemaker_session=session, + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler" + "._get_tags_from_pipeline_to_propagate_to_lineage_resources", + return_value=TAGS, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.FeatureProcessorLineageHandler", + return_value=mock_feature_processor_lineage(), +) +def test_schedule(lineage, helper, validation, get_tags): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline = Mock( + return_value={"PipelineArn": "my:arn", "CreationTime": NOW} + ) + + schedule_arn = schedule( + schedule_expression="some_schedule", + state=VALID_SCHEDULE_STATE, + start_date=NOW, + pipeline_name=PIPELINE_ARN, + role_arn=SCHEDULE_ROLE_ARN, + sagemaker_session=session, + ) + + assert schedule_arn == SCHEDULE_ARN + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeRuleHelper", + return_value=mock_event_bridge_rule_helper(), +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_describe_both_exist(mock_scheduler_helper, mock_rule_helper): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline.return_value = PIPELINE + describe_schedule_response = describe( + pipeline_name="some_pipeline_arn", sagemaker_session=session + ) + assert describe_schedule_response == dict( + pipeline_arn="some_pipeline_arn", + pipeline_execution_role_arn="some_execution_role_arn", + max_retries=5, + schedule_arn="some_schedule_arn", + schedule_expression="some_schedule_expression", + schedule_state=VALID_SCHEDULE_STATE, + schedule_start_date=NOW.strftime(EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + schedule_role="some_schedule_role_arn", + trigger="some_rule_arn", + event_pattern="some_event_pattern", + trigger_state="ENABLED", + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeRuleHelper.describe_rule", + return_value=None, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper.describe_schedule", + return_value=None, +) +def test_describe_only_pipeline_exist(helper, mock_describe_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline.return_value = { + "PipelineArn": "some_pipeline_arn", + "RoleArn": "some_execution_role_arn", + "PipelineDefinition": json.dumps({"Steps": [{"Arguments": {}}]}), + } + helper.describe_schedule().return_value = None + describe_schedule_response = describe( + pipeline_name="some_pipeline_arn", sagemaker_session=session + ) + assert describe_schedule_response == dict( + pipeline_arn="some_pipeline_arn", + pipeline_execution_role_arn="some_execution_role_arn", + ) + + +def test_list_pipelines(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.list_contexts.return_value = { + "ContextSummaries": [ + { + "Source": { + "SourceUri": "arn:aws:sagemaker:us-west-2:12345789012:pipeline/some_pipeline_name" + } + } + ] + } + list_response = list_pipelines(session) + assert list_response == [dict(pipeline_name="some_pipeline_name")] + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_delete_schedule_both_exist(helper): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.delete_pipeline = Mock() + delete_schedule(pipeline_name=PIPELINE_ARN, sagemaker_session=session) + helper().delete_schedule.assert_called_once_with(PIPELINE_ARN) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler.EventBridgeSchedulerHelper", + return_value=mock_event_bridge_scheduler_helper(), +) +def test_delete_schedule_not_exist(helper): + helper.delete_schedule.side_effect = ClientError( + error_response={"Error": {"Code": "ResourceNotFoundException"}}, + operation_name="update_schedule", + ) + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.delete_pipeline = Mock() + delete_schedule(pipeline_name=PIPELINE_ARN, sagemaker_session=session) + helper().delete_schedule.assert_called_once_with(PIPELINE_ARN) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources", + return_value=None, +) +def test_execute(validation): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.describe_pipeline = Mock( + return_value={"PipelineArn": "my:arn", "CreationTime": NOW} + ) + session.sagemaker_client.start_pipeline_execution = Mock( + return_value={"PipelineExecutionArn": "my:arn"} + ) + execution_arn = execute( + pipeline_name="some_pipeline", execution_time=NOW, sagemaker_session=session + ) + assert execution_arn == "my:arn" + + +def test_validate_fg_lineage_resources_happy_case(): + with patch.object( + SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP + ) as fg_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02, CONTEXT_MOCK_03] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn" + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep" + type(CONTEXT_MOCK_03).context_arn = "context-arn-fep-ver" + _validate_fg_lineage_resources( + feature_group_name="some_fg", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + fg_describe_method.assert_called_once_with(feature_group_name="some_fg") + context_load.assert_has_calls( + [ + call( + context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + call( + context_name=f'{"some_fg"}-{FEATURE_GROUP["CreationTime"].strftime("%s")}' + f"-feature-group-pipeline-version", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ), + ] + ) + assert 3 == context_load.call_count + + +def test_validete_fg_lineage_resources_rnf(): + with patch.object(SAGEMAKER_SESSION_MOCK, "describe_feature_group", return_value=FEATURE_GROUP): + with patch.object( + Context, + "load", + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFound"}}, + operation_name="describe_context", + ), + ): + feature_group_name = "some_fg" + feature_group_creation_time = FEATURE_GROUP["CreationTime"].strftime("%s") + context_name = f"{feature_group_name}-{feature_group_creation_time}" + with pytest.raises( + ValueError, + match=f"Lineage resource {context_name} has not yet been created for feature group" + f" {feature_group_name} or has already been deleted. Please try again later.", + ): + _validate_fg_lineage_resources( + feature_group_name="some_fg", + sagemaker_session=SAGEMAKER_SESSION_MOCK, + ) + + +def test_validate_pipeline_lineage_resources_happy_case(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.return_value = Mock() + pipeline_name = "some_pipeline" + with patch.object( + session.sagemaker_client, "describe_pipeline", return_value=PIPELINE + ) as pipeline_describe_method: + with patch.object( + Context, "load", side_effect=[CONTEXT_MOCK_01, CONTEXT_MOCK_02] + ) as context_load: + type(CONTEXT_MOCK_01).context_arn = "context-arn" + type(CONTEXT_MOCK_01).properties = {"LastUpdateTime": NOW.strftime("%s")} + type(CONTEXT_MOCK_02).context_arn = "context-arn-fep" + _validate_pipeline_lineage_resources( + pipeline_name=pipeline_name, + sagemaker_session=session, + ) + pipeline_describe_method.assert_called_once_with(PipelineName=pipeline_name) + pipeline_creation_time = PIPELINE["CreationTime"].strftime("%s") + last_updated_time = NOW.strftime("%s") + context_load.assert_has_calls( + [ + call( + context_name=f"sm-fs-fe-{pipeline_name}-{pipeline_creation_time}-fep", + sagemaker_session=session, + ), + call( + context_name=f"sm-fs-fe-{pipeline_name}-{last_updated_time}-fep-ver", + sagemaker_session=session, + ), + ] + ) + assert 2 == context_load.call_count + + +def test_validate_pipeline_lineage_resources_rnf(): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + session.sagemaker_client.return_value = Mock() + pipeline_name = "some_pipeline" + with patch.object(session.sagemaker_client, "describe_pipeline", return_value=PIPELINE): + with patch.object( + Context, + "load", + side_effect=ClientError( + error_response={"Error": {"Code": "ResourceNotFound"}}, + operation_name="describe_context", + ), + ): + with pytest.raises( + ValueError, + match="Pipeline lineage resources have not been created yet or have" + " already been deleted. Please try again later.", + ): + _validate_pipeline_lineage_resources( + pipeline_name=pipeline_name, + sagemaker_session=session, + ) + + +@patch("sagemaker.core.remote_function.job.Session", return_value=mock_session()) +@patch("sagemaker.core.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN) +def test_remote_decorator_fields_consistency(get_execution_role, session): + expected_remote_decorator_attributes = { + "sagemaker_session", + "environment_variables", + "image_uri", + "dependencies", + "pre_execution_commands", + "pre_execution_script", + "include_local_workdir", + "instance_type", + "instance_count", + "volume_size", + "max_runtime_in_seconds", + "max_retry_attempts", + "keep_alive_period_in_seconds", + "spark_config", + "job_conda_env", + "job_name_prefix", + "encrypt_inter_container_traffic", + "enable_network_isolation", + "role", + "s3_root_uri", + "s3_kms_key", + "volume_kms_key", + "vpc_config", + "tags", + "use_spot_instances", + "max_wait_time_in_seconds", + "custom_file_filter", + "disable_output_compression", + "use_torchrun", + "use_mpirun", + "nproc_per_node", + } + + job_settings = _JobSettings( + image_uri=IMAGE, + s3_root_uri=S3_URI, + role=EXECUTION_ROLE_ARN, + include_local_workdir=True, + instance_type="ml.m5.large", + encrypt_inter_container_traffic=True, + ) + actual_attributes = {attribute for attribute, _ in job_settings.__dict__.items()} + + assert expected_remote_decorator_attributes == actual_attributes + + +@patch( + "sagemaker.mlops.feature_store.feature_processor.lineage." + "_feature_processor_lineage.FeatureProcessorLineageHandler.create_trigger_lineage" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.describe_rule", + return_value={"EventPattern": "test-pattern"}, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.add_tags" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor.feature_scheduler." + "_get_tags_from_pipeline_to_propagate_to_lineage_resources", + return_value=TAGS, +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.put_target" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.put_rule", + return_value="arn:aws:events:us-west-2:123456789012:rule/test-rule", +) +def test_put_trigger( + mock_put_rule, + mock_put_target, + mock_get_tags, + mock_add_tags, + mock_describe_rule, + mock_create_trigger_lineage, +): + session = Mock( + Session, + sagemaker_client=Mock( + describe_pipeline=Mock(return_value={"PipelineArn": "test-pipeline-arn"}) + ), + boto_session=Mock(), + ) + source_pipeline_events = [ + FeatureProcessorPipelineEvents( + pipeline_name="test-pipeline", + pipeline_execution_status=[FeatureProcessorPipelineExecutionStatus.SUCCEEDED], + ) + ] + put_trigger( + source_pipeline_events=source_pipeline_events, + target_pipeline="test-target-pipeline", + state="Enabled", + event_pattern="test-pattern", + role_arn=SCHEDULE_ROLE_ARN, + sagemaker_session=session, + ) + + mock_put_rule.assert_called_once_with( + source_pipeline_events=source_pipeline_events, + target_pipeline="test-target-pipeline", + state="Enabled", + event_pattern="test-pattern", + ) + mock_put_target.assert_called_once_with( + rule_name="test-rule", + target_pipeline="test-target-pipeline", + target_pipeline_parameters=None, + role_arn=SCHEDULE_ROLE_ARN, + ) + mock_add_tags.assert_called_once_with(rule_arn=EVENT_BRIDGE_RULE_ARN, tags=TAGS) + mock_create_trigger_lineage.assert_called_once_with( + pipeline_name="test-target-pipeline", + trigger_arn=EVENT_BRIDGE_RULE_ARN, + state="Enabled", + tags=TAGS, + event_pattern="test-pattern", + ) + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.enable_rule" +) +def test_enable_trigger(mock_enable_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + enable_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_enable_rule.assert_called_once_with(rule_name="test-pipeline") + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.disable_rule" +) +def test_disable_trigger(mock_disable_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + disable_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_disable_rule.assert_called_once_with(rule_name="test-pipeline") + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.list_targets_by_rule", + return_value=[{"Targets": [{"Id": "target_pipeline"}]}], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.remove_targets" +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._event_bridge_rule_helper.EventBridgeRuleHelper.delete_rule" +) +def test_delete_trigger(mock_delete_rule, mock_remove_targets, mock_list_targets_by_rule): + session = Mock(Session, sagemaker_client=Mock(), boto_session=Mock()) + delete_trigger(pipeline_name="test-pipeline", sagemaker_session=session) + mock_delete_rule.assert_called_once_with("test-pipeline") + mock_list_targets_by_rule.assert_called_once_with("test-pipeline") + mock_remove_targets.assert_called_once_with(rule_name="test-pipeline", ids=["target_pipeline"]) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py new file mode 100644 index 0000000000..24d20f96c5 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_loader.py @@ -0,0 +1,320 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch, call +from pyspark.sql import SparkSession, DataFrame +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + CSVDataSource, + FeatureGroupDataSource, + ParquetDataSource, + IcebergTableDataSource, +) +from sagemaker.mlops.feature_store.feature_processor._input_loader import ( + SparkDataFrameInputLoader, +) +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.core.helper.session_helper import Session + + +@pytest.fixture +def describe_fg_response(): + return tdh.DESCRIBE_FEATURE_GROUP_RESPONSE.copy() + + +@pytest.fixture +def sagemaker_session(describe_fg_response): + return Mock(Session, describe_feature_group=Mock(return_value=describe_fg_response)) + + +@pytest.fixture +def spark_session(mock_data_frame): + return Mock( + SparkSession, + read=Mock( + csv=Mock(return_value=Mock()), + parquet=Mock(return_value=mock_data_frame), + conf=Mock(set=Mock()), + ), + table=Mock(return_value=mock_data_frame), + ) + + +@pytest.fixture +def environment_helper(): + return Mock( + EnvironmentHelper, + get_job_scheduled_time=Mock(return_value="2023-05-05T15:22:57Z"), + ) + + +@pytest.fixture +def mock_data_frame(): + return Mock(DataFrame, filter=Mock()) + + +@pytest.fixture +def spark_session_factory(spark_session): + factory = Mock(SparkSessionFactory) + factory.spark_session = spark_session + factory.get_spark_session_with_iceberg_config = Mock(return_value=spark_session) + return factory + + +@pytest.fixture +def fp_config(): + return tdh.create_fp_config() + + +@pytest.fixture +def input_loader(spark_session_factory, sagemaker_session, environment_helper): + return SparkDataFrameInputLoader( + spark_session_factory, + environment_helper, + sagemaker_session, + ) + + +def test_load_from_s3_with_csv_object(input_loader: SparkDataFrameInputLoader, spark_session): + s3_data_source = CSVDataSource( + s3_uri="s3://bucket/prefix/file.csv", + csv_header=True, + csv_infer_schema=True, + ) + + input_loader.load_from_s3(s3_data_source) + + spark_session.read.csv.assert_called_with( + "s3a://bucket/prefix/file.csv", header=True, inferSchema=True + ) + + +def test_load_from_s3_with_parquet_object(input_loader, spark_session): + s3_data_source = ParquetDataSource(s3_uri="s3://bucket/prefix/file.parquet") + + input_loader.load_from_s3(s3_data_source) + + spark_session.read.parquet.assert_called_with("s3a://bucket/prefix/file.parquet") + + +@pytest.mark.parametrize( + "condition", + [(None), ("condition")], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_loader." + "SparkDataFrameInputLoader._get_iceberg_offset_filter_condition" +) +def test_load_from_iceberg_table( + mock_get_filter_condition, + condition, + input_loader, + spark_session, + spark_session_factory, + mock_data_frame, +): + iceberg_table_data_source = IcebergTableDataSource( + warehouse_s3_uri="s3://bucket/prefix/", + catalog="Catalog", + database="Database", + table="Table", + ) + mock_get_filter_condition.return_value = condition + + input_loader.load_from_iceberg_table(iceberg_table_data_source, "event_time", "start", "end") + spark_session_factory.get_spark_session_with_iceberg_config.assert_called_with( + "s3://bucket/prefix/", "catalog" + ) + spark_session.table.assert_called_with("catalog.database.table") + mock_get_filter_condition.assert_called_with("event_time", "start", "end") + + if condition: + mock_data_frame.filter.assert_called_with(condition) + else: + mock_data_frame.filter.assert_not_called() + + +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_loader.SparkDataFrameInputLoader.load_from_date_partitioned_s3" +) +def test_load_from_feature_group_with_arn( + mock_load_from_date_partitioned_s3, sagemaker_session, input_loader +): + fg_arn = tdh.INPUT_FEATURE_GROUP_ARN + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource( + name=fg_arn, input_start_offset="start", input_end_offset="end" + ) + + input_loader.load_from_feature_group(fg_data_source) + + sagemaker_session.describe_feature_group.assert_called_with(fg_name) + mock_load_from_date_partitioned_s3.assert_called_with( + ParquetDataSource(tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI), + "start", + "end", + ) + + +def test_load_from_feature_group_offline_store_not_enabled(input_loader, describe_fg_response): + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + with pytest.raises( + ValueError, + match=( + f"Input Feature Groups must have an enabled Offline Store." + f" Feature Group: {fg_name} does not have an Offline Store enabled." + ), + ): + del describe_fg_response["OfflineStoreConfig"] + input_loader.load_from_feature_group(fg_data_source) + + +def test_load_from_feature_group_with_default_table_format( + sagemaker_session, input_loader, spark_session +): + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + input_loader.load_from_feature_group(fg_data_source) + + sagemaker_session.describe_feature_group.assert_called_with(fg_name) + spark_session.read.parquet.assert_called_with( + tdh.INPUT_FEATURE_GROUP_RESOLVED_OUTPUT_S3_URI.replace("s3:", "s3a:") + ) + + +def test_load_from_feature_group_with_iceberg_table_format( + describe_fg_response, spark_session_factory, spark_session, environment_helper +): + describe_iceberg_fg_response = describe_fg_response.copy() + describe_iceberg_fg_response["OfflineStoreConfig"]["TableFormat"] = "Iceberg" + mocked_session = Mock( + Session, describe_feature_group=Mock(return_value=describe_iceberg_fg_response) + ) + mock_input_loader = SparkDataFrameInputLoader( + spark_session_factory, environment_helper, mocked_session + ) + + fg_name = tdh.INPUT_FEATURE_GROUP_NAME + fg_data_source = FeatureGroupDataSource(name=fg_name) + mock_input_loader.load_from_feature_group(fg_data_source) + + mocked_session.describe_feature_group.assert_called_with(fg_name) + spark_session.table.assert_called_with( + "awsdatacatalog.sagemaker_featurestore.input_fg_1680142547" + ) + + +@pytest.mark.parametrize( + "param", + [ + (None, None, None), + ("start", None, "event_time >= 'start_time'"), + (None, "end", "event_time < 'end_time'"), + ("start", "end", "event_time >= 'start_time' AND event_time < 'end_time'"), + ], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_offset_parser.InputOffsetParser.get_iso_format_offset_date", + side_effect=[ + "start_time", + "end_time", + ], +) +def test_get_iceberg_offset_filter_condition(mock_get_iso_date, param, input_loader): + start_offset, end_offset, expected_condition_str = param + + condition = input_loader._get_iceberg_offset_filter_condition( + "event_time", start_offset, end_offset + ) + + if start_offset or end_offset: + mock_get_iso_date.assert_has_calls([call(start_offset), call(end_offset)]) + else: + mock_get_iso_date.assert_not_called() + + assert condition == expected_condition_str + + +@pytest.mark.parametrize( + "param", + [ + (None, None, None), + ( + "start", + None, + "(year >= 'year_start') AND NOT ((year = 'year_start' AND month < 'month_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day < 'day_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day = 'day_start' AND hour < 'hour_start'))", + ), + ( + None, + "end", + "(year <= 'year_end') AND NOT ((year = 'year_end' AND month > 'month_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day > 'day_end') OR (year = 'year_end' " + "AND month = 'month_end' AND day = 'day_end' AND hour >= 'hour_end'))", + ), + ( + "start", + "end", + "(year >= 'year_start' AND year <= 'year_end') AND NOT ((year = 'year_start' AND " + "month < 'month_start') OR (year = 'year_start' AND month = 'month_start' AND day < 'day_start') OR " + "(year = 'year_start' AND month = 'month_start' AND day = 'day_start' AND hour < 'hour_start') OR " + "(year = 'year_end' AND month > 'month_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day > 'day_end') OR " + "(year = 'year_end' AND month = 'month_end' AND day = 'day_end' AND hour >= 'hour_end'))", + ), + ], +) +@patch( + "sagemaker.mlops.feature_store.feature_processor._input_offset_parser.InputOffsetParser." + "get_offset_date_year_month_day_hour", + side_effect=[ + ("year_start", "month_start", "day_start", "hour_start"), + ("year_end", "month_end", "day_end", "hour_end"), + ], +) +def test_get_s3_partitions_offset_filter_condition(mock_get_ymdh, param, input_loader): + start_offset, end_offset, expected_condition_str = param + + condition = input_loader._get_s3_partitions_offset_filter_condition(start_offset, end_offset) + + if start_offset or end_offset: + mock_get_ymdh.assert_has_calls([call(start_offset), call(end_offset)]) + else: + mock_get_ymdh.assert_not_called() + + assert condition == expected_condition_str + + +@pytest.mark.parametrize( + "condition", + [(None), ("condition")], +) +def test_load_from_date_partitioned_s3(input_loader, spark_session, mock_data_frame, condition): + input_loader._get_s3_partitions_offset_filter_condition = Mock(return_value=condition) + + input_loader.load_from_date_partitioned_s3( + ParquetDataSource("s3://path/to/file"), "start", "end" + ) + df = spark_session.read.parquet + df.assert_called_with("s3a://path/to/file") + + if condition: + mock_data_frame.filter.assert_called_with(condition) + else: + mock_data_frame.filter.assert_not_called() diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py new file mode 100644 index 0000000000..b244a4c2da --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_input_offset_parser.py @@ -0,0 +1,143 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.mlops.feature_store.feature_processor._input_offset_parser import ( + InputOffsetParser, +) +from sagemaker.mlops.feature_store.feature_processor._constants import ( + EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT, +) +from datetime import datetime +from dateutil.relativedelta import relativedelta +import pytest + + +@pytest.fixture +def input_offset_parser(): + time_spec = dict(year=2023, month=5, day=10, hour=17, minute=30, second=20) + return InputOffsetParser(now=datetime(**time_spec)) + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ("1 hour", "2023-05-10T16:30:20Z"), + ("1 day", "2023-05-09T17:30:20Z"), + ("1 month", "2023-04-10T17:30:20Z"), + ("1 year", "2022-05-10T17:30:20Z"), + ], +) +def test_get_iso_format_offset_date(param, input_offset_parser): + input_offset, expected_offset_date = param + output_offset_date = input_offset_parser.get_iso_format_offset_date(input_offset) + + assert output_offset_date == expected_offset_date + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ( + "1 hour", + datetime.strptime("2023-05-10T16:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 day", + datetime.strptime("2023-05-09T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 month", + datetime.strptime("2023-04-10T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ( + "1 year", + datetime.strptime("2022-05-10T17:30:20Z", EXECUTION_TIME_PIPELINE_PARAMETER_FORMAT), + ), + ], +) +def test_get_offset_datetime(param, input_offset_parser): + input_offset, expected_offset_datetime = param + output_offet_datetime = input_offset_parser.get_offset_datetime(input_offset) + + assert output_offet_datetime == expected_offset_datetime + + +@pytest.mark.parametrize( + "param", + [ + (None, (None, None, None, None)), + ("1 hour", ("2023", "05", "10", "16")), + ("1 day", ("2023", "05", "09", "17")), + ("1 month", ("2023", "04", "10", "17")), + ("1 year", ("2022", "05", "10", "17")), + ], +) +def test_get_offset_date_year_month_day_hour(param, input_offset_parser): + input_offset, expected_date_tuple = param + output_date_tuple = input_offset_parser.get_offset_date_year_month_day_hour(input_offset) + + assert output_date_tuple == expected_date_tuple + + +@pytest.mark.parametrize( + "param", + [ + (None, None), + ("1 hour", relativedelta(hours=-1)), + ("20 hours", relativedelta(hours=-20)), + ("1 day", relativedelta(days=-1)), + ("20 days", relativedelta(days=-20)), + ("1 month", relativedelta(months=-1)), + ("20 months", relativedelta(months=-20)), + ("1 year", relativedelta(years=-1)), + ("20 years", relativedelta(years=-20)), + ], +) +def test_parse_offset_to_timedelta(param, input_offset_parser): + input_offset, expected_deltatime = param + output_deltatime = input_offset_parser.parse_offset_to_timedelta(input_offset) + + assert output_deltatime == expected_deltatime + + +@pytest.mark.parametrize( + "param", + [ + ( + "random invalid string", + "[random invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "1 invalid string", + "[1 invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "2 days invalid string", + "[2 days invalid string] is not in a valid offset format. Please pass a valid offset e.g '1 day'.", + ), + ( + "1 second", + "[second] is not a valid offset unit. Supported units: ['hour', 'day', 'week', 'month', 'year']", + ), + ], +) +def test_parse_offset_to_timedelta_negative(param, input_offset_parser): + input_offset, expected_error_message = param + + with pytest.raises(ValueError) as e: + input_offset_parser.parse_offset_to_timedelta(input_offset) + + assert str(e.value) == expected_error_message diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py new file mode 100644 index 0000000000..1ac2042a55 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_params_loader.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +import pytest +import test_data_helpers as tdh +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._env import EnvironmentHelper +from sagemaker.mlops.feature_store.feature_processor._params_loader import ( + ParamsLoader, + SystemParamsLoader, +) + + +@pytest.fixture +def system_params_loader_mock(): + system_params_loader = Mock(SystemParamsLoader) + system_params_loader.get_system_args.return_value = tdh.SYSTEM_PARAMS + return system_params_loader + + +@pytest.fixture +def environment_checker(): + environment_checker = Mock(EnvironmentHelper) + environment_checker.is_training_job.return_value = False + environment_checker.get_job_scheduled_time = Mock(return_value="2023-05-05T15:22:57Z") + return environment_checker + + +@pytest.fixture +def params_loader(system_params_loader_mock): + return ParamsLoader(system_params_loader_mock) + + +@pytest.fixture +def system_params_loader(environment_checker): + return SystemParamsLoader(environment_checker) + + +def test_get_parameter_args(params_loader, system_params_loader_mock): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + parameters=tdh.USER_INPUT_PARAMS, + ) + + params = params_loader.get_parameter_args(fp_config) + + system_params_loader_mock.get_system_args.assert_called_once() + assert params == {"params": {**tdh.USER_INPUT_PARAMS, **tdh.SYSTEM_PARAMS}} + + +def test_get_parameter_args_with_no_user_params(params_loader, system_params_loader_mock): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + parameters=None, + ) + + params = params_loader.get_parameter_args(fp_config) + + system_params_loader_mock.get_system_args.assert_called_once() + assert params == {"params": {**tdh.SYSTEM_PARAMS}} + + +def test_get_system_arg_from_pipeline_execution(system_params_loader): + system_params = system_params_loader.get_system_args() + + assert system_params == { + "system": { + "scheduled_time": "2023-05-05T15:22:57Z", + } + } diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py new file mode 100644 index 0000000000..12ac116f53 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_spark_session_factory.py @@ -0,0 +1,175 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import feature_store_pyspark +import pytest +from mock import Mock, patch, call + +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, + SparkSessionFactory, +) + + +@pytest.fixture +def env_helper(): + return Mock( + is_training_job=Mock(return_value=False), + load_training_resource_config=Mock(return_value=None), + ) + + +def test_spark_session_factory_configuration(): + env_helper = Mock() + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) + spark_configs = dict(spark_session_factory._get_spark_configs(is_training_job=False)) + jsc_hadoop_configs = dict(spark_session_factory._get_jsc_hadoop_configs()) + + # General optimizations + assert spark_configs.get("spark.hadoop.fs.s3a.aws.credentials.provider") == ",".join( + [ + "com.amazonaws.auth.ContainerCredentialsProvider", + "com.amazonaws.auth.profile.ProfileCredentialsProvider", + "com.amazonaws.auth.DefaultAWSCredentialsProviderChain", + ] + ) + + assert spark_configs.get("spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version") == "2" + assert ( + spark_configs.get("spark.hadoop.mapreduce.fileoutputcommitter.cleanup-failures.ignored") + == "true" + ) + assert spark_configs.get("spark.hadoop.parquet.enable.summary-metadata") == "false" + + assert spark_configs.get("spark.sql.parquet.mergeSchema") == "false" + assert spark_configs.get("spark.sql.parquet.filterPushdown") == "true" + assert spark_configs.get("spark.sql.hive.metastorePartitionPruning") == "true" + + assert spark_configs.get("spark.hadoop.fs.s3a.threads.max") == "500" + assert spark_configs.get("spark.hadoop.fs.s3a.connection.maximum") == "500" + assert spark_configs.get("spark.hadoop.fs.s3a.experimental.input.fadvise") == "normal" + assert spark_configs.get("spark.hadoop.fs.s3a.block.size") == "128M" + assert spark_configs.get("spark.hadoop.fs.s3a.fast.upload.buffer") == "disk" + assert spark_configs.get("spark.hadoop.fs.trash.interval") == "0" + assert spark_configs.get("spark.port.maxRetries") == "50" + + assert spark_configs.get("spark.test.key") == "spark.test.value" + + assert jsc_hadoop_configs.get("mapreduce.fileoutputcommitter.marksuccessfuljobs") == "false" + + # Verify configurations when not running on a training job + assert ",".join(feature_store_pyspark.classpath_jars()) in spark_configs.get("spark.jars") + assert ",".join( + [ + "org.apache.hadoop:hadoop-aws:3.3.1", + "org.apache.hadoop:hadoop-common:3.3.1", + ] + ) in spark_configs.get("spark.jars.packages") + + +def test_spark_session_factory_configuration_on_training_job(): + env_helper = Mock() + spark_config = {"spark.test.key": "spark.test.value"} + spark_session_factory = SparkSessionFactory(env_helper, spark_config) + + spark_config = spark_session_factory._get_spark_configs(is_training_job=True) + assert dict(spark_config).get("spark.test.key") == "spark.test.value" + + assert all(tup[0] != "spark.jars" for tup in spark_config) + assert all(tup[0] != "spark.jars.packages" for tup in spark_config) + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory(mock_spark_context): + env_helper = Mock() + env_helper.get_instance_count.return_value = 1 + spark_session_factory = SparkSessionFactory(env_helper) + + spark_session_factory.spark_session + + _, _, kw_args = mock_spark_context.mock_calls[0] + spark_conf = kw_args["conf"] + + mock_spark_context.assert_called_once() + assert spark_conf.get("spark.master") == "local[*]" + for cfg in spark_session_factory._get_spark_configs(True): + assert spark_conf.get(cfg[0]) == cfg[1] + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory_with_iceberg_config(mock_spark_context): + mock_env_helper = Mock() + mock_spark_context.side_effect = [Mock(), Mock()] + + spark_session_factory = SparkSessionFactory(mock_env_helper) + + spark_session = spark_session_factory.spark_session + mock_conf = Mock() + spark_session.conf = mock_conf + + spark_session_with_iceberg_config = spark_session_factory.get_spark_session_with_iceberg_config( + "warehouse", "catalog" + ) + + assert spark_session is spark_session_with_iceberg_config + expected_calls = [ + call.set(cfg[0], cfg[1]) + for cfg in spark_session_factory._get_iceberg_configs("warehouse", "catalog") + ] + + mock_conf.assert_has_calls(expected_calls, any_order=False) + + +@patch("pyspark.context.SparkContext.getOrCreate") +def test_spark_session_factory_same_instance(mock_spark_context): + mock_env_helper = Mock() + mock_spark_context.side_effect = [Mock(), Mock()] + + spark_session_factory = SparkSessionFactory(mock_env_helper) + + a_reference = spark_session_factory.spark_session + another_reference = spark_session_factory.spark_session + + assert a_reference is another_reference + + +@patch("feature_store_pyspark.FeatureStoreManager.FeatureStoreManager") +def test_feature_store_manager_same_instance(mock_feature_store_manager): + mock_feature_store_manager.side_effect = [Mock(), Mock()] + + factory = FeatureStoreManagerFactory() + + assert factory.feature_store_manager is factory.feature_store_manager + + +def test_spark_session_factory_get_spark_session_with_iceberg_config(env_helper): + spark_session_factory = SparkSessionFactory(env_helper) + iceberg_configs = dict(spark_session_factory._get_iceberg_configs("s3://test/path", "Catalog")) + + assert ( + iceberg_configs.get("spark.sql.catalog.catalog") + == "smfs.shaded.org.apache.iceberg.spark.SparkCatalog" + ) + assert iceberg_configs.get("spark.sql.catalog.catalog.warehouse") == "s3://test/path" + assert ( + iceberg_configs.get("spark.sql.catalog.catalog.catalog-impl") + == "smfs.shaded.org.apache.iceberg.aws.glue.GlueCatalog" + ) + assert ( + iceberg_configs.get("spark.sql.catalog.catalog.io-impl") + == "smfs.shaded.org.apache.iceberg.aws.s3.S3FileIO" + ) + assert iceberg_configs.get("spark.sql.catalog.catalog.glue.skip-name-validation") == "true" diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py new file mode 100644 index 0000000000..561cc09e76 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_arg_provider.py @@ -0,0 +1,280 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from mock import Mock, patch +from pyspark.sql import DataFrame, SparkSession + +from sagemaker.mlops.feature_store.feature_processor._input_loader import InputLoader +from sagemaker.mlops.feature_store.feature_processor._params_loader import ParamsLoader +from sagemaker.mlops.feature_store.feature_processor._spark_factory import SparkSessionFactory +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import SparkArgProvider +from sagemaker.mlops.feature_store.feature_processor._data_source import PySparkDataSource + + +@pytest.fixture +def params_loader(): + params_loader = Mock(ParamsLoader) + params_loader.get_parameter_args = Mock(return_value={"params": {"key": "value"}}) + return params_loader + + +@pytest.fixture +def feature_group_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def s3_uri_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def base_data_source_as_spark_df(): + return Mock(DataFrame) + + +@pytest.fixture +def input_loader(feature_group_as_spark_df, s3_uri_as_spark_df): + input_loader = Mock(InputLoader) + input_loader.load_from_s3.return_value = s3_uri_as_spark_df + input_loader.load_from_feature_group.return_value = feature_group_as_spark_df + + return input_loader + + +@pytest.fixture +def spark_session(): + return Mock(SparkSession) + + +@pytest.fixture +def spark_session_factory(spark_session): + return Mock(SparkSessionFactory, spark_session=spark_session) + + +@pytest.fixture +def spark_arg_provider(params_loader, input_loader, spark_session_factory): + return SparkArgProvider(params_loader, input_loader, spark_session_factory) + + +class MockDataSource(PySparkDataSource): + + data_source_unique_id = "test_id" + data_source_name = "test_source" + + def read_data(self, spark, params) -> DataFrame: + return Mock(DataFrame) + + +def test_provide_additional_kw_args(spark_arg_provider, spark_session): + def udf(fg_input, s3_input, params, spark): + return None + + additional_kw_args = spark_arg_provider.provide_additional_kwargs(udf) + + assert additional_kw_args.keys() == {"spark"} + assert additional_kw_args["spark"] == spark_session + + +def test_not_provide_additional_kw_args(spark_arg_provider): + def udf(input, params): + return None + + additional_kw_args = spark_arg_provider.provide_additional_kwargs(udf) + + assert additional_kw_args == {} + + +def test_provide_params(spark_arg_provider, params_loader): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(fg_input, s3_input, params, spark): + return None + + params = spark_arg_provider.provide_params_arg(udf, fp_config) + + params_loader.get_parameter_args.assert_called_with(fp_config) + assert params == params_loader.get_parameter_args.return_value + + +def test_not_provide_params(spark_arg_provider, params_loader): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(fg_input, s3_input, spark): + return None + + params = spark_arg_provider.provide_params_arg(udf, fp_config) + + assert params == {} + + +def test_provide_input_args_with_no_input(spark_arg_provider): + fp_config = tdh.create_fp_config(inputs=[], output=tdh.OUTPUT_FEATURE_GROUP_ARN) + + def udf() -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, match="Expected at least one input to the user defined function." + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_extra_udf_parameters(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.INPUT_S3_URI], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, + match=r"The signature of the user defined function does not match the list of inputs requested." + r" Expected 1 parameter\(s\).", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_extra_fp_config_inputs(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None) -> DataFrame: + return Mock(DataFrame) + + with pytest.raises( + ValueError, + match=r"The signature of the user defined function does not match the list of inputs requested." + r" Expected 2 parameter\(s\).", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args( + spark_arg_provider, + feature_group_as_spark_df, + s3_uri_as_spark_df, +): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == feature_group_as_spark_df + assert inputs["input_s3_uri"] == s3_uri_as_spark_df + + +def test_provide_input_args_with_reversed_inputs( + spark_arg_provider, + feature_group_as_spark_df, + s3_uri_as_spark_df, +): + fp_config = tdh.create_fp_config( + inputs=[tdh.S3_DATA_SOURCE, tdh.FEATURE_GROUP_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == s3_uri_as_spark_df + assert inputs["input_s3_uri"] == feature_group_as_spark_df + + +def test_provide_input_args_with_optional_args_out_of_order(spark_arg_provider): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf_spark_params(spark=None, params=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_params_spark(params=None, spark=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_spark(spark=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_params(params=None, input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + for udf in [udf_spark_params, udf_params_spark, udf_spark, udf_params]: + with pytest.raises( + ValueError, + match="Expected at least one input to the user defined function.", + ): + spark_arg_provider.provide_input_args(udf, fp_config) + + +def test_provide_input_args_with_optional_args( + spark_arg_provider, feature_group_as_spark_df, s3_uri_as_spark_df +): + fp_config = tdh.create_fp_config( + inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE], + output=tdh.OUTPUT_FEATURE_GROUP_ARN, + ) + + def udf_all_optional(input_fg=None, input_s3_uri=None, params=None, spark=None) -> DataFrame: + return Mock(DataFrame) + + def udf_no_optional(input_fg=None, input_s3_uri=None) -> DataFrame: + return Mock(DataFrame) + + def udf_only_params(input_fg=None, input_s3_uri=None, params=None) -> DataFrame: + return Mock(DataFrame) + + def udf_only_spark(input_fg=None, input_s3_uri=None, spark=None) -> DataFrame: + return Mock(DataFrame) + + for udf in [udf_all_optional, udf_no_optional, udf_only_params, udf_only_spark]: + inputs = spark_arg_provider.provide_input_args(udf, fp_config) + + assert inputs.keys() == {"input_fg", "input_s3_uri"} + assert inputs["input_fg"] == feature_group_as_spark_df + assert inputs["input_s3_uri"] == s3_uri_as_spark_df + + +def test_provide_input_arg_for_base_data_source(spark_arg_provider, params_loader, spark_session): + fp_config = tdh.create_fp_config(inputs=[MockDataSource()], output=tdh.OUTPUT_FEATURE_GROUP_ARN) + + def udf(input_df) -> DataFrame: + return input_df + + with patch.object(MockDataSource, "read_data", return_value=Mock(DataFrame)) as mock_read: + spark_arg_provider.provide_input_args(udf, fp_config) + mock_read.assert_called_with(spark=spark_session, params={"key": "value"}) + params_loader.get_parameter_args.assert_called_with(fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py new file mode 100644 index 0000000000..34770a011a --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_output_receiver.py @@ -0,0 +1,106 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import test_data_helpers as tdh +from feature_store_pyspark.FeatureStoreManager import FeatureStoreManager +from mock import Mock +from py4j.protocol import Py4JJavaError +from pyspark.sql import DataFrame + +from sagemaker.mlops.feature_store.feature_processor import IngestionError +from sagemaker.mlops.feature_store.feature_processor._spark_factory import ( + FeatureStoreManagerFactory, +) +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + SparkOutputReceiver, +) + + +@pytest.fixture +def df() -> Mock: + return Mock(DataFrame) + + +@pytest.fixture +def feature_store_manager(): + return Mock(FeatureStoreManager) + + +@pytest.fixture +def feature_store_manager_factory(feature_store_manager): + return Mock(FeatureStoreManagerFactory, feature_store_manager=feature_store_manager) + + +@pytest.fixture +def spark_output_receiver(feature_store_manager_factory): + return SparkOutputReceiver(feature_store_manager_factory) + + +def test_ingest_udf_output_enable_ingestion_false(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config(enable_ingestion=False) + spark_output_receiver.ingest_udf_output(df, fp_config) + + feature_store_manager.ingest_data.assert_not_called() + + +def test_ingest_udf_output(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + spark_output_receiver.ingest_udf_output(df, fp_config) + + feature_store_manager.ingest_data.assert_called_with( + input_data_frame=df, + feature_group_arn=fp_config.output, + target_stores=fp_config.target_stores, + ) + + +def test_ingest_udf_output_failed_records(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + + # Simulate streaming ingestion failure. + mock_failed_records_df = Mock() + mock_java_exception = Mock(_target_id="") + mock_java_exception.getClass = Mock( + return_value=Mock(getSimpleName=Mock(return_value="StreamIngestionFailureException")) + ) + + feature_store_manager.ingest_data.side_effect = Py4JJavaError( + msg="", java_exception=mock_java_exception + ) + feature_store_manager.get_failed_stream_ingestion_data_frame.return_value = ( + mock_failed_records_df + ) + + with pytest.raises(IngestionError): + spark_output_receiver.ingest_udf_output(df, fp_config) + + mock_failed_records_df.show.assert_called_with(n=20, truncate=False) + + +def test_ingest_udf_output_all_py4j_error_raised(df, feature_store_manager, spark_output_receiver): + fp_config = tdh.create_fp_config() + + # Simulate ingestion failure. + mock_java_exception = Mock(_target_id="") + mock_java_exception.getClass = Mock( + return_value=Mock(getSimpleName=Mock(return_value="ValidationError")) + ) + feature_store_manager.ingest_data.side_effect = Py4JJavaError( + msg="", java_exception=mock_java_exception + ) + + with pytest.raises(Py4JJavaError): + spark_output_receiver.ingest_udf_output(df, fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py new file mode 100644 index 0000000000..51747d78f9 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_udf_wrapper.py @@ -0,0 +1,85 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable + +import pytest +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._feature_processor_config import ( + FeatureProcessorConfig, +) +from sagemaker.mlops.feature_store.feature_processor._udf_arg_provider import UDFArgProvider +from sagemaker.mlops.feature_store.feature_processor._udf_output_receiver import ( + UDFOutputReceiver, +) +from sagemaker.mlops.feature_store.feature_processor._udf_wrapper import UDFWrapper + + +@pytest.fixture +def udf_arg_provider(): + udf_arg_provider = Mock(UDFArgProvider) + udf_arg_provider.provide_input_args.return_value = {"input": Mock()} + udf_arg_provider.provide_params_arg.return_value = {"params": Mock()} + udf_arg_provider.provide_additional_kwargs.return_value = {"kwarg": Mock()} + + return udf_arg_provider + + +@pytest.fixture +def udf_output_receiver(): + udf_output_receiver = Mock(UDFOutputReceiver) + udf_output_receiver.ingest_udf_output.return_value = Mock() + return udf_output_receiver + + +@pytest.fixture +def udf_output(): + udf_output = Mock(Callable) + return udf_output + + +@pytest.fixture +def udf(udf_output): + udf = Mock(Callable) + udf.return_value = udf_output + return udf + + +@pytest.fixture +def fp_config(): + fp_config = Mock(FeatureProcessorConfig) + return fp_config + + +def test_wrap(fp_config, udf_output, udf_arg_provider, udf_output_receiver): + def test_udf(input, params, kwarg): + # Verify wrapped function is called with auto-loaded arguments. + assert input is udf_arg_provider.provide_input_args.return_value["input"] + assert params is udf_arg_provider.provide_params_arg.return_value["params"] + assert kwarg is udf_arg_provider.provide_additional_kwargs.return_value["kwarg"] + return udf_output + + udf_wrapper = UDFWrapper(udf_arg_provider, udf_output_receiver) + + # Execute decorator function and the decorated function. + wrapped_udf = udf_wrapper.wrap(test_udf, fp_config) + wrapped_udf() + + # Verify interactions with dependencies. + udf_arg_provider.provide_input_args.assert_called_with(test_udf, fp_config) + udf_arg_provider.provide_params_arg.assert_called_with(test_udf, fp_config) + udf_arg_provider.provide_additional_kwargs.assert_called_with(test_udf) + udf_output_receiver.ingest_udf_output.assert_called_with(udf_output, fp_config) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py new file mode 100644 index 0000000000..16a8784586 --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/feature_processor/test_validation.py @@ -0,0 +1,192 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from typing import Callable +from pyspark.sql import DataFrame + +import pytest + +import test_data_helpers as tdh +from mock import Mock + +from sagemaker.mlops.feature_store.feature_processor._validation import ( + SparkUDFSignatureValidator, + Validator, + ValidatorChain, + BaseDataSourceValidator, +) +from sagemaker.mlops.feature_store.feature_processor._data_source import ( + BaseDataSource, +) + + +def test_validator_chain(): + fp_config = tdh.create_fp_config() + udf = Mock(Callable) + + first_validator = Mock(Validator) + second_validator = Mock(Validator) + validator_chain = ValidatorChain([first_validator, second_validator]) + + validator_chain.validate(udf, fp_config) + + first_validator.validate.assert_called_with(udf, fp_config) + second_validator.validate.assert_called_with(udf, fp_config) + + +def test_validator_chain_validation_fails(): + fp_config = tdh.create_fp_config() + udf = Mock(Callable) + + first_validator = Mock(validate=Mock(side_effect=ValueError())) + second_validator = Mock(validate=Mock()) + validator_chain = ValidatorChain([first_validator, second_validator]) + + with pytest.raises(ValueError): + validator_chain.validate(udf, fp_config) + + +def test_spark_udf_signature_validator_valid(): + # One Input + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE]) + + def one_data_source(fg_data_source, params, spark): + return None + + SparkUDFSignatureValidator().validate(one_data_source, fp_config) + + # Two Inputs + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def two_data_sources(fg_data_source, s3_data_source, params, spark): + return None + + SparkUDFSignatureValidator().validate(two_data_sources, fp_config) + + # No Optional Args (params and spark) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_args(fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(no_optional_args, fp_config) + + # Optional Args (no params) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_params_arg(fg_data_source, s3_data_source, spark): + return None + + SparkUDFSignatureValidator().validate(no_optional_params_arg, fp_config) + + # No Optional Args (no spark) + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def no_optional_spark_arg(fg_data_source, s3_data_source, params): + return None + + SparkUDFSignatureValidator().validate(no_optional_spark_arg, fp_config) + + +def test_spark_udf_signature_validator_udf_input_mismatch(): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + + def one_input(one, params, spark): + return None + + def three_inputs(one, two, three, params, spark): + return None + + exception_string = ( + r"feature_processor expected a function with \(2\) parameter\(s\) before any" + r" optional 'params' or 'spark' parameters for the \(2\) requested data source\(s\)\." + ) + + with pytest.raises(ValueError, match=exception_string): + SparkUDFSignatureValidator().validate(one_input, fp_config) + + with pytest.raises(ValueError, match=exception_string): + SparkUDFSignatureValidator().validate(three_inputs, fp_config) + + +def test_spark_udf_signature_validator_zero_input_params(): + def zero_inputs(params, spark): + return None + + with pytest.raises(ValueError, match="feature_processor expects at least 1 input parameter."): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + SparkUDFSignatureValidator().validate(zero_inputs, fp_config) + + +def test_spark_udf_signature_validator_udf_invalid_non_input_position(): + fp_config = tdh.create_fp_config(inputs=[tdh.FEATURE_GROUP_DATA_SOURCE, tdh.S3_DATA_SOURCE]) + with pytest.raises( + ValueError, + match="feature_processor expected the 'params' parameter to be the last or second last" + " parameter after input parameters.", + ): + + def invalid_params_position(params, fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(invalid_params_position, fp_config) + + with pytest.raises( + ValueError, + match="feature_processor expected the 'spark' parameter to be the last or second last" + " parameter after input parameters.", + ): + + def invalid_spark_position(spark, fg_data_source, s3_data_source): + return None + + SparkUDFSignatureValidator().validate(invalid_spark_position, fp_config) + + +@pytest.mark.parametrize( + "data_source_name, data_source_unique_id, error_pattern", + [ + ("$_invalid_source", "unique_id", "data_source_name of input does not match pattern '.*'."), + ("", "unique_id", "data_source_name of input does not match pattern '.*'."), + ( + "source", + tdh.DATA_SOURCE_UNIQUE_ID_TOO_LONG, + "data_source_unique_id of input does not match pattern '.*'.", + ), + ("source", "", "data_source_unique_id of input does not match pattern '.*'."), + ], +) +def test_spark_udf_signature_validator_udf_invalid_base_data_source( + data_source_name, data_source_unique_id, error_pattern +): + class TestInValidCustomDataSource(BaseDataSource): + + data_source_name = None + data_source_unique_id = None + + def read_data(self, spark, params) -> DataFrame: + return None + + test_data_source = TestInValidCustomDataSource() + test_data_source.data_source_name = data_source_name + test_data_source.data_source_unique_id = data_source_unique_id + + fp_config = tdh.create_fp_config(inputs=[test_data_source]) + + def udf(input_data_source, params, spark): + return None + + with pytest.raises(ValueError, match=error_pattern): + BaseDataSourceValidator().validate(udf, fp_config) From 611a1594e8bad60834e0c6d45c471880aefff30e Mon Sep 17 00:00:00 2001 From: BassemHalim Date: Wed, 11 Feb 2026 16:12:54 -0800 Subject: [PATCH 8/8] revert: Remove Lake Formation commits (tracked in feature-store-lakeformation --- docs/api/sagemaker_mlops.rst | 8 - .../sagemaker/mlops/feature_store/__init__.py | 6 +- .../mlops/feature_store/feature_group.py | 711 ------ .../integ/test_featureStore_lakeformation.py | 660 ----- .../mlops/feature_store/test_lakeformation.py | 2141 ----------------- .../v3-feature-store-lake-formation.ipynb | 675 ------ 6 files changed, 1 insertion(+), 4200 deletions(-) delete mode 100644 sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py delete mode 100644 sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py delete mode 100644 sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py delete mode 100644 v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb diff --git a/docs/api/sagemaker_mlops.rst b/docs/api/sagemaker_mlops.rst index d9f911068e..f67879111d 100644 --- a/docs/api/sagemaker_mlops.rst +++ b/docs/api/sagemaker_mlops.rst @@ -21,14 +21,6 @@ Workflow Management :undoc-members: :show-inheritance: -Feature Store -------------- - -.. automodule:: sagemaker.mlops.feature_store - :members: - :undoc-members: - :show-inheritance: - Local Development ----------------- diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py index 1b635df2a7..f15d6d3845 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -2,11 +2,8 @@ # Licensed under the Apache License, Version 2.0 """SageMaker FeatureStore V3 - powered by sagemaker-core.""" -# FeatureGroup with Lake Formation support (local subclass) -from sagemaker.mlops.feature_store.feature_group import FeatureGroup, LakeFormationConfig - # Resources from core -from sagemaker.core.resources import FeatureMetadata +from sagemaker.core.resources import FeatureGroup, FeatureMetadata # Shapes from core (Pydantic - no to_dict() needed) from sagemaker.core.shapes import ( @@ -82,7 +79,6 @@ "FeatureParameter", "FeatureValue", "Filter", - "LakeFormationConfig", "OfflineStoreConfig", "OnlineStoreConfig", "OnlineStoreSecurityConfig", diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py deleted file mode 100644 index c5fcb9211a..0000000000 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group.py +++ /dev/null @@ -1,711 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# Licensed under the Apache License, Version 2.0 -"""FeatureGroup with Lake Formation support.""" - -import logging -from typing import List, Optional - -import botocore.exceptions - -from sagemaker.core.resources import FeatureGroup as CoreFeatureGroup -from sagemaker.core.resources import Base -from sagemaker.core.shapes import ( - FeatureDefinition, - OfflineStoreConfig, - OnlineStoreConfig, - Tag, - ThroughputConfig, -) -from sagemaker.core.shapes import Unassigned -from sagemaker.core.helper.pipeline_variable import StrPipeVar -from sagemaker.core.s3.utils import parse_s3_url -from sagemaker.core.common_utils import aws_partition -from boto3 import Session - - -logger = logging.getLogger(__name__) - - -class LakeFormationConfig: - """Configuration for Lake Formation governance on Feature Group offline stores. - - Attributes: - enabled: If True, enables Lake Formation governance for the offline store. - Requires offline_store_config and role_arn to be set on the Feature Group. - use_service_linked_role: Whether to use the Lake Formation service-linked role - for S3 registration. If True, Lake Formation uses its service-linked role. - If False, registration_role_arn must be provided. Default is True. - registration_role_arn: IAM role ARN to use for S3 registration with Lake Formation. - Required when use_service_linked_role is False. This can be different from the - Feature Group's execution role. - show_s3_policy: If True, prints the S3 deny policy to the console after successful - Lake Formation setup. This policy should be added to your S3 bucket to restrict - access to only the allowed principals. Default is False. - """ - - enabled: bool = False - use_service_linked_role: bool = True - registration_role_arn: Optional[str] = None - show_s3_policy: bool = False - - -class FeatureGroup(CoreFeatureGroup): - - # Inherit parent docstring and append our additions - if CoreFeatureGroup.__doc__ and __doc__: - __doc__ = CoreFeatureGroup.__doc__ - - @staticmethod - def _s3_uri_to_arn(s3_uri: str, region: Optional[str] = None) -> str: - """ - Convert S3 URI to S3 ARN format for Lake Formation. - - Args: - s3_uri: S3 URI in format s3://bucket/path or already an ARN - region: AWS region name (e.g., 'us-west-2'). Used to determine the correct - partition for the ARN. If not provided, defaults to 'aws' partition. - - Returns: - S3 ARN in format arn:{partition}:s3:::bucket/path - - Note: - This format is specifically used for Lake Formation resource registration. - The triple colon (:::) after 's3' is correct - S3 ARNs don't include - region or account ID fields. - """ - if s3_uri.startswith("arn:"): - return s3_uri - - # Determine partition based on region - partition = aws_partition(region) if region else "aws" - - bucket, key = parse_s3_url(s3_uri) - # Reconstruct as ARN - key may be empty string - s3_path = f"{bucket}/{key}" if key else bucket - return f"arn:{partition}:s3:::{s3_path}" - - @staticmethod - def _extract_account_id_from_arn(arn: str) -> str: - """ - Extract AWS account ID from an ARN. - - Args: - arn: AWS ARN in format arn:aws:service:region:account:resource - - Returns: - AWS account ID (the 5th colon-separated field) - - Raises: - ValueError: If ARN format is invalid (fewer than 5 colon-separated parts) - """ - parts = arn.split(":") - if len(parts) < 5: - raise ValueError(f"Invalid ARN format: {arn}") - return parts[4] - - @staticmethod - def _get_lake_formation_service_linked_role_arn( - account_id: str, region: Optional[str] = None - ) -> str: - """ - Generate the Lake Formation service-linked role ARN for an account. - - Args: - account_id: AWS account ID - region: AWS region name (e.g., 'us-west-2'). Used to determine the correct - partition for the ARN. If not provided, defaults to 'aws' partition. - - Returns: - Lake Formation service-linked role ARN in format: - arn:{partition}:iam::{account}:role/aws-service-role/lakeformation.amazonaws.com/ - AWSServiceRoleForLakeFormationDataAccess - """ - partition = aws_partition(region) if region else "aws" - return ( - f"arn:{partition}:iam::{account_id}:role/aws-service-role/" - f"lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" - ) - - def _generate_s3_deny_policy( - self, - bucket_name: str, - s3_prefix: str, - lake_formation_role_arn: str, - feature_store_role_arn: str, - ) -> dict: - """ - Generate an S3 deny policy for Lake Formation governance. - - This policy denies S3 access to the offline store data prefix except for - the Lake Formation role and Feature Store execution role. - - Args: - bucket_name: S3 bucket name. - s3_prefix: S3 prefix path (without bucket name). - lake_formation_role_arn: Lake Formation registration role ARN. - feature_store_role_arn: Feature Store execution role ARN. - - Returns: - S3 bucket policy as a dict with valid JSON structure containing: - - Version: "2012-10-17" - - Statement: List with two deny statements: - 1. Deny GetObject, PutObject, DeleteObject on data prefix except allowed principals - 2. Deny ListBucket on bucket with prefix condition except allowed principals - """ - policy = { - "Version": "2012-10-17", - "Statement": [ - { - "Sid": "DenyAllAccessToFeatureStorePrefixExceptAllowedPrincipals", - "Effect": "Deny", - "Principal": "*", - "Action": ["s3:GetObject", "s3:PutObject", "s3:DeleteObject"], - "Resource": f"arn:aws:s3:::{bucket_name}/{s3_prefix}/*", - "Condition": { - "StringNotEquals": { - "aws:PrincipalArn": [ - lake_formation_role_arn, - feature_store_role_arn, - ] - } - }, - }, - { - "Sid": "DenyListOnPrefixExceptAllowedPrincipals", - "Effect": "Deny", - "Principal": "*", - "Action": "s3:ListBucket", - "Resource": f"arn:aws:s3:::{bucket_name}", - "Condition": { - "StringLike": {"s3:prefix": f"{s3_prefix}/*"}, - "StringNotEquals": { - "aws:PrincipalArn": [ - lake_formation_role_arn, - feature_store_role_arn, - ] - }, - }, - }, - ], - } - return policy - - def _get_lake_formation_client( - self, - session: Optional[Session] = None, - region: Optional[str] = None, - ): - """ - Get a Lake Formation client. - - Args: - session: Boto3 session. If not provided, a new session will be created. - region: AWS region name. - - Returns: - A boto3 Lake Formation client. - """ - # TODO: don't create w new client for each call - boto_session = session or Session() - return boto_session.client("lakeformation", region_name=region) - - def _register_s3_with_lake_formation( - self, - s3_location: str, - session: Optional[Session] = None, - region: Optional[str] = None, - use_service_linked_role: bool = True, - role_arn: Optional[str] = None, - ) -> bool: - """ - Register an S3 location with Lake Formation. - - Args: - s3_location: S3 URI or ARN to register. - session: Boto3 session. - region: AWS region. If not provided, will be inferred from the session. - use_service_linked_role: Whether to use the Lake Formation service-linked role. - If True, Lake Formation uses its service-linked role for registration. - If False, role_arn must be provided. - role_arn: IAM role ARN to use for registration. Required when - use_service_linked_role is False. - - Returns: - True if registration succeeded or location already registered. - - Raises: - ValueError: If use_service_linked_role is False but role_arn is not provided. - ClientError: If registration fails for unexpected reasons. - """ - if not use_service_linked_role and not role_arn: - raise ValueError("role_arn must be provided when use_service_linked_role is False") - - # Get region from session if not provided - if region is None and session is not None: - region = session.region_name() - - client = self._get_lake_formation_client(session, region) - resource_arn = self._s3_uri_to_arn(s3_location, region) - - try: - register_params = {"ResourceArn": resource_arn} - - if use_service_linked_role: - register_params["UseServiceLinkedRole"] = True - else: - register_params["RoleArn"] = role_arn - - client.register_resource(**register_params) - logger.info(f"Successfully registered S3 location: {resource_arn}") - return True - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "AlreadyExistsException": - logger.info(f"S3 location already registered: {resource_arn}") - return True - raise - - def _revoke_iam_allowed_principal( - self, - database_name: str, - table_name: str, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> bool: - """ - Revoke IAMAllowedPrincipal permissions from a Glue table. - - Args: - database_name: Glue database name. - table_name: Glue table name. - session: Boto3 session. - region: AWS region. If not provided, will be inferred from the session. - - Returns: - True if revocation succeeded or permissions didn't exist. - - Raises: - ClientError: If revocation fails for unexpected reasons. - """ - # Get region from session if not provided - if region is None and session is not None: - region = session.region_name() - - client = self._get_lake_formation_client(session, region) - - try: - client.revoke_permissions( - Principal={"DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS"}, - Resource={ - "Table": { - "DatabaseName": database_name, - "Name": table_name, - } - }, - Permissions=["ALL"], - ) - logger.info(f"Revoked IAMAllowedPrincipal from table: {database_name}.{table_name}") - return True - except botocore.exceptions.ClientError as e: - # if the Table doesn't have that permission because the user already revoked it - # then just return True - if e.response["Error"]["Code"] == "InvalidInputException": - logger.info( - f"IAMAllowedPrincipal permissions may not exist on: {database_name}.{table_name}" - ) - return True - raise - - def _grant_lake_formation_permissions( - self, - role_arn: str, - database_name: str, - table_name: str, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> bool: - """ - Grant permissions to a role on a Glue table via Lake Formation. - - Args: - role_arn: IAM role ARN to grant permissions to. - database_name: Glue database name. - table_name: Glue table name. - session: Boto3 session. - region: AWS region. If not provided, will be inferred from the session. - - Returns: - True if grant succeeded or permissions already exist. - - Raises: - ClientError: If grant fails for unexpected reasons. - """ - # Get region from session if not provided - if region is None and session is not None: - region = session.region_name() - - client = self._get_lake_formation_client(session, region) - permissions = ["SELECT", "INSERT", "DELETE", "DESCRIBE", "ALTER"] - - try: - client.grant_permissions( - Principal={"DataLakePrincipalIdentifier": role_arn}, - Resource={ - "Table": { - "DatabaseName": database_name, - "Name": table_name, - } - }, - Permissions=permissions, - PermissionsWithGrantOption=[], - ) - logger.info(f"Granted permissions to {role_arn} on table: {database_name}.{table_name}") - return True - except botocore.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "InvalidInputException": - logger.info( - f"Permissions may already exist for {role_arn} on: {database_name}.{table_name}" - ) - return True - raise - - @Base.add_validate_call - def enable_lake_formation( - self, - session: Optional[Session] = None, - region: Optional[str] = None, - use_service_linked_role: bool = True, - registration_role_arn: Optional[str] = None, - wait_for_active: bool = False, - show_s3_policy: bool = False, - ) -> dict: - """ - Enable Lake Formation governance for this Feature Group's offline store. - - This method: - 1. Optionally waits for Feature Group to reach 'Created' status - 2. Validates Feature Group status is 'Created' - 3. Registers the offline store S3 location as data lake location - 4. Grants the execution role permissions on the Glue table - 5. Revokes IAMAllowedPrincipal permissions from the Glue table - - The role ARN is automatically extracted from the Feature Group's configuration. - Each phase depends on the success of the previous phase - if any phase fails, - subsequent phases are not executed. - - Parameters: - session: Boto3 session. - region: Region name. - use_service_linked_role: Whether to use the Lake Formation service-linked role - for S3 registration. If True, Lake Formation uses its service-linked role. - If False, registration_role_arn must be provided. Default is True. - registration_role_arn: IAM role ARN to use for S3 registration with Lake Formation. - Required when use_service_linked_role is False. This can be different from the - Feature Group's execution role (role_arn) - wait_for_active: If True, waits for the Feature Group to reach 'Created' status - before enabling Lake Formation. Default is False. - show_s3_policy: If True, prints the S3 deny policy to the console after successful - Lake Formation setup. This policy should be added to your S3 bucket to restrict - access to only the allowed principals. Default is False. - - Returns: - Dict with status of each Lake Formation operation: - - s3_registration: bool - - iam_principal_revoked: bool - - permissions_granted: bool - - Raises: - ValueError: If the Feature Group has no offline store configured, - if role_arn is not set on the Feature Group, if use_service_linked_role - is False but registration_role_arn is not provided, or if the Feature Group - is not in 'Created' status. - ClientError: If Lake Formation operations fail. - RuntimeError: If a phase fails and subsequent phases cannot proceed. - """ - # Get region from session if not provided - if region is None and session is not None: - region = session.region_name() - - # Wait for Created status if requested - if wait_for_active: - self.wait_for_status(target_status="Created") - - # Refresh to get latest state - self.refresh() - - # Validate Feature Group status - if self.feature_group_status not in ["Created"]: - raise ValueError( - f"Feature Group '{self.feature_group_name}' must be in 'Created' status " - f"to enable Lake Formation. Current status: '{self.feature_group_status}'. " - f"Use wait_for_active=True to automatically wait for the Feature Group to be ready." - ) - - # Validate offline store exists - if self.offline_store_config is None or self.offline_store_config == Unassigned(): - raise ValueError( - f"Feature Group '{self.feature_group_name}' does not have an offline store configured. " - "Lake Formation can only be enabled for Feature Groups with offline stores." - ) - - # Get role ARN from Feature Group config - if self.role_arn is None or self.role_arn == Unassigned(): - raise ValueError( - f"Feature Group '{self.feature_group_name}' does not have a role_arn configured. " - "Lake Formation requires a role ARN to grant permissions." - ) - if not use_service_linked_role and registration_role_arn is None: - raise ValueError( - "Either 'use_service_linked_role' must be True or 'registration_role_arn' must be provided." - ) - - # Extract required configuration - s3_config = self.offline_store_config.s3_storage_config - if s3_config is None: - raise ValueError("Offline store S3 configuration is missing") - - resolved_s3_uri = s3_config.resolved_output_s3_uri - if resolved_s3_uri is None or resolved_s3_uri == Unassigned(): - raise ValueError( - "Resolved S3 URI not available. Ensure the Feature Group is in 'Created' status." - ) - - data_catalog_config = self.offline_store_config.data_catalog_config - if data_catalog_config is None: - raise ValueError("Data catalog configuration is missing from offline store config") - - database_name = data_catalog_config.database - table_name = data_catalog_config.table_name - - if not database_name or not table_name: - raise ValueError("Database name and table name are required from data catalog config") - - # Convert to str to handle PipelineVariable types - resolved_s3_uri_str = str(resolved_s3_uri) - database_name_str = str(database_name) - table_name_str = str(table_name) - role_arn_str = str(self.role_arn) - - # Execute Lake Formation setup with fail-fast behavior - results = { - "s3_registration": False, - "iam_principal_revoked": False, - "permissions_granted": False, - } - - # Phase 1: Register S3 with Lake Formation - try: - results["s3_registration"] = self._register_s3_with_lake_formation( - resolved_s3_uri_str, - session, - region, - use_service_linked_role=use_service_linked_role, - role_arn=registration_role_arn, - ) - except Exception as e: - raise RuntimeError( - f"Failed to register S3 location with Lake Formation. " - f"Subsequent phases skipped. Results: {results}. Error: {e}" - ) from e - - if not results["s3_registration"]: - raise RuntimeError( - f"Failed to register S3 location with Lake Formation. " - f"Subsequent phases skipped. Results: {results}" - ) - - # Phase 2: Grant Lake Formation permissions to the role - try: - results["permissions_granted"] = self._grant_lake_formation_permissions( - role_arn_str, database_name_str, table_name_str, session, region - ) - except Exception as e: - raise RuntimeError( - f"Failed to grant Lake Formation permissions. " - f"Subsequent phases skipped. Results: {results}. Error: {e}" - ) from e - - if not results["permissions_granted"]: - raise RuntimeError( - f"Failed to grant Lake Formation permissions. " - f"Subsequent phases skipped. Results: {results}" - ) - - # Phase 3: Revoke IAMAllowedPrincipal permissions - try: - results["iam_principal_revoked"] = self._revoke_iam_allowed_principal( - database_name_str, table_name_str, session, region - ) - except Exception as e: - raise RuntimeError( - f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}. Error: {e}" - ) from e - - if not results["iam_principal_revoked"]: - raise RuntimeError( - f"Failed to revoke IAMAllowedPrincipal permissions. Results: {results}" - ) - - logger.info(f"Lake Formation setup complete for {self.feature_group_name}: {results}") - - # Generate and optionally print S3 deny policy - if show_s3_policy: - # Extract bucket name and prefix from resolved S3 URI using core utility - bucket_name, s3_prefix = parse_s3_url(resolved_s3_uri_str) - - # Extract account ID from Feature Group ARN - feature_group_arn_str = str(self.feature_group_arn) if self.feature_group_arn else "" - account_id = self._extract_account_id_from_arn(feature_group_arn_str) - - # Determine Lake Formation role ARN based on use_service_linked_role flag - if use_service_linked_role: - lf_role_arn = self._get_lake_formation_service_linked_role_arn(account_id, region) - else: - # registration_role_arn is validated earlier when use_service_linked_role is False - lf_role_arn = str(registration_role_arn) - - # Generate the S3 deny policy - policy = self._generate_s3_deny_policy( - bucket_name=bucket_name, - s3_prefix=s3_prefix, - lake_formation_role_arn=lf_role_arn, - feature_store_role_arn=role_arn_str, - ) - - # Print policy with clear instructions - import json - - print("\n" + "=" * 80) - print("S3 Bucket Policy Update recommended") - print("=" * 80) - print( - "\nTo complete Lake Formation setup, add the following deny policy to your S3 bucket." - ) - print( - "This policy restricts access to the offline store data to only the allowed principals." - ) - print("\nBucket:", bucket_name) - print("\nPolicy to add:") - print("-" * 40) - print(json.dumps(policy, indent=2)) - print("-" * 40) - print("\nNote: Merge this with your existing bucket policy if one exists.") - print("=" * 80 + "\n") - - return results - - @classmethod - @Base.add_validate_call - def create( - cls, - feature_group_name: StrPipeVar, - record_identifier_feature_name: StrPipeVar, - event_time_feature_name: StrPipeVar, - feature_definitions: List[FeatureDefinition], - online_store_config: Optional[OnlineStoreConfig] = None, - offline_store_config: Optional[OfflineStoreConfig] = None, - throughput_config: Optional[ThroughputConfig] = None, - role_arn: Optional[StrPipeVar] = None, - description: Optional[StrPipeVar] = None, - tags: Optional[List[Tag]] = None, - use_pre_prod_offline_store_replicator_lambda: Optional[bool] = None, - lake_formation_config: Optional[LakeFormationConfig] = None, - session: Optional[Session] = None, - region: Optional[StrPipeVar] = None, - ) -> Optional["FeatureGroup"]: - """ - Create a FeatureGroup resource with optional Lake Formation governance. - - Parameters: - feature_group_name: The name of the FeatureGroup. - record_identifier_feature_name: The name of the Feature whose value uniquely - identifies a Record. - event_time_feature_name: The name of the feature that stores the EventTime. - feature_definitions: A list of Feature names and types. - online_store_config: Configuration for the OnlineStore. - offline_store_config: Configuration for the OfflineStore. - throughput_config: Throughput configuration. - role_arn: IAM execution role ARN for the OfflineStore. - description: A free-form description of the FeatureGroup. - tags: Tags used to identify Features in each FeatureGroup. - use_pre_prod_offline_store_replicator_lambda: Pre-prod replicator flag. - lake_formation_config: Optional LakeFormationConfig to configure Lake Formation - governance. When enabled=True, requires offline_store_config and role_arn. - session: Boto3 session. - region: Region name. - - Returns: - The FeatureGroup resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. - For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 - """ - # Validation for Lake Formation - if lake_formation_config is not None and lake_formation_config.enabled: - if offline_store_config is None: - raise ValueError( - "lake_formation_config with enabled=True requires offline_store_config to be configured" - ) - if role_arn is None: - raise ValueError( - "lake_formation_config with enabled=True requires role_arn to be specified" - ) - if ( - not lake_formation_config.use_service_linked_role - and not lake_formation_config.registration_role_arn - ): - raise ValueError( - "registration_role_arn must be provided in lake_formation_config " - "when use_service_linked_role is False" - ) - - # Build kwargs, only including non-None values so parent uses its defaults - create_kwargs = { - "feature_group_name": feature_group_name, - "record_identifier_feature_name": record_identifier_feature_name, - "event_time_feature_name": event_time_feature_name, - "feature_definitions": feature_definitions, - "session": session, - "region": region, - } - if online_store_config is not None: - create_kwargs["online_store_config"] = online_store_config - if offline_store_config is not None: - create_kwargs["offline_store_config"] = offline_store_config - if throughput_config is not None: - create_kwargs["throughput_config"] = throughput_config - if role_arn is not None: - create_kwargs["role_arn"] = role_arn - if description is not None: - create_kwargs["description"] = description - if tags is not None: - create_kwargs["tags"] = tags - if use_pre_prod_offline_store_replicator_lambda is not None: - create_kwargs["use_pre_prod_offline_store_replicator_lambda"] = use_pre_prod_offline_store_replicator_lambda - - feature_group = super().create(**create_kwargs) - - # Enable Lake Formation if requested - if lake_formation_config is not None and lake_formation_config.enabled: - feature_group.wait_for_status(target_status="Created") - feature_group.enable_lake_formation( - session=session, - region=region, - use_service_linked_role=lake_formation_config.use_service_linked_role, - registration_role_arn=lake_formation_config.registration_role_arn, - show_s3_policy=lake_formation_config.show_s3_policy, - ) - return feature_group diff --git a/sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py b/sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py deleted file mode 100644 index dc6f12181a..0000000000 --- a/sagemaker-mlops/tests/integ/test_featureStore_lakeformation.py +++ /dev/null @@ -1,660 +0,0 @@ -""" -Integration tests for Lake Formation with FeatureGroup. - -These tests require: -- AWS credentials with Lake Formation and SageMaker permissions -- An S3 bucket for offline store (uses default SageMaker bucket) -- An IAM role for Feature Store (uses execution role) - -Run with: pytest tests/integ/test_featureStore_lakeformation.py -v -m integ -""" - -import uuid - -import boto3 -import pytest -from botocore.exceptions import ClientError - -from sagemaker.core.helper.session_helper import Session, get_execution_role -from sagemaker.mlops.feature_store import ( - FeatureGroup, - LakeFormationConfig, - OfflineStoreConfig, - OnlineStoreConfig, - S3StorageConfig, - StringFeatureDefinition, - FractionalFeatureDefinition, -) - -feature_definitions = [ - StringFeatureDefinition(feature_name="record_id"), - StringFeatureDefinition(feature_name="event_time"), - FractionalFeatureDefinition(feature_name="feature_value"), -] - - -@pytest.fixture(scope="module") -def sagemaker_session(): - return Session() - - -@pytest.fixture(scope="module") -def role(sagemaker_session): - return get_execution_role(sagemaker_session) - - -@pytest.fixture(scope="module") -def s3_uri(sagemaker_session): - bucket = sagemaker_session.default_bucket() - return f"s3://{bucket}/feature-store-test" - - -@pytest.fixture(scope="module") -def region(): - return "us-west-2" - - -@pytest.fixture(scope="module") -def shared_feature_group_for_negative_tests(s3_uri, role, region): - """ - Create a single FeatureGroup for negative tests that only need to verify - error conditions without modifying the resource. - - This fixture is module-scoped to be created once and shared across tests, - reducing test execution time. - """ - fg_name = f"test-lf-negative-{uuid.uuid4().hex[:8]}" - fg = None - - try: - fg = create_test_feature_group(fg_name, s3_uri, role, region) - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - yield fg - finally: - if fg: - cleanup_feature_group(fg) - - -def generate_feature_group_name(): - """Generate a unique feature group name for testing.""" - return f"test-lf-fg-{uuid.uuid4().hex[:8]}" - - -def create_test_feature_group(name: str, s3_uri: str, role_arn: str, region: str) -> FeatureGroup: - """Create a FeatureGroup with offline store for testing.""" - - offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) - - fg = FeatureGroup.create( - feature_group_name=name, - record_identifier_feature_name="record_id", - event_time_feature_name="event_time", - feature_definitions=feature_definitions, - offline_store_config=offline_store_config, - role_arn=role_arn, - region=region, - ) - - return fg - - -def cleanup_feature_group(fg: FeatureGroup): - """ - Delete a FeatureGroup and its associated Glue table. - - Args: - fg: The FeatureGroup to delete. - """ - try: - # Delete the Glue table if it exists - if fg.offline_store_config is not None: - try: - fg.refresh() # Ensure we have latest config - data_catalog_config = fg.offline_store_config.data_catalog_config - if data_catalog_config is not None: - database_name = data_catalog_config.database - table_name = data_catalog_config.table_name - - if database_name and table_name: - glue_client = boto3.client("glue") - try: - glue_client.delete_table(DatabaseName=database_name, Name=table_name) - except ClientError as e: - # Ignore if table doesn't exist - if e.response["Error"]["Code"] != "EntityNotFoundException": - raise - except Exception: - # Don't fail cleanup if Glue table deletion fails - pass - - # Delete the FeatureGroup - fg.delete() - except ClientError: - # Don't fail cleanup if Glue table deletion fails - pass - - -@pytest.mark.serial -@pytest.mark.slow_test -def test_create_feature_group_and_enable_lake_formation(s3_uri, role, region): - """ - Test creating a FeatureGroup and enabling Lake Formation governance. - - This test: - 1. Creates a new FeatureGroup with offline store - 2. Waits for it to reach Created status - 3. Enables Lake Formation governance (registers S3, grants permissions, revokes IAM principals) - 4. Cleans up the FeatureGroup - """ - - fg_name = generate_feature_group_name() - fg = None - - try: - # Create the FeatureGroup - fg = create_test_feature_group(fg_name, s3_uri, role, region) - assert fg is not None - - # Wait for Created status - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - assert fg.feature_group_status == "Created" - - # Enable Lake Formation governance - result = fg.enable_lake_formation() - - # Verify all phases completed successfully - assert result["s3_registration"] is True - assert result["permissions_granted"] is True - assert result["iam_principal_revoked"] is True - - finally: - print('done') - # Cleanup - if fg: - cleanup_feature_group(fg) - - -@pytest.mark.serial -@pytest.mark.slow_test -def test_create_feature_group_with_lake_formation_enabled(s3_uri, role, region): - """ - Test creating a FeatureGroup with lake_formation_config.enabled=True. - - This test verifies the integrated workflow where Lake Formation is enabled - automatically during FeatureGroup creation: - 1. Creates a new FeatureGroup with lake_formation_config.enabled=True - 2. Verifies the FeatureGroup is created and Lake Formation is configured - 3. Cleans up the FeatureGroup - """ - - fg_name = generate_feature_group_name() - fg = None - - try: - # Create the FeatureGroup with Lake Formation enabled - - offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) - lake_formation_config = LakeFormationConfig() - lake_formation_config.enabled = True - - fg = FeatureGroup.create( - feature_group_name=fg_name, - record_identifier_feature_name="record_id", - event_time_feature_name="event_time", - feature_definitions=feature_definitions, - offline_store_config=offline_store_config, - role_arn=role, - lake_formation_config=lake_formation_config, - ) - - # Verify the FeatureGroup was created - assert fg is not None - assert fg.feature_group_name == fg_name - assert fg.feature_group_status == "Created" - - # Verify Lake Formation is configured by checking we can refresh without errors - fg.refresh() - assert fg.offline_store_config is not None - - finally: - # Cleanup - if fg: - cleanup_feature_group(fg) - - -@pytest.mark.serial -def test_create_feature_group_without_lake_formation(s3_uri, role, region): - """ - Test creating a FeatureGroup without Lake Formation enabled. - - This test verifies that when lake_formation_config is not provided or enabled=False, - the FeatureGroup is created successfully without any Lake Formation operations: - 1. Creates a new FeatureGroup without lake_formation_config - 2. Verifies the FeatureGroup is created successfully - 3. Verifies no Lake Formation operations were performed - 4. Cleans up the FeatureGroup - """ - fg_name = generate_feature_group_name() - fg = None - - try: - # Create the FeatureGroup without Lake Formation - offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) - - # Create without lake_formation_config (default behavior) - fg = FeatureGroup.create( - feature_group_name=fg_name, - record_identifier_feature_name="record_id", - event_time_feature_name="event_time", - feature_definitions=feature_definitions, - offline_store_config=offline_store_config, - role_arn=role, - ) - - # Verify the FeatureGroup was created - assert fg is not None - assert fg.feature_group_name == fg_name - - # Wait for Created status to ensure it's fully provisioned - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - assert fg.feature_group_status == "Created" - - # Verify offline store is configured - fg.refresh() - assert fg.offline_store_config is not None - assert fg.offline_store_config.s3_storage_config is not None - - finally: - # Cleanup - if fg: - cleanup_feature_group(fg) - - -# ============================================================================ -# Negative Integration Tests -# ============================================================================ - - -def test_create_feature_group_with_lake_formation_fails_without_offline_store(role, region): - """ - Test that creating a FeatureGroup with enable_lake_formation=True fails - when no offline store is configured. - - Expected behavior: ValueError should be raised indicating offline store is required. - """ - fg_name = generate_feature_group_name() - - lake_formation_config = LakeFormationConfig() - lake_formation_config.enabled = True - - # Attempt to create without offline store but with Lake Formation enabled - with pytest.raises(ValueError) as exc_info: - FeatureGroup.create( - feature_group_name=fg_name, - record_identifier_feature_name="record_id", - event_time_feature_name="event_time", - feature_definitions=feature_definitions, - role_arn=role, - lake_formation_config=lake_formation_config, - ) - - # Verify error message mentions offline_store_config requirement - assert "lake_formation_config with enabled=True requires offline_store_config to be configured" in str( - exc_info.value - ) - - -def test_create_feature_group_with_lake_formation_fails_without_role(s3_uri, region): - """ - Test that creating a FeatureGroup with lake_formation_config.enabled=True fails - when no role_arn is provided. - - Expected behavior: ValueError should be raised indicating role_arn is required. - """ - fg_name = generate_feature_group_name() - - offline_store_config = OfflineStoreConfig(s3_storage_config=S3StorageConfig(s3_uri=s3_uri)) - lake_formation_config = LakeFormationConfig() - lake_formation_config.enabled = True - - # Attempt to create without role_arn but with Lake Formation enabled - with pytest.raises(ValueError) as exc_info: - FeatureGroup.create( - feature_group_name=fg_name, - record_identifier_feature_name="record_id", - event_time_feature_name="event_time", - feature_definitions=feature_definitions, - offline_store_config=offline_store_config, - lake_formation_config=lake_formation_config, - ) - - # Verify error message mentions role_arn requirement - assert "lake_formation_config with enabled=True requires role_arn to be specified" in str(exc_info.value) - - -def test_enable_lake_formation_fails_for_non_created_status(s3_uri, role, region): - """ - Test that enable_lake_formation() fails when called on a FeatureGroup - that is not in 'Created' status. - - Expected behavior: ValueError should be raised indicating the Feature Group - must be in 'Created' status. - - Note: This test creates its own FeatureGroup because it needs to test - behavior during the 'Creating' status, which requires a fresh resource. - """ - fg_name = generate_feature_group_name() - fg = None - - try: - # Create the FeatureGroup - fg = create_test_feature_group(fg_name, s3_uri, role, region) - assert fg is not None - - # Immediately try to enable Lake Formation without waiting for Created status - # The Feature Group will be in 'Creating' status - with pytest.raises(ValueError) as exc_info: - fg.enable_lake_formation(wait_for_active=False) - - # Verify error message mentions status requirement - error_msg = str(exc_info.value) - assert "must be in 'Created' status to enable Lake Formation" in error_msg - - finally: - # Cleanup - if fg: - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - cleanup_feature_group(fg) - - -def test_enable_lake_formation_without_offline_store(role, region): - """ - Test that enable_lake_formation() fails when called on a FeatureGroup - without an offline store configured. - - Expected behavior: ValueError should be raised indicating offline store is required. - - Note: This test creates a FeatureGroup with only online store, which is a valid - configuration, but Lake Formation cannot be enabled for it. - """ - fg_name = generate_feature_group_name() - fg = None - - try: - # Create a FeatureGroup with only online store (no offline store) - online_store_config = OnlineStoreConfig(enable_online_store=True) - - fg = FeatureGroup.create( - feature_group_name=fg_name, - record_identifier_feature_name="record_id", - event_time_feature_name="event_time", - feature_definitions=feature_definitions, - online_store_config=online_store_config, - role_arn=role, - ) - - # Wait for Created status - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - - # Attempt to enable Lake Formation - with pytest.raises(ValueError) as exc_info: - fg.enable_lake_formation() - # Verify error message mentions offline store requirement - assert "does not have an offline store configured" in str(exc_info.value) - - finally: - # Cleanup - if fg: - cleanup_feature_group(fg) - - -def test_enable_lake_formation_fails_with_invalid_registration_role( - shared_feature_group_for_negative_tests, -): - """ - Test that enable_lake_formation() fails when use_service_linked_role=False - but no registration_role_arn is provided. - - Expected behavior: ValueError should be raised indicating registration_role_arn - is required when not using service-linked role. - """ - fg = shared_feature_group_for_negative_tests - - # Attempt to enable Lake Formation without service-linked role and without registration_role_arn - with pytest.raises(ValueError) as exc_info: - fg.enable_lake_formation( - use_service_linked_role=False, - registration_role_arn=None, - ) - - # Verify error message mentions role requirement - error_msg = str(exc_info.value) - assert "registration_role_arn" in error_msg - - -def test_enable_lake_formation_fails_with_nonexistent_role( - shared_feature_group_for_negative_tests, role -): - """ - Test that enable_lake_formation() properly bubbles errors when using - a nonexistent role ARN for Lake Formation registration. - - Expected behavior: RuntimeError or ClientError should be raised with details - about the registration failure. - - Note: This test uses a nonexistent role ARN (current role with random suffix) - to trigger an error during S3 registration with Lake Formation. - """ - fg = shared_feature_group_for_negative_tests - - # Generate a nonexistent role ARN by appending a random string to the current role - nonexistent_role = f"{role}-nonexistent-{uuid.uuid4().hex[:8]}" - - with pytest.raises(RuntimeError) as exc_info: - fg.enable_lake_formation( - use_service_linked_role=False, - registration_role_arn=nonexistent_role, - ) - - # Verify we got an appropriate error - error_msg = str(exc_info.value) - print(exc_info) - # Should mention role-related issues (not found, invalid, access denied, etc.) - assert "EntityNotFoundException" in error_msg - - -# ============================================================================ -# Full Flow Integration Tests with Policy Output -# ============================================================================ - - -@pytest.mark.serial -@pytest.mark.slow_test -def test_enable_lake_formation_full_flow_with_policy_output(s3_uri, role, region, capsys): - """ - Test the full Lake Formation flow with S3 deny policy output. - - This test verifies: - 1. Creates a FeatureGroup with offline store - 2. Enables Lake Formation with show_s3_policy=True - 3. Verifies all Lake Formation phases complete successfully - 4. Verifies the S3 deny policy is printed to the console - 5. Verifies the policy structure contains expected elements - - This validates Requirements 6.1-6.9 from the design document. - """ - fg_name = generate_feature_group_name() - fg = None - - try: - # Create the FeatureGroup - fg = create_test_feature_group(fg_name, s3_uri, role, region) - assert fg is not None - - # Wait for Created status - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - assert fg.feature_group_status == "Created" - - # Enable Lake Formation governance with policy output - result = fg.enable_lake_formation(show_s3_policy=True) - - # Verify all phases completed successfully - assert result["s3_registration"] is True - assert result["permissions_granted"] is True - assert result["iam_principal_revoked"] is True - - # Capture the printed output - captured = capsys.readouterr() - output = captured.out - - # Re-print the output so it's visible in terminal with -s flag - print(output) - - # Verify the policy header is printed - assert "S3 Bucket Policy Update recommended" in output - assert "=" * 80 in output - - # Verify bucket information is printed - # Extract bucket name from s3_uri (s3://bucket/path -> bucket) - expected_bucket = s3_uri.replace("s3://", "").split("/")[0] - assert f"Bucket: {expected_bucket}" in output - - # Verify policy structure elements are present - assert '"Version": "2012-10-17"' in output - assert '"Statement"' in output - assert '"Effect": "Deny"' in output - assert '"Principal": "*"' in output - - # Verify the deny actions are present - assert "s3:GetObject" in output - assert "s3:PutObject" in output - assert "s3:DeleteObject" in output - assert "s3:ListBucket" in output - - # Verify the condition structure is present - assert "StringNotEquals" in output - assert "aws:PrincipalArn" in output - - # Verify the role ARN is in the allowed principals - assert role in output - - # Verify the service-linked role pattern is present (default use_service_linked_role=True) - assert "aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" in output - - # Verify instructions are printed - assert "Merge this with your existing bucket policy" in output - - finally: - # Cleanup - if fg: - cleanup_feature_group(fg) - - -@pytest.mark.serial -@pytest.mark.slow_test -def test_enable_lake_formation_no_policy_output_by_default(s3_uri, role, region, capsys): - """ - Test that S3 deny policy is NOT printed when show_s3_policy=False (default). - - This test verifies: - 1. Creates a FeatureGroup with offline store - 2. Enables Lake Formation without show_s3_policy (defaults to False) - 3. Verifies all Lake Formation phases complete successfully - 4. Verifies the S3 deny policy is NOT printed to the console - - This validates Requirement 6.2 from the design document. - """ - fg_name = generate_feature_group_name() - fg = None - - try: - # Create the FeatureGroup - fg = create_test_feature_group(fg_name, s3_uri, role, region) - assert fg is not None - - # Wait for Created status - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - assert fg.feature_group_status == "Created" - - # Enable Lake Formation governance WITHOUT policy output (default) - result = fg.enable_lake_formation() - - # Verify all phases completed successfully - assert result["s3_registration"] is True - assert result["permissions_granted"] is True - assert result["iam_principal_revoked"] is True - - # Capture the printed output - captured = capsys.readouterr() - output = captured.out - - # Verify the policy is NOT printed - assert "S3 Bucket Policy Update recommended" not in output - assert '"Version": "2012-10-17"' not in output - assert "s3:GetObject" not in output - - finally: - # Cleanup - if fg: - cleanup_feature_group(fg) - - -@pytest.mark.serial -@pytest.mark.slow_test -def test_enable_lake_formation_with_custom_role_policy_output(s3_uri, role, region, capsys): - """ - Test the full Lake Formation flow with custom registration role and policy output. - - This test verifies: - 1. Creates a FeatureGroup with offline store - 2. Enables Lake Formation with use_service_linked_role=False and a custom registration_role_arn - 3. Verifies the S3 deny policy uses the custom role ARN instead of service-linked role - - This validates Requirements 6.4, 6.5 from the design document. - - Note: This test uses the same execution role as the registration role for simplicity. - In production, these would typically be different roles. - """ - fg_name = generate_feature_group_name() - fg = None - - try: - # Create the FeatureGroup - fg = create_test_feature_group(fg_name, s3_uri, role, region) - assert fg is not None - - # Wait for Created status - fg.wait_for_status(target_status="Created", poll=30, timeout=300) - assert fg.feature_group_status == "Created" - - # Enable Lake Formation with custom registration role and policy output - # Using the same role for both execution and registration for test simplicity - result = fg.enable_lake_formation( - use_service_linked_role=False, - registration_role_arn=role, - show_s3_policy=True, - ) - - # Verify all phases completed successfully - assert result["s3_registration"] is True - assert result["permissions_granted"] is True - assert result["iam_principal_revoked"] is True - - # Capture the printed output - captured = capsys.readouterr() - output = captured.out - - # Verify the policy header is printed - assert "S3 Bucket Policy Update recommended" in output - - # Verify the custom role ARN is used in the policy (appears twice - once for each principal) - # The role should appear as both the Lake Formation role and the Feature Store role - assert output.count(role) >= 2 - - # Verify the service-linked role is NOT used - assert "aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" not in output - - finally: - # Cleanup - if fg: - cleanup_feature_group(fg) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py deleted file mode 100644 index e4d44df37a..0000000000 --- a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_lakeformation.py +++ /dev/null @@ -1,2141 +0,0 @@ -"""Unit tests for Lake Formation integration with FeatureGroup.""" -from unittest.mock import MagicMock, patch - -import botocore.exceptions -import pytest - -from sagemaker.mlops.feature_store import FeatureGroup, LakeFormationConfig - - -class TestS3UriToArn: - """Tests for _s3_uri_to_arn static method.""" - - def test_converts_s3_uri_to_arn(self): - """Test S3 URI is converted to ARN format.""" - uri = "s3://my-bucket/my-prefix/data" - result = FeatureGroup._s3_uri_to_arn(uri) - assert result == "arn:aws:s3:::my-bucket/my-prefix/data" - - def test_handles_bucket_only_uri(self): - """Test S3 URI with bucket only.""" - uri = "s3://my-bucket" - result = FeatureGroup._s3_uri_to_arn(uri) - assert result == "arn:aws:s3:::my-bucket" - - def test_returns_arn_unchanged(self): - """Test ARN input is returned unchanged (idempotent).""" - arn = "arn:aws:s3:::my-bucket/path" - result = FeatureGroup._s3_uri_to_arn(arn) - assert result == arn - - def test_uses_region_for_partition(self): - """Test that region is used to determine partition.""" - uri = "s3://my-bucket/path" - result = FeatureGroup._s3_uri_to_arn(uri, region="cn-north-1") - assert result.startswith("arn:aws-cn:s3:::") - - - -class TestGetLakeFormationClient: - """Tests for _get_lake_formation_client method.""" - - @patch("sagemaker.mlops.feature_store.feature_group.Session") - def test_creates_client_with_default_session(self, mock_session_class): - """Test client creation with default session.""" - mock_session = MagicMock() - mock_client = MagicMock() - mock_session.client.return_value = mock_client - mock_session_class.return_value = mock_session - - fg = MagicMock(spec=FeatureGroup) - fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg) - - client = fg._get_lake_formation_client(region="us-west-2") - - mock_session.client.assert_called_with("lakeformation", region_name="us-west-2") - assert client == mock_client - - def test_creates_client_with_provided_session(self): - """Test client creation with provided session.""" - mock_session = MagicMock() - mock_client = MagicMock() - mock_session.client.return_value = mock_client - - fg = MagicMock(spec=FeatureGroup) - fg._get_lake_formation_client = FeatureGroup._get_lake_formation_client.__get__(fg) - - client = fg._get_lake_formation_client(session=mock_session, region="us-west-2") - - mock_session.client.assert_called_with("lakeformation", region_name="us-west-2") - assert client == mock_client - - -class TestRegisterS3WithLakeFormation: - """Tests for _register_s3_with_lake_formation method.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fg = MagicMock(spec=FeatureGroup) - self.fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn - self.fg._register_s3_with_lake_formation = ( - FeatureGroup._register_s3_with_lake_formation.__get__(self.fg) - ) - self.mock_client = MagicMock() - self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) - - def test_successful_registration_returns_true(self): - """Test successful S3 registration returns True.""" - self.mock_client.register_resource.return_value = {} - - result = self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") - - assert result is True - self.mock_client.register_resource.assert_called_with( - ResourceArn="arn:aws:s3:::test-bucket/prefix", - UseServiceLinkedRole=True, - ) - - def test_already_exists_exception_returns_true(self): - """Test AlreadyExistsException is handled gracefully.""" - self.mock_client.register_resource.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "AlreadyExistsException", "Message": "Already exists"}}, - "RegisterResource", - ) - - result = self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") - - assert result is True - - def test_other_exceptions_are_propagated(self): - """Test non-AlreadyExistsException errors are propagated.""" - self.mock_client.register_resource.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, - "RegisterResource", - ) - - with pytest.raises(botocore.exceptions.ClientError) as exc_info: - self.fg._register_s3_with_lake_formation("s3://test-bucket/prefix") - - assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" - - def test_uses_service_linked_role(self): - """Test UseServiceLinkedRole is set to True.""" - self.mock_client.register_resource.return_value = {} - - self.fg._register_s3_with_lake_formation("s3://bucket/path") - - call_args = self.mock_client.register_resource.call_args - assert call_args[1]["UseServiceLinkedRole"] is True - - def test_uses_custom_role_arn_when_service_linked_role_disabled(self): - """Test custom role ARN is used when use_service_linked_role is False.""" - self.mock_client.register_resource.return_value = {} - custom_role = "arn:aws:iam::123456789012:role/CustomLakeFormationRole" - - self.fg._register_s3_with_lake_formation( - "s3://bucket/path", - use_service_linked_role=False, - role_arn=custom_role, - ) - - call_args = self.mock_client.register_resource.call_args - assert call_args[1]["RoleArn"] == custom_role - assert "UseServiceLinkedRole" not in call_args[1] - - def test_raises_error_when_role_arn_missing_and_service_linked_role_disabled(self): - """Test ValueError when use_service_linked_role is False but role_arn not provided.""" - with pytest.raises(ValueError) as exc_info: - self.fg._register_s3_with_lake_formation( - "s3://bucket/path", use_service_linked_role=False - ) - - assert "role_arn must be provided when use_service_linked_role is False" in str( - exc_info.value - ) - - - -class TestRevokeIamAllowedPrincipal: - """Tests for _revoke_iam_allowed_principal method.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fg = MagicMock(spec=FeatureGroup) - self.fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__( - self.fg - ) - self.mock_client = MagicMock() - self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) - - def test_successful_revocation_returns_true(self): - """Test successful revocation returns True.""" - self.mock_client.revoke_permissions.return_value = {} - - result = self.fg._revoke_iam_allowed_principal("test_database", "test_table") - - assert result is True - self.mock_client.revoke_permissions.assert_called_once() - - def test_revoke_permissions_call_structure(self): - """Test that revoke_permissions is called with correct parameters.""" - self.mock_client.revoke_permissions.return_value = {} - database_name = "my_database" - table_name = "my_table" - - self.fg._revoke_iam_allowed_principal(database_name, table_name) - - call_args = self.mock_client.revoke_permissions.call_args - assert call_args[1]["Principal"] == { - "DataLakePrincipalIdentifier": "IAM_ALLOWED_PRINCIPALS" - } - assert call_args[1]["Permissions"] == ["ALL"] - assert call_args[1]["Resource"] == { - "Table": { - "DatabaseName": database_name, - "Name": table_name, - } - } - - def test_invalid_input_exception_returns_true(self): - """Test InvalidInputException is handled gracefully (permissions may not exist).""" - self.mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "InvalidInputException", "Message": "Permissions not found"}}, - "RevokePermissions", - ) - - result = self.fg._revoke_iam_allowed_principal("test_database", "test_table") - - assert result is True - - def test_other_exceptions_are_propagated(self): - """Test non-InvalidInputException errors are propagated.""" - self.mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, - "RevokePermissions", - ) - - with pytest.raises(botocore.exceptions.ClientError) as exc_info: - self.fg._revoke_iam_allowed_principal("test_database", "test_table") - - assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" - - def test_passes_session_and_region_to_client(self): - """Test session and region are passed to get_lake_formation_client.""" - self.mock_client.revoke_permissions.return_value = {} - mock_session = MagicMock() - - self.fg._revoke_iam_allowed_principal( - "test_database", "test_table", session=mock_session, region="us-west-2" - ) - - self.fg._get_lake_formation_client.assert_called_with(mock_session, "us-west-2") - - - -class TestGrantLakeFormationPermissions: - """Tests for _grant_lake_formation_permissions method.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fg = MagicMock(spec=FeatureGroup) - self.fg._grant_lake_formation_permissions = ( - FeatureGroup._grant_lake_formation_permissions.__get__(self.fg) - ) - self.mock_client = MagicMock() - self.fg._get_lake_formation_client = MagicMock(return_value=self.mock_client) - - def test_successful_grant_returns_true(self): - """Test successful permission grant returns True.""" - self.mock_client.grant_permissions.return_value = {} - - result = self.fg._grant_lake_formation_permissions( - "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" - ) - - assert result is True - self.mock_client.grant_permissions.assert_called_once() - - def test_grant_permissions_call_structure(self): - """Test that grant_permissions is called with correct parameters.""" - self.mock_client.grant_permissions.return_value = {} - role_arn = "arn:aws:iam::123456789012:role/MyExecutionRole" - - self.fg._grant_lake_formation_permissions(role_arn, "my_database", "my_table") - - call_args = self.mock_client.grant_permissions.call_args - assert call_args[1]["Principal"] == {"DataLakePrincipalIdentifier": role_arn} - assert call_args[1]["Resource"] == { - "Table": { - "DatabaseName": "my_database", - "Name": "my_table", - } - } - assert call_args[1]["Permissions"] == ["SELECT", "INSERT", "DELETE", "DESCRIBE", "ALTER"] - assert call_args[1]["PermissionsWithGrantOption"] == [] - - def test_invalid_input_exception_returns_true(self): - """Test InvalidInputException is handled gracefully (permissions may exist).""" - self.mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "InvalidInputException", "Message": "Permissions already exist"}}, - "GrantPermissions", - ) - - result = self.fg._grant_lake_formation_permissions( - "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" - ) - - assert result is True - - def test_other_exceptions_are_propagated(self): - """Test non-InvalidInputException errors are propagated.""" - self.mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "AccessDeniedException", "Message": "Access denied"}}, - "GrantPermissions", - ) - - with pytest.raises(botocore.exceptions.ClientError) as exc_info: - self.fg._grant_lake_formation_permissions( - "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" - ) - - assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" - - def test_passes_session_and_region_to_client(self): - """Test session and region are passed to get_lake_formation_client.""" - self.mock_client.grant_permissions.return_value = {} - mock_session = MagicMock() - - self.fg._grant_lake_formation_permissions( - "arn:aws:iam::123456789012:role/TestRole", - "test_database", - "test_table", - session=mock_session, - region="us-west-2", - ) - - self.fg._get_lake_formation_client.assert_called_with(mock_session, "us-west-2") - - - -class TestEnableLakeFormationValidation: - """Tests for enable_lake_formation validation logic.""" - - @patch.object(FeatureGroup, "refresh") - def test_raises_error_when_no_offline_store(self, mock_refresh): - """Test that enable_lake_formation raises ValueError when no offline store is configured.""" - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = None - fg.feature_group_status = "Created" - - with pytest.raises(ValueError, match="does not have an offline store configured"): - fg.enable_lake_formation() - - # Verify refresh was called - mock_refresh.assert_called_once() - - @patch.object(FeatureGroup, "refresh") - def test_raises_error_when_no_role_arn(self, mock_refresh): - """Test that enable_lake_formation raises ValueError when no role_arn is configured.""" - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = None - fg.feature_group_status = "Created" - - with pytest.raises(ValueError, match="does not have a role_arn configured"): - fg.enable_lake_formation() - - # Verify refresh was called - mock_refresh.assert_called_once() - - @patch.object(FeatureGroup, "refresh") - def test_raises_error_when_invalid_status(self, mock_refresh): - """Test enable_lake_formation raises ValueError when Feature Group not in Created status.""" - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" - fg.feature_group_status = "Creating" - - with pytest.raises(ValueError, match="must be in 'Created' status"): - fg.enable_lake_formation() - - # Verify refresh was called - mock_refresh.assert_called_once() - - @patch.object(FeatureGroup, "wait_for_status") - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - def test_wait_for_active_calls_wait_for_status( - self, mock_revoke, mock_grant, mock_register, mock_refresh, mock_wait - ): - """Test that wait_for_active=True calls wait_for_status with 'Created' target.""" - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call with wait_for_active=True - fg.enable_lake_formation(wait_for_active=True) - - # Verify wait_for_status was called with "Created" - mock_wait.assert_called_once_with(target_status="Created") - # Verify refresh was called after wait - mock_refresh.assert_called_once() - - @patch.object(FeatureGroup, "wait_for_status") - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - def test_wait_for_active_false_does_not_call_wait( - self, mock_revoke, mock_grant, mock_register, mock_refresh, mock_wait - ): - """Test that wait_for_active=False does not call wait_for_status.""" - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/TestRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call with wait_for_active=False (default) - fg.enable_lake_formation(wait_for_active=False) - - # Verify wait_for_status was NOT called - mock_wait.assert_not_called() - # Verify refresh was still called - mock_refresh.assert_called_once() - - - @pytest.mark.parametrize( - "feature_group_name,role_arn,s3_uri,database_name,table_name", - [ - ("test-fg", "TestRole", "path1", "db1", "table1"), - ("my_feature_group", "ExecutionRole", "data/features", "feature_db", "feature_table"), - ("fg123", "MyRole123", "ml/features/v1", "analytics", "features_v1"), - ("simple", "SimpleRole", "simple-path", "simple_db", "simple_table"), - ( - "complex-name", - "ComplexExecutionRole", - "complex/path/structure", - "complex_database", - "complex_table_name", - ), - ( - "underscore_name", - "Underscore_Role", - "underscore_path", - "underscore_db", - "underscore_table", - ), - ("mixed-123", "Mixed123Role", "mixed/path/123", "mixed_db_123", "mixed_table_123"), - ("x", "XRole", "x", "x", "x"), - ( - "very-long-name", - "VeryLongRoleName", - "very/long/path/structure", - "very_long_database_name", - "very_long_table_name", - ), - ], - ) - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - def test_fail_fast_phase_execution( - self, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - feature_group_name, - role_arn, - s3_uri, - database_name, - table_name, - ): - """ - Test fail-fast behavior for Lake Formation phases. - - If Phase 1 (S3 registration) fails, Phase 2 and 3 should not execute. - If Phase 2 fails, Phase 3 should not execute. - RuntimeError should indicate which phase failed. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - fg = FeatureGroup(feature_group_name=feature_group_name) - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri=f"s3://test-bucket/{s3_uri}", - resolved_output_s3_uri=f"s3://test-bucket/resolved-{s3_uri}", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database=database_name, table_name=table_name - ), - ) - fg.role_arn = f"arn:aws:iam::123456789012:role/{role_arn}" - fg.feature_group_status = "Created" - - # Test Phase 1 failure - subsequent phases should not be called - mock_register.side_effect = Exception("Phase 1 failed") - mock_grant.return_value = True - mock_revoke.return_value = True - - with pytest.raises( - RuntimeError, match="Failed to register S3 location with Lake Formation" - ): - fg.enable_lake_formation() - - # Verify Phase 1 was called but Phase 2 and 3 were not - mock_register.assert_called_once() - mock_grant.assert_not_called() - mock_revoke.assert_not_called() - - # Reset mocks for Phase 2 failure test - mock_register.reset_mock() - mock_grant.reset_mock() - mock_revoke.reset_mock() - - # Test Phase 2 failure - Phase 3 should not be called - mock_register.side_effect = None - mock_register.return_value = True - mock_grant.side_effect = Exception("Phase 2 failed") - mock_revoke.return_value = True - - with pytest.raises(RuntimeError, match="Failed to grant Lake Formation permissions"): - fg.enable_lake_formation() - - # Verify Phase 1 and 2 were called but Phase 3 was not - mock_register.assert_called_once() - mock_grant.assert_called_once() - mock_revoke.assert_not_called() - - # Reset mocks for Phase 3 failure test - mock_register.reset_mock() - mock_grant.reset_mock() - mock_revoke.reset_mock() - - # Test Phase 3 failure - all phases should be called - mock_register.side_effect = None - mock_register.return_value = True - mock_grant.side_effect = None - mock_grant.return_value = True - mock_revoke.side_effect = Exception("Phase 3 failed") - - with pytest.raises(RuntimeError, match="Failed to revoke IAMAllowedPrincipal permissions"): - fg.enable_lake_formation() - - # Verify all phases were called - mock_register.assert_called_once() - mock_grant.assert_called_once() - mock_revoke.assert_called_once() - - - -class TestUnhandledExceptionPropagation: - """Tests for proper propagation of unhandled boto3 exceptions.""" - - def test_register_s3_propagates_unhandled_exceptions(self): - """ - Non-AlreadyExists Errors Propagate from S3 Registration - - For any error from Lake Formation's register_resource API that is not - AlreadyExistsException, the error should be propagated to the caller unchanged. - - """ - fg = MagicMock(spec=FeatureGroup) - fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn - fg._register_s3_with_lake_formation = FeatureGroup._register_s3_with_lake_formation.__get__( - fg - ) - mock_client = MagicMock() - fg._get_lake_formation_client = MagicMock(return_value=mock_client) - - # Configure mock to raise an unhandled error - mock_client.register_resource.side_effect = botocore.exceptions.ClientError( - { - "Error": { - "Code": "AccessDeniedException", - "Message": "User does not have permission", - } - }, - "RegisterResource", - ) - - # Verify the exception is propagated unchanged - with pytest.raises(botocore.exceptions.ClientError) as exc_info: - fg._register_s3_with_lake_formation("s3://test-bucket/path") - - # Verify error details are preserved - assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" - assert exc_info.value.response["Error"]["Message"] == "User does not have permission" - assert exc_info.value.operation_name == "RegisterResource" - - def test_revoke_iam_principal_propagates_unhandled_exceptions(self): - """ - Non-InvalidInput Errors Propagate from IAM Principal Revocation - - For any error from Lake Formation's revoke_permissions API that is not - InvalidInputException, the error should be propagated to the caller unchanged. - - """ - fg = MagicMock(spec=FeatureGroup) - fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__(fg) - mock_client = MagicMock() - fg._get_lake_formation_client = MagicMock(return_value=mock_client) - - # Configure mock to raise an unhandled error - mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( - { - "Error": { - "Code": "AccessDeniedException", - "Message": "User does not have permission", - } - }, - "RevokePermissions", - ) - - # Verify the exception is propagated unchanged - with pytest.raises(botocore.exceptions.ClientError) as exc_info: - fg._revoke_iam_allowed_principal("test_database", "test_table") - - # Verify error details are preserved - assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" - assert exc_info.value.response["Error"]["Message"] == "User does not have permission" - assert exc_info.value.operation_name == "RevokePermissions" - - def test_grant_permissions_propagates_unhandled_exceptions(self): - """ - Non-InvalidInput Errors Propagate from Permission Grant - - For any error from Lake Formation's grant_permissions API that is not - InvalidInputException, the error should be propagated to the caller unchanged. - - """ - fg = MagicMock(spec=FeatureGroup) - fg._grant_lake_formation_permissions = ( - FeatureGroup._grant_lake_formation_permissions.__get__(fg) - ) - mock_client = MagicMock() - fg._get_lake_formation_client = MagicMock(return_value=mock_client) - - # Configure mock to raise an unhandled error - mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( - { - "Error": { - "Code": "AccessDeniedException", - "Message": "User does not have permission", - } - }, - "GrantPermissions", - ) - - # Verify the exception is propagated unchanged - with pytest.raises(botocore.exceptions.ClientError) as exc_info: - fg._grant_lake_formation_permissions( - "arn:aws:iam::123456789012:role/TestRole", "test_database", "test_table" - ) - - # Verify error details are preserved - assert exc_info.value.response["Error"]["Code"] == "AccessDeniedException" - assert exc_info.value.response["Error"]["Message"] == "User does not have permission" - assert exc_info.value.operation_name == "GrantPermissions" - - def test_handled_exceptions_do_not_propagate(self): - """ - Verify that specifically handled exceptions (AlreadyExistsException, InvalidInputException) - do NOT propagate but return True instead, while all other exceptions are propagated. - """ - fg = MagicMock(spec=FeatureGroup) - fg._s3_uri_to_arn = FeatureGroup._s3_uri_to_arn - fg._register_s3_with_lake_formation = FeatureGroup._register_s3_with_lake_formation.__get__( - fg - ) - fg._revoke_iam_allowed_principal = FeatureGroup._revoke_iam_allowed_principal.__get__(fg) - fg._grant_lake_formation_permissions = ( - FeatureGroup._grant_lake_formation_permissions.__get__(fg) - ) - mock_client = MagicMock() - fg._get_lake_formation_client = MagicMock(return_value=mock_client) - - # Test AlreadyExistsException is handled (not propagated) - mock_client.register_resource.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "AlreadyExistsException", "Message": "Already exists"}}, - "RegisterResource", - ) - result = fg._register_s3_with_lake_formation("s3://test-bucket/path") - assert result is True # Should return True, not raise - - # Test InvalidInputException is handled for revoke (not propagated) - mock_client.revoke_permissions.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "InvalidInputException", "Message": "Invalid input"}}, - "RevokePermissions", - ) - result = fg._revoke_iam_allowed_principal("db", "table") - assert result is True # Should return True, not raise - - # Test InvalidInputException is handled for grant (not propagated) - mock_client.grant_permissions.side_effect = botocore.exceptions.ClientError( - {"Error": {"Code": "InvalidInputException", "Message": "Invalid input"}}, - "GrantPermissions", - ) - result = fg._grant_lake_formation_permissions( - "arn:aws:iam::123456789012:role/TestRole", "db", "table" - ) - assert result is True # Should return True, not raise - - - -class TestCreateWithLakeFormation: - """Tests for create() method with Lake Formation integration.""" - - @pytest.mark.parametrize( - "feature_group_name,record_id_feature,event_time_feature", - [ - ("test-fg", "record_id", "event_time"), - ("my_feature_group", "id", "timestamp"), - ("fg123", "identifier", "time"), - ("simple", "rec_id", "evt_time"), - ("complex-name", "record_identifier", "event_timestamp"), - ("underscore_name", "record_id_field", "event_time_field"), - ("mixed-123", "id_123", "time_123"), - ("x", "x_id", "x_time"), - ("very-long-name", "very_long_record_id", "very_long_event_time"), - ], - ) - @patch("sagemaker.core.resources.Base.get_sagemaker_client") - @patch.object(FeatureGroup, "get") - @patch.object(FeatureGroup, "wait_for_status") - @patch.object(FeatureGroup, "enable_lake_formation") - def test_no_lake_formation_operations_when_disabled( - self, - mock_enable_lf, - mock_wait, - mock_get, - mock_get_client, - feature_group_name, - record_id_feature, - event_time_feature, - ): - """ - No Lake Formation Operations When Disabled - - For any call to FeatureGroup.create() where lake_formation_config is None or has enabled=False, - no Lake Formation client methods should be invoked. - - """ - from sagemaker.core.shapes import FeatureDefinition - - # Mock the SageMaker client - mock_client = MagicMock() - mock_client.create_feature_group.return_value = { - "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" - } - mock_get_client.return_value = mock_client - - # Mock the get method to return a feature group - mock_fg = MagicMock(spec=FeatureGroup) - mock_fg.feature_group_name = feature_group_name - mock_get.return_value = mock_fg - - # Create feature definitions - feature_definitions = [ - FeatureDefinition(feature_name=record_id_feature, feature_type="String"), - FeatureDefinition(feature_name=event_time_feature, feature_type="String"), - ] - - # Test 1: lake_formation_config with enabled=False (explicit) - lf_config = LakeFormationConfig() - lf_config.enabled = False - result = FeatureGroup.create( - feature_group_name=feature_group_name, - record_identifier_feature_name=record_id_feature, - event_time_feature_name=event_time_feature, - feature_definitions=feature_definitions, - lake_formation_config=lf_config, - ) - - # Verify enable_lake_formation was NOT called - mock_enable_lf.assert_not_called() - # Verify wait_for_status was NOT called - mock_wait.assert_not_called() - # Verify the feature group was returned - assert result == mock_fg - - # Reset mocks for next test - mock_enable_lf.reset_mock() - mock_wait.reset_mock() - mock_get.reset_mock() - mock_get.return_value = mock_fg - - # Test 2: lake_formation_config not specified (defaults to None) - result = FeatureGroup.create( - feature_group_name=feature_group_name, - record_identifier_feature_name=record_id_feature, - event_time_feature_name=event_time_feature, - feature_definitions=feature_definitions, - # lake_formation_config not specified, should default to None - ) - - # Verify enable_lake_formation was NOT called - mock_enable_lf.assert_not_called() - # Verify wait_for_status was NOT called - mock_wait.assert_not_called() - # Verify the feature group was returned - assert result == mock_fg - - @pytest.mark.parametrize( - "feature_group_name,record_id_feature,event_time_feature,role_arn,s3_uri,database,table", - [ - ("test-fg", "record_id", "event_time", "TestRole", "path1", "db1", "table1"), - ( - "my_feature_group", - "id", - "timestamp", - "ExecutionRole", - "data/features", - "feature_db", - "feature_table", - ), - ( - "fg123", - "identifier", - "time", - "MyRole123", - "ml/features/v1", - "analytics", - "features_v1", - ), - ], - ) - @patch("sagemaker.core.resources.Base.get_sagemaker_client") - @patch.object(FeatureGroup, "get") - @patch.object(FeatureGroup, "wait_for_status") - @patch.object(FeatureGroup, "enable_lake_formation") - def test_enable_lake_formation_called_when_enabled( - self, - mock_enable_lf, - mock_wait, - mock_get, - mock_get_client, - feature_group_name, - record_id_feature, - event_time_feature, - role_arn, - s3_uri, - database, - table, - ): - """ - Test that enable_lake_formation is called when lake_formation_config has enabled=True. - - This verifies the integration between create() and enable_lake_formation(). - """ - from sagemaker.core.shapes import ( - FeatureDefinition, - OfflineStoreConfig, - S3StorageConfig, - DataCatalogConfig, - ) - - # Mock the SageMaker client - mock_client = MagicMock() - mock_client.create_feature_group.return_value = { - "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" - } - mock_get_client.return_value = mock_client - - # Mock the get method to return a feature group - mock_fg = MagicMock(spec=FeatureGroup) - mock_fg.feature_group_name = feature_group_name - mock_fg.wait_for_status = mock_wait - mock_fg.enable_lake_formation = mock_enable_lf - mock_get.return_value = mock_fg - - # Create feature definitions - feature_definitions = [ - FeatureDefinition(feature_name=record_id_feature, feature_type="String"), - FeatureDefinition(feature_name=event_time_feature, feature_type="String"), - ] - - # Create offline store config - offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database=database, table_name=table - ), - ) - - # Create LakeFormationConfig with enabled=True - lf_config = LakeFormationConfig() - lf_config.enabled = True - - # Create with lake_formation_config enabled=True - result = FeatureGroup.create( - feature_group_name=feature_group_name, - record_identifier_feature_name=record_id_feature, - event_time_feature_name=event_time_feature, - feature_definitions=feature_definitions, - offline_store_config=offline_store_config, - role_arn=f"arn:aws:iam::123456789012:role/{role_arn}", - lake_formation_config=lf_config, - ) - - # Verify wait_for_status was called with "Created" - mock_wait.assert_called_once_with(target_status="Created") - # Verify enable_lake_formation was called with default use_service_linked_role=True - mock_enable_lf.assert_called_once_with( - session=None, - region=None, - use_service_linked_role=True, - registration_role_arn=None, - show_s3_policy=False, - ) - # Verify the feature group was returned - assert result == mock_fg - - @pytest.mark.parametrize( - "feature_group_name,record_id_feature,event_time_feature", - [ - ("test-fg", "record_id", "event_time"), - ("my_feature_group", "id", "timestamp"), - ("fg123", "identifier", "time"), - ], - ) - @patch("sagemaker.core.resources.Base.get_sagemaker_client") - def test_validation_error_when_lake_formation_enabled_without_offline_store( - self, mock_get_client, feature_group_name, record_id_feature, event_time_feature - ): - """Test create() raises ValueError when lake_formation_config enabled=True without offline_store.""" - from sagemaker.core.shapes import FeatureDefinition - - # Mock the SageMaker client - mock_client = MagicMock() - mock_get_client.return_value = mock_client - - # Create feature definitions - feature_definitions = [ - FeatureDefinition(feature_name=record_id_feature, feature_type="String"), - FeatureDefinition(feature_name=event_time_feature, feature_type="String"), - ] - - # Create LakeFormationConfig with enabled=True - lf_config = LakeFormationConfig() - lf_config.enabled = True - - # Test with lake_formation_config enabled=True but no offline_store_config - with pytest.raises( - ValueError, - match="lake_formation_config with enabled=True requires offline_store_config to be configured", - ): - FeatureGroup.create( - feature_group_name=feature_group_name, - record_identifier_feature_name=record_id_feature, - event_time_feature_name=event_time_feature, - feature_definitions=feature_definitions, - lake_formation_config=lf_config, - # offline_store_config not provided - ) - - @pytest.mark.parametrize( - "feature_group_name,record_id_feature,event_time_feature,s3_uri,database,table", - [ - ("test-fg", "record_id", "event_time", "path1", "db1", "table1"), - ("my_feature_group", "id", "timestamp", "data/features", "feature_db", "feature_table"), - ("fg123", "identifier", "time", "ml/features/v1", "analytics", "features_v1"), - ], - ) - @patch("sagemaker.core.resources.Base.get_sagemaker_client") - def test_validation_error_when_lake_formation_enabled_without_role_arn( - self, - mock_get_client, - feature_group_name, - record_id_feature, - event_time_feature, - s3_uri, - database, - table, - ): - """Test create() raises ValueError when lake_formation_config enabled=True without role_arn.""" - from sagemaker.core.shapes import ( - FeatureDefinition, - OfflineStoreConfig, - S3StorageConfig, - DataCatalogConfig, - ) - - # Mock the SageMaker client - mock_client = MagicMock() - mock_get_client.return_value = mock_client - - # Create feature definitions - feature_definitions = [ - FeatureDefinition(feature_name=record_id_feature, feature_type="String"), - FeatureDefinition(feature_name=event_time_feature, feature_type="String"), - ] - - # Create offline store config - offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database=database, table_name=table - ), - ) - - # Create LakeFormationConfig with enabled=True - lf_config = LakeFormationConfig() - lf_config.enabled = True - - # Test with lake_formation_config enabled=True but no role_arn - with pytest.raises( - ValueError, match="lake_formation_config with enabled=True requires role_arn to be specified" - ): - FeatureGroup.create( - feature_group_name=feature_group_name, - record_identifier_feature_name=record_id_feature, - event_time_feature_name=event_time_feature, - feature_definitions=feature_definitions, - offline_store_config=offline_store_config, - lake_formation_config=lf_config, - # role_arn not provided - ) - - - @pytest.mark.parametrize( - "feature_group_name,record_id_feature,event_time_feature,role_arn,s3_uri,database,table,use_slr", - [ - ("test-fg", "record_id", "event_time", "TestRole", "path1", "db1", "table1", True), - ("my_feature_group", "id", "timestamp", "ExecutionRole", "data/features", "feature_db", "feature_table", False), - ("fg123", "identifier", "time", "MyRole123", "ml/features/v1", "analytics", "features_v1", True), - ], - ) - @patch("sagemaker.core.resources.Base.get_sagemaker_client") - @patch.object(FeatureGroup, "get") - @patch.object(FeatureGroup, "wait_for_status") - @patch.object(FeatureGroup, "enable_lake_formation") - def test_use_service_linked_role_extraction_from_config( - self, - mock_enable_lf, - mock_wait, - mock_get, - mock_get_client, - feature_group_name, - record_id_feature, - event_time_feature, - role_arn, - s3_uri, - database, - table, - use_slr, - ): - """ - Test that use_service_linked_role is correctly extracted from lake_formation_config. - - Verifies: - - use_service_linked_role defaults to True when not specified - - use_service_linked_role is passed correctly to enable_lake_formation() - """ - from sagemaker.core.shapes import ( - FeatureDefinition, - OfflineStoreConfig, - S3StorageConfig, - DataCatalogConfig, - ) - - # Mock the SageMaker client - mock_client = MagicMock() - mock_client.create_feature_group.return_value = { - "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" - } - mock_get_client.return_value = mock_client - - # Mock the get method to return a feature group - mock_fg = MagicMock(spec=FeatureGroup) - mock_fg.feature_group_name = feature_group_name - mock_fg.wait_for_status = mock_wait - mock_fg.enable_lake_formation = mock_enable_lf - mock_get.return_value = mock_fg - - # Create feature definitions - feature_definitions = [ - FeatureDefinition(feature_name=record_id_feature, feature_type="String"), - FeatureDefinition(feature_name=event_time_feature, feature_type="String"), - ] - - # Create offline store config - offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig(s3_uri=f"s3://test-bucket/{s3_uri}"), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database=database, table_name=table - ), - ) - - # Build LakeFormationConfig with use_service_linked_role - lf_config = LakeFormationConfig() - lf_config.enabled = True - lf_config.use_service_linked_role = use_slr - # When use_service_linked_role is False, registration_role_arn is required - expected_registration_role = None - if not use_slr: - lf_config.registration_role_arn = "arn:aws:iam::123456789012:role/LFRegistrationRole" - expected_registration_role = "arn:aws:iam::123456789012:role/LFRegistrationRole" - - # Create with lake_formation_config - result = FeatureGroup.create( - feature_group_name=feature_group_name, - record_identifier_feature_name=record_id_feature, - event_time_feature_name=event_time_feature, - feature_definitions=feature_definitions, - offline_store_config=offline_store_config, - role_arn=f"arn:aws:iam::123456789012:role/{role_arn}", - lake_formation_config=lf_config, - ) - - # Verify enable_lake_formation was called with correct use_service_linked_role value - mock_enable_lf.assert_called_once_with( - session=None, - region=None, - use_service_linked_role=use_slr, - registration_role_arn=expected_registration_role, - show_s3_policy=False, - ) - # Verify the feature group was returned - assert result == mock_fg - - -class TestExtractAccountIdFromArn: - """Tests for _extract_account_id_from_arn static method.""" - - def test_extracts_account_id_from_sagemaker_arn(self): - """Test extracting account ID from a SageMaker Feature Group ARN.""" - arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-feature-group" - result = FeatureGroup._extract_account_id_from_arn(arn) - assert result == "123456789012" - - def test_raises_value_error_for_invalid_arn_too_few_parts(self): - """Test that ValueError is raised for ARN with fewer than 5 colon-separated parts.""" - invalid_arn = "arn:aws:sagemaker:us-west-2" # Only 4 parts - with pytest.raises(ValueError, match="Invalid ARN format"): - FeatureGroup._extract_account_id_from_arn(invalid_arn) - - def test_raises_value_error_for_empty_string(self): - """Test that ValueError is raised for empty string.""" - with pytest.raises(ValueError, match="Invalid ARN format"): - FeatureGroup._extract_account_id_from_arn("") - - def test_raises_value_error_for_non_arn_string(self): - """Test that ValueError is raised for non-ARN string.""" - with pytest.raises(ValueError, match="Invalid ARN format"): - FeatureGroup._extract_account_id_from_arn("not-an-arn") - - def test_raises_value_error_for_s3_uri(self): - """Test that ValueError is raised for S3 URI (not ARN).""" - with pytest.raises(ValueError, match="Invalid ARN format"): - FeatureGroup._extract_account_id_from_arn("s3://my-bucket/my-prefix") - - def test_handles_arn_with_resource_path(self): - """Test extracting account ID from ARN with complex resource path.""" - arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/my-fg/version/1" - result = FeatureGroup._extract_account_id_from_arn(arn) - assert result == "123456789012" - - -class TestGetLakeFormationServiceLinkedRoleArn: - """Tests for _get_lake_formation_service_linked_role_arn static method.""" - - def test_generates_correct_service_linked_role_arn(self): - """Test that the method generates the correct service-linked role ARN format.""" - account_id = "123456789012" - result = FeatureGroup._get_lake_formation_service_linked_role_arn(account_id) - expected = "arn:aws:iam::123456789012:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" - assert result == expected - - def test_uses_region_for_partition(self): - """Test that region is used to determine partition.""" - account_id = "123456789012" - result = FeatureGroup._get_lake_formation_service_linked_role_arn(account_id, region="cn-north-1") - assert result.startswith("arn:aws-cn:iam::") - - - -class TestGenerateS3DenyPolicy: - """Tests for _generate_s3_deny_policy method.""" - - def setup_method(self): - """Set up test fixtures.""" - self.fg = MagicMock(spec=FeatureGroup) - self.fg._generate_s3_deny_policy = FeatureGroup._generate_s3_deny_policy.__get__(self.fg) - - def test_policy_includes_correct_bucket_arn_in_object_statement(self): - """Test that the policy includes correct bucket ARN and prefix in object actions statement.""" - bucket_name = "my-feature-store-bucket" - s3_prefix = "feature-store/data/my-feature-group" - lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" - fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - - policy = self.fg._generate_s3_deny_policy( - bucket_name=bucket_name, - s3_prefix=s3_prefix, - lake_formation_role_arn=lf_role_arn, - feature_store_role_arn=fs_role_arn, - ) - - # Verify the object actions statement has correct Resource ARN - object_statement = policy["Statement"][0] - expected_resource = f"arn:aws:s3:::{bucket_name}/{s3_prefix}/*" - assert object_statement["Resource"] == expected_resource - - def test_policy_includes_correct_bucket_arn_in_list_statement(self): - """Test that the policy includes correct bucket ARN in ListBucket statement.""" - bucket_name = "my-feature-store-bucket" - s3_prefix = "feature-store/data/my-feature-group" - lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" - fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - - policy = self.fg._generate_s3_deny_policy( - bucket_name=bucket_name, - s3_prefix=s3_prefix, - lake_formation_role_arn=lf_role_arn, - feature_store_role_arn=fs_role_arn, - ) - - # Verify the ListBucket statement has correct Resource ARN (bucket only) - list_statement = policy["Statement"][1] - expected_resource = f"arn:aws:s3:::{bucket_name}" - assert list_statement["Resource"] == expected_resource - - def test_policy_includes_correct_prefix_condition_in_list_statement(self): - """Test that the policy includes correct prefix condition in ListBucket statement.""" - bucket_name = "my-feature-store-bucket" - s3_prefix = "feature-store/data/my-feature-group" - lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" - fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - - policy = self.fg._generate_s3_deny_policy( - bucket_name=bucket_name, - s3_prefix=s3_prefix, - lake_formation_role_arn=lf_role_arn, - feature_store_role_arn=fs_role_arn, - ) - - # Verify the ListBucket statement has correct prefix condition - list_statement = policy["Statement"][1] - expected_prefix = f"{s3_prefix}/*" - assert list_statement["Condition"]["StringLike"]["s3:prefix"] == expected_prefix - - def test_policy_preserves_bucket_name_exactly(self): - """Test that bucket name is preserved exactly without modification.""" - # Test with various bucket name formats - test_cases = [ - "simple-bucket", - "bucket.with.dots", - "bucket-with-dashes-123", - "mybucket", - "a" * 63, # Max bucket name length - ] - - for bucket_name in test_cases: - policy = self.fg._generate_s3_deny_policy( - bucket_name=bucket_name, - s3_prefix="prefix", - lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", - feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", - ) - - # Verify bucket name is preserved in both statements - assert bucket_name in policy["Statement"][0]["Resource"] - assert bucket_name in policy["Statement"][1]["Resource"] - - def test_policy_preserves_prefix_exactly(self): - """Test that S3 prefix is preserved exactly without modification.""" - # Test with various prefix formats - test_cases = [ - "simple-prefix", - "path/to/data", - "feature-store/account-id/region/feature-group-name", - "deep/nested/path/structure/data", - "prefix_with_underscores", - "prefix-with-dashes", - ] - - for s3_prefix in test_cases: - policy = self.fg._generate_s3_deny_policy( - bucket_name="test-bucket", - s3_prefix=s3_prefix, - lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", - feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", - ) - - # Verify prefix is preserved in object statement Resource - assert f"{s3_prefix}/*" in policy["Statement"][0]["Resource"] - # Verify prefix is preserved in list statement Condition - assert policy["Statement"][1]["Condition"]["StringLike"]["s3:prefix"] == f"{s3_prefix}/*" - - def test_policy_has_correct_s3_arn_format(self): - """Test that the policy uses correct S3 ARN format (arn:aws:s3:::bucket/path).""" - bucket_name = "test-bucket" - s3_prefix = "test/prefix" - - policy = self.fg._generate_s3_deny_policy( - bucket_name=bucket_name, - s3_prefix=s3_prefix, - lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", - feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", - ) - - # Verify object statement Resource starts with correct ARN prefix - object_resource = policy["Statement"][0]["Resource"] - assert object_resource.startswith("arn:aws:s3:::") - assert object_resource == f"arn:aws:s3:::{bucket_name}/{s3_prefix}/*" - - # Verify list statement Resource is bucket-only ARN - list_resource = policy["Statement"][1]["Resource"] - assert list_resource.startswith("arn:aws:s3:::") - assert list_resource == f"arn:aws:s3:::{bucket_name}" - - def test_policy_structure_validation(self): - """Test that the policy has correct overall structure.""" - policy = self.fg._generate_s3_deny_policy( - bucket_name="test-bucket", - s3_prefix="test/prefix", - lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", - feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", - ) - - # Verify policy version - assert policy["Version"] == "2012-10-17" - - # Verify exactly two statements - assert len(policy["Statement"]) == 2 - - # Verify first statement structure (object actions) - object_statement = policy["Statement"][0] - assert object_statement["Sid"] == "DenyAllAccessToFeatureStorePrefixExceptAllowedPrincipals" - assert object_statement["Effect"] == "Deny" - assert object_statement["Principal"] == "*" - assert "Condition" in object_statement - assert "StringNotEquals" in object_statement["Condition"] - - # Verify second statement structure (list bucket) - list_statement = policy["Statement"][1] - assert list_statement["Sid"] == "DenyListOnPrefixExceptAllowedPrincipals" - assert list_statement["Effect"] == "Deny" - assert list_statement["Principal"] == "*" - assert "Condition" in list_statement - assert "StringLike" in list_statement["Condition"] - assert "StringNotEquals" in list_statement["Condition"] - - def test_policy_includes_both_principals_in_allowed_list(self): - """Test that both Lake Formation role and Feature Store role are in allowed principals.""" - lf_role_arn = "arn:aws:iam::123456789012:role/LakeFormationRole" - fs_role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - - policy = self.fg._generate_s3_deny_policy( - bucket_name="test-bucket", - s3_prefix="test/prefix", - lake_formation_role_arn=lf_role_arn, - feature_store_role_arn=fs_role_arn, - ) - - # Verify both principals in object statement - object_principals = policy["Statement"][0]["Condition"]["StringNotEquals"]["aws:PrincipalArn"] - assert lf_role_arn in object_principals - assert fs_role_arn in object_principals - assert len(object_principals) == 2 - - # Verify both principals in list statement - list_principals = policy["Statement"][1]["Condition"]["StringNotEquals"]["aws:PrincipalArn"] - assert lf_role_arn in list_principals - assert fs_role_arn in list_principals - assert len(list_principals) == 2 - - def test_policy_has_correct_actions_in_each_statement(self): - """Test that each statement has the correct S3 actions.""" - policy = self.fg._generate_s3_deny_policy( - bucket_name="test-bucket", - s3_prefix="test/prefix", - lake_formation_role_arn="arn:aws:iam::123456789012:role/LFRole", - feature_store_role_arn="arn:aws:iam::123456789012:role/FSRole", - ) - - # Verify object statement has correct actions - object_actions = policy["Statement"][0]["Action"] - assert "s3:GetObject" in object_actions - assert "s3:PutObject" in object_actions - assert "s3:DeleteObject" in object_actions - assert len(object_actions) == 3 - - # Verify list statement has correct action - list_action = policy["Statement"][1]["Action"] - assert list_action == "s3:ListBucket" - - - -class TestEnableLakeFormationServiceLinkedRoleInPolicy: - """Tests for service-linked role ARN usage in S3 deny policy generation.""" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch.object(FeatureGroup, "_generate_s3_deny_policy") - @patch("builtins.print") - def test_uses_service_linked_role_arn_when_use_service_linked_role_true( - self, - mock_print, - mock_generate_policy, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that enable_lake_formation uses the auto-generated service-linked role ARN - when use_service_linked_role=True. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} - - # Call with use_service_linked_role=True (default) - fg.enable_lake_formation(use_service_linked_role=True, show_s3_policy=True) - - # Verify _generate_s3_deny_policy was called with the service-linked role ARN - expected_slr_arn = "arn:aws:iam::123456789012:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" - mock_generate_policy.assert_called_once() - call_kwargs = mock_generate_policy.call_args[1] - assert call_kwargs["lake_formation_role_arn"] == expected_slr_arn - assert call_kwargs["feature_store_role_arn"] == fg.role_arn - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch.object(FeatureGroup, "_generate_s3_deny_policy") - @patch("builtins.print") - def test_uses_service_linked_role_arn_by_default( - self, - mock_print, - mock_generate_policy, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that enable_lake_formation uses the service-linked role ARN by default - (when use_service_linked_role is not explicitly specified). - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::987654321098:role/MyFeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-east-1:987654321098:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} - - # Call without specifying use_service_linked_role (should default to True) - fg.enable_lake_formation(show_s3_policy=True) - - # Verify _generate_s3_deny_policy was called with the service-linked role ARN - expected_slr_arn = "arn:aws:iam::987654321098:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" - mock_generate_policy.assert_called_once() - call_kwargs = mock_generate_policy.call_args[1] - assert call_kwargs["lake_formation_role_arn"] == expected_slr_arn - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch.object(FeatureGroup, "_generate_s3_deny_policy") - @patch("builtins.print") - def test_service_linked_role_arn_uses_correct_account_id( - self, - mock_print, - mock_generate_policy, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that the service-linked role ARN is generated with the correct account ID - extracted from the Feature Group ARN. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Use a specific account ID to verify it's extracted correctly - account_id = "111222333444" - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = f"arn:aws:iam::{account_id}:role/FeatureStoreRole" - fg.feature_group_arn = f"arn:aws:sagemaker:us-west-2:{account_id}:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} - - # Call with use_service_linked_role=True - fg.enable_lake_formation(use_service_linked_role=True, show_s3_policy=True) - - # Verify the service-linked role ARN contains the correct account ID - expected_slr_arn = f"arn:aws:iam::{account_id}:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess" - mock_generate_policy.assert_called_once() - call_kwargs = mock_generate_policy.call_args[1] - assert call_kwargs["lake_formation_role_arn"] == expected_slr_arn - assert account_id in call_kwargs["lake_formation_role_arn"] - - - -class TestRegistrationRoleArnUsedWhenServiceLinkedRoleFalse: - """Tests for verifying registration_role_arn is used when use_service_linked_role=False.""" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch.object(FeatureGroup, "_generate_s3_deny_policy") - @patch("builtins.print") - def test_uses_registration_role_arn_when_use_service_linked_role_false( - self, - mock_print, - mock_generate_policy, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that when use_service_linked_role=False, the registration_role_arn is used - in the S3 deny policy instead of the auto-generated service-linked role ARN. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} - - # Custom registration role ARN - custom_registration_role = "arn:aws:iam::123456789012:role/CustomLakeFormationRole" - - # Call with use_service_linked_role=False and registration_role_arn - fg.enable_lake_formation( - use_service_linked_role=False, - registration_role_arn=custom_registration_role, - show_s3_policy=True, - ) - - # Verify _generate_s3_deny_policy was called with the custom registration role ARN - mock_generate_policy.assert_called_once() - call_kwargs = mock_generate_policy.call_args[1] - assert call_kwargs["lake_formation_role_arn"] == custom_registration_role - - # Verify it's NOT the service-linked role ARN - service_linked_role_pattern = "aws-service-role/lakeformation.amazonaws.com" - assert service_linked_role_pattern not in call_kwargs["lake_formation_role_arn"] - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch.object(FeatureGroup, "_generate_s3_deny_policy") - @patch("builtins.print") - def test_registration_role_arn_passed_to_s3_registration( - self, - mock_print, - mock_generate_policy, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that when use_service_linked_role=False, the registration_role_arn is also - passed to _register_s3_with_lake_formation. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} - - # Custom registration role ARN - custom_registration_role = "arn:aws:iam::123456789012:role/CustomLakeFormationRole" - - # Call with use_service_linked_role=False and registration_role_arn - fg.enable_lake_formation( - use_service_linked_role=False, - registration_role_arn=custom_registration_role, - show_s3_policy=True, - ) - - # Verify _register_s3_with_lake_formation was called with the correct parameters - mock_register.assert_called_once() - call_args = mock_register.call_args - assert call_args[1]["use_service_linked_role"] == False - assert call_args[1]["role_arn"] == custom_registration_role - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch.object(FeatureGroup, "_generate_s3_deny_policy") - @patch("builtins.print") - def test_different_registration_role_arns_produce_different_policies( - self, - mock_print, - mock_generate_policy, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that different registration_role_arn values result in different - lake_formation_role_arn values in the generated policy. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - mock_generate_policy.return_value = {"Version": "2012-10-17", "Statement": []} - - # First call with one registration role - first_role = "arn:aws:iam::123456789012:role/FirstLakeFormationRole" - fg.enable_lake_formation( - use_service_linked_role=False, - registration_role_arn=first_role, - show_s3_policy=True, - ) - - first_call_kwargs = mock_generate_policy.call_args[1] - first_lf_role = first_call_kwargs["lake_formation_role_arn"] - - # Reset mocks - mock_generate_policy.reset_mock() - mock_register.reset_mock() - mock_grant.reset_mock() - mock_revoke.reset_mock() - - # Second call with different registration role - second_role = "arn:aws:iam::123456789012:role/SecondLakeFormationRole" - fg.enable_lake_formation( - use_service_linked_role=False, - registration_role_arn=second_role, - show_s3_policy=True, - ) - - second_call_kwargs = mock_generate_policy.call_args[1] - second_lf_role = second_call_kwargs["lake_formation_role_arn"] - - # Verify different roles were used - assert first_lf_role == first_role - assert second_lf_role == second_role - assert first_lf_role != second_lf_role - - - -class TestPolicyPrintedWithClearInstructions: - """Tests for verifying the S3 deny policy is printed with clear instructions.""" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch("builtins.print") - def test_policy_printed_with_header_and_instructions( - self, - mock_print, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that enable_lake_formation prints the S3 deny policy with clear - header and instructions for the user. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call enable_lake_formation with show_s3_policy=True - fg.enable_lake_formation(show_s3_policy=True) - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify header is printed - assert "S3 Bucket Policy" in all_printed_text, "Header should mention 'S3 Bucket Policy'" - - # Verify instructions are printed - assert ( - "Lake Formation" in all_printed_text - or "deny policy" in all_printed_text - ), "Instructions should mention Lake Formation or deny policy" - - # Verify bucket name is printed - assert "test-bucket" in all_printed_text, "Bucket name should be printed" - - # Verify note about merging with existing policy is printed - assert ( - "Merge" in all_printed_text or "existing" in all_printed_text - ), "Note about merging with existing policy should be printed" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch("builtins.print") - def test_policy_json_is_printed( - self, - mock_print, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that the S3 deny policy JSON is printed to the console when show_s3_policy=True. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call enable_lake_formation with show_s3_policy=True - fg.enable_lake_formation(show_s3_policy=True) - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify policy JSON structure elements are printed - assert "Version" in all_printed_text, "Policy JSON should contain 'Version'" - assert "Statement" in all_printed_text, "Policy JSON should contain 'Statement'" - assert "Effect" in all_printed_text, "Policy JSON should contain 'Effect'" - assert "Deny" in all_printed_text, "Policy JSON should contain 'Deny' effect" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch("builtins.print") - def test_policy_printed_only_after_successful_setup( - self, - mock_print, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that the S3 deny policy is only printed after all Lake Formation - phases complete successfully. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock Phase 1 failure - mock_register.side_effect = Exception("Phase 1 failed") - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call enable_lake_formation with show_s3_policy=True - should fail - with pytest.raises(RuntimeError): - fg.enable_lake_formation(show_s3_policy=True) - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify policy was NOT printed when setup failed - assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when setup fails" - - # Reset mocks - mock_print.reset_mock() - mock_register.reset_mock() - mock_register.side_effect = None - mock_register.return_value = True - - # Mock Phase 2 failure - mock_grant.side_effect = Exception("Phase 2 failed") - - # Call enable_lake_formation with show_s3_policy=True - should fail - with pytest.raises(RuntimeError): - fg.enable_lake_formation(show_s3_policy=True) - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify policy was NOT printed when setup fails at Phase 2 - assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when Phase 2 fails" - - # Reset mocks - mock_print.reset_mock() - mock_grant.reset_mock() - mock_grant.side_effect = None - mock_grant.return_value = True - - # Mock Phase 3 failure - mock_revoke.side_effect = Exception("Phase 3 failed") - - # Call enable_lake_formation with show_s3_policy=True - should fail - with pytest.raises(RuntimeError): - fg.enable_lake_formation(show_s3_policy=True) - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify policy was NOT printed when setup fails at Phase 3 - assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when Phase 3 fails" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch("builtins.print") - def test_policy_includes_both_allowed_principals( - self, - mock_print, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that the printed policy includes both the Lake Formation role - and the Feature Store execution role as allowed principals. - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - feature_store_role = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.role_arn = feature_store_role - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call enable_lake_formation with service-linked role and show_s3_policy=True - fg.enable_lake_formation(use_service_linked_role=True, show_s3_policy=True) - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify Feature Store role is in the printed output - assert feature_store_role in all_printed_text, "Feature Store role should be in printed policy" - - # Verify Lake Formation service-linked role pattern is in the printed output - assert "AWSServiceRoleForLakeFormationDataAccess" in all_printed_text, \ - "Lake Formation service-linked role should be in printed policy" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch("builtins.print") - def test_policy_not_printed_when_show_s3_policy_false( - self, - mock_print, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that the S3 deny policy is NOT printed when show_s3_policy=False (default). - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call enable_lake_formation with show_s3_policy=False (default) - fg.enable_lake_formation(show_s3_policy=False) - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify policy was NOT printed - assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed when show_s3_policy=False" - assert "Version" not in all_printed_text, "Policy JSON should not be printed when show_s3_policy=False" - - @patch.object(FeatureGroup, "refresh") - @patch.object(FeatureGroup, "_register_s3_with_lake_formation") - @patch.object(FeatureGroup, "_grant_lake_formation_permissions") - @patch.object(FeatureGroup, "_revoke_iam_allowed_principal") - @patch("builtins.print") - def test_policy_not_printed_by_default( - self, - mock_print, - mock_revoke, - mock_grant, - mock_register, - mock_refresh, - ): - """ - Test that the S3 deny policy is NOT printed by default (when show_s3_policy is not specified). - """ - from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig - - # Set up Feature Group with required configuration - fg = FeatureGroup(feature_group_name="test-fg") - fg.offline_store_config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig( - s3_uri="s3://test-bucket/path", - resolved_output_s3_uri="s3://test-bucket/resolved-path/data", - ), - data_catalog_config=DataCatalogConfig( - catalog="AwsDataCatalog", database="test_db", table_name="test_table" - ), - ) - fg.role_arn = "arn:aws:iam::123456789012:role/FeatureStoreRole" - fg.feature_group_arn = "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg" - fg.feature_group_status = "Created" - - # Mock successful Lake Formation operations - mock_register.return_value = True - mock_grant.return_value = True - mock_revoke.return_value = True - - # Call enable_lake_formation without specifying show_s3_policy (should default to False) - fg.enable_lake_formation() - - # Collect all print calls - print_calls = [str(call) for call in mock_print.call_args_list] - all_printed_text = " ".join(print_calls) - - # Verify policy was NOT printed - assert "S3 Bucket Policy" not in all_printed_text, "Policy should not be printed by default" - assert "Version" not in all_printed_text, "Policy JSON should not be printed by default" diff --git a/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb b/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb deleted file mode 100644 index f8c5a8ced0..0000000000 --- a/v3-examples/ml-ops-examples/v3-feature-store-lake-formation.ipynb +++ /dev/null @@ -1,675 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Feature Store with Lake Formation Governance\n", - "\n", - "This notebook demonstrates two workflows for using SageMaker Feature Store with Lake Formation governance:\n", - "\n", - "1. **Example 1**: Create Feature Group with Lake Formation enabled at creation time\n", - "2. **Example 2**: Create Feature Group first, then enable Lake Formation separately\n", - "\n", - "Both workflows include record ingestion to verify everything works end-to-end.\n", - "\n", - "## Prerequisites\n", - "\n", - "- AWS credentials configured with permissions for SageMaker, S3, Glue, and Lake Formation\n", - "- An S3 bucket for the offline store\n", - "- An IAM role with Feature Store permissions\n", - "\n", - "## Required IAM Permissions\n", - "\n", - "This notebook uses two separate IAM roles:\n", - "1. **Execution Role**: The SageMaker execution role running this notebook\n", - "2. **Offline Store Role**: A dedicated role for Feature Store S3 access\n", - "\n", - "### Execution Role Policy\n", - "\n", - "The execution role needs permissions to manage Feature Groups and configure Lake Formation:\n", - "\n", - "```json\n", - "{\n", - " \"Version\": \"2012-10-17\",\n", - " \"Statement\": [\n", - " {\n", - " \"Sid\": \"FeatureGroupManagement\",\n", - " \"Effect\": \"Allow\",\n", - " \"Action\": [\n", - " \"sagemaker:*\"\n", - " ],\n", - " \"Resource\": \"arn:aws:sagemaker:*:*:feature-group/*\"\n", - " },\n", - " {\n", - " \"Sid\": \"LakeFormation\",\n", - " \"Effect\": \"Allow\",\n", - " \"Action\": [\n", - " \"lakeformation:RegisterResource\",\n", - " \"lakeformation:DeregisterResource\",\n", - " \"lakeformation:GrantPermissions\",\n", - " \"lakeformation:RevokePermissions\"\n", - " ],\n", - " \"Resource\": \"*\"\n", - " },\n", - " {\n", - " \"Sid\": \"GlueCatalogRead\",\n", - " \"Effect\": \"Allow\",\n", - " \"Action\": [\n", - " \"glue:GetTable\",\n", - " \"glue:GetDatabase\",\n", - " \"glue:DeleteTable\"\n", - " ],\n", - " \"Resource\": [\n", - " \"arn:aws:glue:*:*:catalog\",\n", - " \"arn:aws:glue:*:*:database/sagemaker_featurestore\",\n", - " \"arn:aws:glue:*:*:table/sagemaker_featurestore/*\"\n", - " ]\n", - " },\n", - " {\n", - " \"Sid\": \"PassOfflineStoreRole\",\n", - " \"Effect\": \"Allow\",\n", - " \"Action\": \"iam:PassRole\",\n", - " \"Resource\": \"arn:aws:iam::*:role/SagemakerFeatureStoreOfflineRole\"\n", - " },\n", - " {\n", - " \"Sid\": \"LakeFormationServiceLinkedRole\",\n", - " \"Effect\": \"Allow\",\n", - " \"Action\": [\n", - " \"iam:GetRole\",\n", - " \"iam:PutRolePolicy\",\n", - " \"iam:GetRolePolicy\"\n", - " ],\n", - " \"Resource\": \"arn:aws:iam::*:role/aws-service-role/lakeformation.amazonaws.com/AWSServiceRoleForLakeFormationDataAccess\"\n", - " },\n", - " {\n", - " \"Sid\": \"S3SagemakerDefaultBucket\",\n", - " \"Effect\": \"Allow\",\n", - " \"Action\": [\n", - " \"s3:CreateBucket\",\n", - " \"s3:GetBucketAcl\",\n", - " \"s3:ListBucket\"\n", - " ],\n", - " \"Resource\": [\n", - " \"arn:aws:s3:::sagemaker-*\"\n", - " ]\n", - " },\n", - " {\n", - " \"Sid\": \"CreateGlueTable\",\n", - " \"Effect\": \"Allow\",\n", - " \"Action\": [\n", - " \"glue:CreateTable\"\n", - " ],\n", - " \"Resource\": [\n", - " \"*\"\n", - " ]\n", - " }\n", - " ]\n", - "}\n", - "```\n", - "\n", - "## Lake Formation Admin Requirements\n", - "\n", - "The person enabling Lake Formation governance must be a **Data Lake Administrator** in Lake Formation. There are two options depending on your organization's setup:\n", - "\n", - "### Option 1: Single User (Data Lake Admin + Feature Store Admin)\n", - "\n", - "If the caller has both:\n", - "- Data Lake Administrator privileges in Lake Formation\n", - "- Permissions to create Feature Groups in SageMaker\n", - "\n", - "Then they can use `FeatureGroup.create()` with `lake_formation_config` to enable governance at creation time (Example 1).\n", - "\n", - "### Option 2: Separate Roles (ML Engineer + Data Lake Admin)\n", - "\n", - "If the person creating the Feature Group is different from the Data Lake Administrator:\n", - "\n", - "1. **ML Engineer** creates the Feature Group without Lake Formation using `FeatureGroup.create()`\n", - "2. **Data Lake Admin** later enables governance by calling `enable_lake_formation()` on the existing Feature Group (Example 2)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "from datetime import datetime\n", - "from datetime import timezone\n", - "\n", - "import boto3\n", - "from botocore.exceptions import ClientError\n", - "\n", - "# Import the FeatureGroup with Lake Formation support\n", - "from sagemaker.mlops.feature_store.feature_group import FeatureGroup, LakeFormationConfig\n", - "from sagemaker.core.shapes import (\n", - " FeatureDefinition,\n", - " FeatureValue,\n", - " OfflineStoreConfig,\n", - " OnlineStoreConfig,\n", - " S3StorageConfig,\n", - ")\n", - "from sagemaker.core.helper.session_helper import Session as SageMakerSession, get_execution_role" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Configuration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Use SageMaker session to get default bucket and execution role\n", - "sagemaker_session = SageMakerSession()\n", - "S3_BUCKET = sagemaker_session.default_bucket()\n", - "REGION = sagemaker_session.boto_session.region_name\n", - "\n", - "# Execution role (for running this notebook)\n", - "EXECUTION_ROLE_ARN = get_execution_role(sagemaker_session)\n", - "\n", - "# Offline store role (dedicated role for Feature Store S3 access)\n", - "# Replace with your dedicated offline store role ARN\n", - "OFFLINE_STORE_ROLE_ARN = \"arn:aws:iam:::role/\"\n", - "\n", - "print(f\"S3 Bucket: {S3_BUCKET}\")\n", - "print(f\"Execution Role ARN: {EXECUTION_ROLE_ARN}\")\n", - "print(f\"Offline Store Role ARN: {OFFLINE_STORE_ROLE_ARN}\")\n", - "print(f\"Region: {REGION}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Common Feature Definitions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "feature_definitions = [\n", - " FeatureDefinition(feature_name=\"customer_id\", feature_type=\"String\"),\n", - " FeatureDefinition(feature_name=\"event_time\", feature_type=\"String\"),\n", - " FeatureDefinition(feature_name=\"age\", feature_type=\"Integral\"),\n", - " FeatureDefinition(feature_name=\"total_purchases\", feature_type=\"Integral\"),\n", - " FeatureDefinition(feature_name=\"avg_order_value\", feature_type=\"Fractional\"),\n", - "]\n", - "\n", - "print(\"Feature Definitions:\")\n", - "for fd in feature_definitions:\n", - " print(f\" - {fd.feature_name}: {fd.feature_type}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Helper Function: Ingest Records" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def ingest_sample_records(feature_group, num_records=3):\n", - " \"\"\"\n", - " Ingest sample records into the Feature Group.\n", - " \n", - " Args:\n", - " feature_group: The FeatureGroup to ingest records into\n", - " num_records: Number of sample records to ingest\n", - " \"\"\"\n", - " print(f\"\\nIngesting {num_records} sample records...\")\n", - " \n", - " for i in range(num_records):\n", - " event_time = datetime.now(timezone.utc).isoformat()\n", - " record = [\n", - " FeatureValue(feature_name=\"customer_id\", value_as_string=f\"cust_{i+1}\"),\n", - " FeatureValue(feature_name=\"event_time\", value_as_string=event_time),\n", - " FeatureValue(feature_name=\"age\", value_as_string=str(25 + i * 5)),\n", - " FeatureValue(feature_name=\"total_purchases\", value_as_string=str(10 + i * 3)),\n", - " FeatureValue(feature_name=\"avg_order_value\", value_as_string=str(50.0 + i * 10.5)),\n", - " ]\n", - " \n", - " feature_group.put_record(record=record)\n", - " print(f\" Ingested record for customer: cust_{i+1}\")\n", - " \n", - " print(f\"Successfully ingested {num_records} records!\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Helper Function: Cleanup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def cleanup_feature_group(fg):\n", - " \"\"\"\n", - " Delete a FeatureGroup and its associated Glue table.\n", - " \n", - " Args:\n", - " fg: The FeatureGroup to delete.\n", - " \"\"\"\n", - " try:\n", - " # Delete the Glue table if it exists\n", - " if fg.offline_store_config is not None:\n", - " try:\n", - " fg.refresh() # Ensure we have latest config\n", - " data_catalog_config = fg.offline_store_config.data_catalog_config\n", - " if data_catalog_config is not None:\n", - " database_name = data_catalog_config.database\n", - " table_name = data_catalog_config.table_name\n", - "\n", - " if database_name and table_name:\n", - " glue_client = boto3.client(\"glue\")\n", - " try:\n", - " glue_client.delete_table(DatabaseName=database_name, Name=table_name)\n", - " print(f\"Deleted Glue table: {database_name}.{table_name}\")\n", - " except ClientError as e:\n", - " # Ignore if table doesn't exist\n", - " if e.response[\"Error\"][\"Code\"] != \"EntityNotFoundException\":\n", - " raise\n", - " except Exception as e:\n", - " # Don't fail cleanup if Glue table deletion fails\n", - " print(f\"Warning: Could not delete Glue table: {e}\")\n", - "\n", - " # Delete the FeatureGroup\n", - " fg.delete()\n", - " print(f\"Deleted Feature Group: {fg.feature_group_name}\")\n", - " except ClientError as e:\n", - " print(f\"Error during cleanup: {e}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "# Example 1: Create Feature Group with Lake Formation Enabled\n", - "\n", - "This example creates a Feature Group with Lake Formation governance enabled at creation time using `LakeFormationConfig`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate unique name for example 1\n", - "timestamp = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n", - "FG_NAME_WORKFLOW1 = f\"lf-demo-workflow1-{timestamp}\"\n", - "\n", - "print(f\"Example 1 Feature Group: {FG_NAME_WORKFLOW1}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Configure online and offline stores\n", - "online_store_config = OnlineStoreConfig(enable_online_store=True)\n", - "\n", - "offline_store_config_1 = OfflineStoreConfig(\n", - " s3_storage_config=S3StorageConfig(\n", - " s3_uri=f\"s3://{S3_BUCKET}/feature-store-demo/\"\n", - " )\n", - ")\n", - "\n", - "# Configure Lake Formation - enabled at creation\n", - "lake_formation_config = LakeFormationConfig()\n", - "lake_formation_config.enabled = True\n", - "lake_formation_config.use_service_linked_role = True\n", - "lake_formation_config.show_s3_policy = True\n", - "\n", - "print(\"Store Config:\")\n", - "print(f\" Online Store: enabled\")\n", - "print(f\" Offline Store S3: s3://{S3_BUCKET}/feature-store-demo/\")\n", - "print(\"\\nLake Formation Config:\")\n", - "print(f\" enabled: {lake_formation_config.enabled}\")\n", - "print(f\" use_service_linked_role: {lake_formation_config.use_service_linked_role}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Create Feature Group with Lake Formation enabled\n", - "print(\"Creating Feature Group with Lake Formation enabled...\")\n", - "print(\"This will:\")\n", - "print(\" 1. Create the Feature Group with online + offline stores\")\n", - "print(\" 2. Wait for 'Created' status\")\n", - "print(\" 3. Register S3 with Lake Formation\")\n", - "print(\" 4. Grant permissions to execution role\")\n", - "print(\" 5. Revoke IAMAllowedPrincipal permissions\")\n", - "print()\n", - "\n", - "fg_workflow1 = FeatureGroup.create(\n", - " feature_group_name=FG_NAME_WORKFLOW1,\n", - " record_identifier_feature_name=\"customer_id\",\n", - " event_time_feature_name=\"event_time\",\n", - " feature_definitions=feature_definitions,\n", - " online_store_config=online_store_config,\n", - " offline_store_config=offline_store_config_1,\n", - " role_arn=OFFLINE_STORE_ROLE_ARN,\n", - " description=\"Workflow 1: Lake Formation enabled at creation\",\n", - " lake_formation_config=lake_formation_config, # new field\n", - " region=REGION,\n", - ")\n", - "\n", - "print(f\"\\nFeature Group created: {fg_workflow1.feature_group_name}\")\n", - "print(f\"Status: {fg_workflow1.feature_group_status}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Verify Feature Group status\n", - "fg_workflow1.refresh()\n", - "print(f\"Feature Group: {fg_workflow1.feature_group_name}\")\n", - "print(f\"Status: {fg_workflow1.feature_group_status}\")\n", - "print(f\"ARN: {fg_workflow1.feature_group_arn}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Ingest sample records to verify everything works\n", - "ingest_sample_records(fg_workflow1, num_records=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Retrieve a sample record from the online store\n", - "print(\"Retrieving record for customer 'cust_1' from online store...\")\n", - "response = fg_workflow1.get_record(record_identifier_value_as_string=\"cust_1\")\n", - "\n", - "print(f\"\\nRecord retrieved successfully!\")\n", - "print(f\"Features:\")\n", - "for feature in response.record:\n", - " print(f\" {feature.feature_name}: {feature.value_as_string}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "# Example 2: Create Feature Group, Then Enable Lake Formation\n", - "\n", - "This example creates a Feature Group first without Lake Formation, then enables it separately using `enable_lake_formation()`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Generate unique name for example 2\n", - "timestamp = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n", - "FG_NAME_WORKFLOW2 = f\"lf-demo-workflow2-{timestamp}\"\n", - "\n", - "print(f\"Example 2 Feature Group: {FG_NAME_WORKFLOW2}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Configure online and offline stores\n", - "online_store_config_2 = OnlineStoreConfig(enable_online_store=True)\n", - "\n", - "offline_store_config_2 = OfflineStoreConfig(\n", - " s3_storage_config=S3StorageConfig(\n", - " s3_uri=f\"s3://{S3_BUCKET}/feature-store-demo/\"\n", - " ),\n", - " table_format=\"Iceberg\"\n", - ")\n", - "\n", - "# Step 1: Create Feature Group WITHOUT Lake Formation\n", - "print(\"Step 1: Creating Feature Group without Lake Formation...\")\n", - "\n", - "fg_workflow2 = FeatureGroup.create(\n", - " feature_group_name=FG_NAME_WORKFLOW2,\n", - " record_identifier_feature_name=\"customer_id\",\n", - " event_time_feature_name=\"event_time\",\n", - " feature_definitions=feature_definitions,\n", - " online_store_config=online_store_config_2,\n", - " offline_store_config=offline_store_config_2,\n", - " role_arn=OFFLINE_STORE_ROLE_ARN,\n", - " description=\"Workflow 2: Lake Formation enabled after creation\",\n", - " region=REGION,\n", - ")\n", - "\n", - "print(f\"Feature Group created: {fg_workflow2.feature_group_name}\")\n", - "print(f\"Status: {fg_workflow2.feature_group_status}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 2: Wait for Feature Group to be ready\n", - "print(\"Step 2: Waiting for Feature Group to reach 'Created' status...\")\n", - "fg_workflow2.wait_for_status(target_status=\"Created\", poll=10, timeout=300)\n", - "print(f\"Status: {fg_workflow2.feature_group_status}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 3: Enable Lake Formation governance\n", - "print(\"Step 3: Enabling Lake Formation governance...\")\n", - "print(\"This will:\")\n", - "print(\" 1. Register S3 with Lake Formation\")\n", - "print(\" 2. Grant permissions to execution role\")\n", - "print(\" 3. Revoke IAMAllowedPrincipal permissions\")\n", - "print()\n", - "\n", - "result = fg_workflow2.enable_lake_formation( # new method\n", - " use_service_linked_role=True\n", - ")\n", - "\n", - "print(f\"\\nLake Formation setup results:\")\n", - "print(f\" s3_registration: {result['s3_registration']}\")\n", - "print(f\" permissions_granted: {result['permissions_granted']}\")\n", - "print(f\" iam_principal_revoked: {result['iam_principal_revoked']}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 4: Ingest sample records to verify everything works\n", - "print(\"Step 4: Ingesting records to verify Lake Formation setup...\")\n", - "ingest_sample_records(fg_workflow2, num_records=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Step 5: Retrieve a sample record from the online store\n", - "print(\"Step 5: Retrieving record for customer 'cust_1' from online store...\")\n", - "response = fg_workflow2.get_record(record_identifier_value_as_string=\"cust_1\")\n", - "\n", - "print(f\"\\nRecord retrieved successfully!\")\n", - "print(f\"Features:\")\n", - "for feature in response.record:\n", - " print(f\" {feature.feature_name}: {feature.value_as_string}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Verify Feature Group status\n", - "fg_workflow2.refresh()\n", - "print(f\"Feature Group: {fg_workflow2.feature_group_name}\")\n", - "print(f\"Status: {fg_workflow2.feature_group_status}\")\n", - "print(f\"ARN: {fg_workflow2.feature_group_arn}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "# Cleanup\n", - "\n", - "Delete the Feature Groups and associated Glue tables created in this demo." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Uncomment to delete the Feature Groups\n", - "cleanup_feature_group(fg_workflow1)\n", - "# cleanup_feature_group(fg_workflow2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "# Summary\n", - "\n", - "This notebook demonstrated two workflows:\n", - "\n", - "**Example 1: Lake Formation at Creation**\n", - "- Use `LakeFormationConfig` with `enabled=True` in `FeatureGroup.create()`\n", - "- Lake Formation is automatically configured after Feature Group creation\n", - "- Both online and offline stores enabled\n", - "\n", - "**Example 2: Enable Lake Formation Later**\n", - "- Create Feature Group normally without Lake Formation\n", - "- Call `enable_lake_formation()` method after creation\n", - "- More control over when Lake Formation is enabled\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "# FAQ:\n", - "\n", - "## What is the S3 deny policy for?\n", - "\n", - "When you enable Lake Formation governance, you control access to data through Lake Formation permissions. However, **IAM roles that already have direct S3 access will continue to have access** to the underlying data files, bypassing Lake Formation entirely.\n", - "\n", - "The S3 deny policy closes this access path by explicitly denying S3 access to all principals except:\n", - "- The Lake Formation service-linked role (for data access)\n", - "- The Feature Store offline store role provided during Feature Group creation\n", - "\n", - "## Why don't we apply the S3 deny policy automatically?\n", - "\n", - "We provide the policy as a **recommendation** rather than applying it automatically for several important reasons:\n", - "\n", - "### 1. Protect existing SageMaker workflows from breaking\n", - "\n", - "Many customers already have SageMaker training and processing jobs wired directly to S3 URIs. An automatic S3 deny could cause those jobs to fail the moment governance is enabled on a table.\n", - "\n", - "### 2. Support different personas and trust levels\n", - "\n", - "Different users have different access needs:\n", - "- **Analysts / BI users** - should only see data through governed surfaces (Lake Formation tables, Athena, Redshift, etc.)\n", - "- **ML / Data engineers** - often need raw S3 access for training, feature engineering, and debugging\n", - "\n", - "### 3. Enable gradual migration to stronger governance\n", - "\n", - "Many customers want to phase in Lake Formation governance:\n", - "1. Start by governing table access only\n", - "2. Later tighten S3 access once they've refactored jobs and validated behavior\n", - "\n", - "### 4. Avoid breaking existing bucket policies\n", - "\n", - "Automatically modifying bucket policies could:\n", - "- Conflict with existing policy statements\n", - "- Lock out users or services unexpectedly\n", - "- Cause cascading failures across multiple applications sharing the bucket\n", - "\n", - "Therefore, the S3 policy is provided as a starting point that should be validated by the user. \n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -}