Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@
TAGS,
SESSION_DEFAULT_S3_BUCKET_PATH,
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
FEATURE_GROUP,
FEATURE_GROUP_ROLE_ARN_PATH,
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH,
FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH,
)

# Setting LOGGER for backward compatibility, in case users import it...
Expand Down Expand Up @@ -1607,6 +1611,222 @@ def delete_endpoint_config(self, endpoint_config_name):
logger.info("Deleting endpoint configuration with name: %s", endpoint_config_name)
self.sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

def delete_feature_group(self, feature_group_name):
"""Delete an Amazon SageMaker Feature Group.

Args:
feature_group_name (str): Name of the Amazon SageMaker Feature Group to delete.
"""
logger.info("Deleting feature group with name: %s", feature_group_name)
self.sagemaker_client.delete_feature_group(FeatureGroupName=feature_group_name)

def create_feature_group(
self,
feature_group_name,
record_identifier_name,
event_time_feature_name,
feature_definitions,
role_arn=None,
online_store_config=None,
offline_store_config=None,
throughput_config=None,
description=None,
tags=None,
):
"""Create an Amazon SageMaker Feature Group.

Args:
feature_group_name (str): Name of the Feature Group.
record_identifier_name (str): Name of the record identifier feature.
event_time_feature_name (str): Name of the event time feature.
feature_definitions (list): List of feature definitions.
role_arn (str): ARN of the role used to execute the API (default: None).
Resolved from SageMaker Config if not provided.
online_store_config (dict): Online store configuration (default: None).
offline_store_config (dict): Offline store configuration (default: None).
throughput_config (dict): Throughput configuration (default: None).
description (str): Description of the Feature Group (default: None).
tags (Optional[Tags]): Tags for labeling the Feature Group (default: None).

Returns:
dict: Response from the CreateFeatureGroup API.
"""
tags = format_tags(tags)
tags = _append_project_tags(tags)
tags = self._append_sagemaker_config_tags(
tags, "{}.{}.{}".format(SAGEMAKER, FEATURE_GROUP, TAGS)
)

role_arn = resolve_value_from_config(
role_arn, FEATURE_GROUP_ROLE_ARN_PATH, sagemaker_session=self
)

inferred_online_store_config = update_nested_dictionary_with_values_from_config(
online_store_config,
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH,
sagemaker_session=self,
)
if inferred_online_store_config is not None:
# OnlineStore should be handled differently because if you set KmsKeyId, then you
# need to set EnableOnlineStore key as well
inferred_online_store_config["EnableOnlineStore"] = True

inferred_offline_store_config = update_nested_dictionary_with_values_from_config(
offline_store_config,
FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH,
sagemaker_session=self,
)

kwargs = dict(
FeatureGroupName=feature_group_name,
RecordIdentifierFeatureName=record_identifier_name,
EventTimeFeatureName=event_time_feature_name,
FeatureDefinitions=feature_definitions,
RoleArn=role_arn,
)
update_args(
kwargs,
OnlineStoreConfig=inferred_online_store_config,
OfflineStoreConfig=inferred_offline_store_config,
ThroughputConfig=throughput_config,
Description=description,
Tags=tags,
)

logger.info("Creating feature group with name: %s", feature_group_name)
return self.sagemaker_client.create_feature_group(**kwargs)

def describe_feature_group(self, feature_group_name, next_token=None):
"""Describe an Amazon SageMaker Feature Group.

Args:
feature_group_name (str): Name of the Amazon SageMaker Feature Group to describe.
next_token (str): A token for paginated results (default: None).

Returns:
dict: Response from the DescribeFeatureGroup API.
"""
args = {"FeatureGroupName": feature_group_name}
update_args(args, NextToken=next_token)
return self.sagemaker_client.describe_feature_group(**args)

def update_feature_group(
self,
feature_group_name,
feature_additions=None,
online_store_config=None,
throughput_config=None,
):
"""Update an Amazon SageMaker Feature Group.

Args:
feature_group_name (str): Name of the Amazon SageMaker Feature Group to update.
feature_additions (list): List of feature definitions to add (default: None).
online_store_config (dict): Online store configuration updates (default: None).
throughput_config (dict): Throughput configuration updates (default: None).

Returns:
dict: Response from the UpdateFeatureGroup API.
"""
args = {"FeatureGroupName": feature_group_name}
update_args(
args,
FeatureAdditions=feature_additions,
OnlineStoreConfig=online_store_config,
ThroughputConfig=throughput_config,
)
return self.sagemaker_client.update_feature_group(**args)

def list_feature_groups(
self,
name_contains=None,
feature_group_status_equals=None,
offline_store_status_equals=None,
creation_time_after=None,
creation_time_before=None,
sort_order=None,
sort_by=None,
max_results=None,
next_token=None,
):
"""List Amazon SageMaker Feature Groups.

Args:
name_contains (str): Filter by name substring (default: None).
feature_group_status_equals (str): Filter by status (default: None).
offline_store_status_equals (str): Filter by offline store status (default: None).
creation_time_after (datetime): Filter by creation time lower bound (default: None).
creation_time_before (datetime): Filter by creation time upper bound (default: None).
sort_order (str): Sort order, 'Ascending' or 'Descending' (default: None).
sort_by (str): Sort by field (default: None).
max_results (int): Maximum number of results (default: None).
next_token (str): Pagination token (default: None).

Returns:
dict: Response from the ListFeatureGroups API.
"""
args = {}
update_args(
args,
NameContains=name_contains,
FeatureGroupStatusEquals=feature_group_status_equals,
OfflineStoreStatusEquals=offline_store_status_equals,
CreationTimeAfter=creation_time_after,
CreationTimeBefore=creation_time_before,
SortOrder=sort_order,
SortBy=sort_by,
MaxResults=max_results,
NextToken=next_token,
)
return self.sagemaker_client.list_feature_groups(**args)

def update_feature_metadata(
self,
feature_group_name,
feature_name,
description=None,
parameter_additions=None,
parameter_removals=None,
):
"""Update metadata for a feature in an Amazon SageMaker Feature Group.

Args:
feature_group_name (str): Name of the Feature Group.
feature_name (str): Name of the feature to update metadata for.
description (str): Updated description for the feature (default: None).
parameter_additions (list): Parameters to add (default: None).
parameter_removals (list): Parameters to remove (default: None).

Returns:
dict: Response from the UpdateFeatureMetadata API.
"""
args = {
"FeatureGroupName": feature_group_name,
"FeatureName": feature_name,
}
update_args(
args,
Description=description,
ParameterAdditions=parameter_additions,
ParameterRemovals=parameter_removals,
)
return self.sagemaker_client.update_feature_metadata(**args)

def describe_feature_metadata(self, feature_group_name, feature_name):
"""Describe metadata for a feature in an Amazon SageMaker Feature Group.

Args:
feature_group_name (str): Name of the Feature Group.
feature_name (str): Name of the feature to describe metadata for.

Returns:
dict: Response from the DescribeFeatureMetadata API.
"""
return self.sagemaker_client.describe_feature_metadata(
FeatureGroupName=feature_group_name,
FeatureName=feature_name,
)

def wait_for_optimization_job(self, job, poll=5):
"""Wait for an Amazon SageMaker Optimization job to complete.

Expand Down
Loading