Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sagemaker-core/src/sagemaker/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Partner App
from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401
from sagemaker.core.partner_app.auth_provider import RequestsAuth # noqa: F401

# Attribution
from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401
Expand Down
4 changes: 3 additions & 1 deletion sagemaker-core/src/sagemaker/core/partner_app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""__init__ file for sagemaker.core.partner_app"""
from __future__ import absolute_import
from __future__ import annotations

from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401
from sagemaker.core.partner_app.auth_provider import RequestsAuth # noqa: F401
from sagemaker.core.partner_app.auth_utils import PartnerAppAuthUtils # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# language governing permissions and limitations under the License.

"""The SageMaker partner application SDK auth module"""
from __future__ import absolute_import
from __future__ import annotations

import logging
import os
import re
from typing import Dict, Tuple
Expand All @@ -25,6 +26,8 @@
from requests.models import PreparedRequest
from sagemaker.core.partner_app.auth_utils import PartnerAppAuthUtils

logger = logging.getLogger(__name__)

SERVICE_NAME = "sagemaker"
AWS_PARTNER_APP_ARN_REGEX = r"arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:partner-app\/.*"

Expand Down Expand Up @@ -94,6 +97,7 @@ def __init__(self, credentials: Credentials = None):
credentials if credentials is not None else boto3.Session().get_credentials()
)
self.sigv4 = SigV4Auth(self.credentials, SERVICE_NAME, self.region)
logger.info("PartnerAppAuthProvider initialized for region: %s", self.region)

def get_signed_request(
self, url: str, method: str, headers: dict, body: object
Expand All @@ -109,6 +113,7 @@ def get_signed_request(
Returns:
tuple: (url, headers)
"""
logger.debug("Signing request: %s %s", method, url)
return PartnerAppAuthUtils.get_signed_request(
sigv4=self.sigv4,
app_arn=self.app_arn,
Expand Down
5 changes: 4 additions & 1 deletion sagemaker-core/src/sagemaker/core/partner_app/auth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@

"""Partner App Auth Utils Module"""

from __future__ import absolute_import
from __future__ import annotations

import logging
from hashlib import sha256
import functools
from typing import Tuple, Dict

from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest

logger = logging.getLogger(__name__)

HEADER_CONNECTION = "Connection"
HEADER_X_AMZ_TARGET = "X-Amz-Target"
HEADER_AUTHORIZATION = "Authorization"
Expand Down
Loading