Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions awscli/customizations/configure/sso_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import logging
import os
from urllib.parse import urlparse

import colorama
from botocore import UNSIGNED
Expand All @@ -38,6 +39,10 @@
profile_to_section,
)
from awscli.customizations.configure.writer import ConfigFileWriter
from awscli.customizations.sso.resolve import (
is_aws_owned_domain,
resolve_start_url,
)
from awscli.customizations.sso.utils import (
LOGIN_ARGS,
BaseSSOCommand,
Expand Down Expand Up @@ -364,8 +369,13 @@ def _prompt_for_sso_registration_args(self):
def _prompt_for_registration_args_with_legacy_format(self):
self._store_sso_session_prompter_answers_to_profile_config()
self._set_sso_session_defaults_from_profile_config()
start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region()
return {'start_url': start_url, 'sso_region': sso_region}
start_url, sso_region, resolved_url = (
self._prompt_for_sso_start_url_and_sso_region()
)
args = {'start_url': start_url, 'sso_region': sso_region}
if resolved_url:
args['resolved_start_url'] = resolved_url
return args

def _get_sso_registration_args_from_sso_config(self, sso_session):
sso_config = self._get_sso_session_config(sso_session)
Expand All @@ -378,17 +388,22 @@ def _get_sso_registration_args_from_sso_config(self, sso_session):

def _prompt_for_registration_args_for_new_sso_session(self, sso_session):
self._set_sso_session_defaults_from_profile_config()
start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region()
start_url, sso_region, resolved_url = (
self._prompt_for_sso_start_url_and_sso_region()
)
scopes = (
self._sso_session_prompter.prompt_for_sso_registration_scopes()
)
return {
args = {
'session_name': sso_session,
'start_url': start_url,
'sso_region': sso_region,
'registration_scopes': scopes,
'force_refresh': True,
}
if resolved_url:
args['resolved_start_url'] = resolved_url
return args

def _store_sso_session_prompter_answers_to_profile_config(self):
self._sso_session_prompter.sso_session_config = (
Expand All @@ -410,8 +425,25 @@ def _set_sso_session_defaults_from_profile_config(self):

def _prompt_for_sso_start_url_and_sso_region(self):
start_url = self._sso_session_prompter.prompt_for_sso_start_url()
hostname = urlparse(start_url).hostname
if hostname and not is_aws_owned_domain(hostname):
try:
resolved_url, region = resolve_start_url(
start_url, session=self._session
)
self._sso_session_prompter.sso_session_config['sso_region'] = (
region
)
return start_url, region, resolved_url
except Exception as e:
logger.debug(
"Failed to resolve vanity URL '%s': %s. "
"Falling back to region prompt.",
start_url,
e,
)
sso_region = self._sso_session_prompter.prompt_for_sso_region()
return start_url, sso_region
return start_url, sso_region, None

def _warn_configuring_using_legacy_format(self):
uni_print(
Expand Down Expand Up @@ -479,8 +511,26 @@ class ConfigureSSOSessionCommand(BaseSSOConfigurationCommand):
def _run_main(self, parsed_args, parsed_globals):
super()._run_main(parsed_args, parsed_globals)
self._sso_session_prompter.prompt_for_sso_session()
self._sso_session_prompter.prompt_for_sso_start_url()
self._sso_session_prompter.prompt_for_sso_region()
start_url = self._sso_session_prompter.prompt_for_sso_start_url()
hostname = urlparse(start_url).hostname
if hostname and not is_aws_owned_domain(hostname):
try:
resolved_url, region = resolve_start_url(
start_url, session=self._session
)
self._sso_session_prompter.sso_session_config['sso_region'] = (
region
)
except Exception as e:
logger.debug(
"Failed to resolve vanity URL '%s': %s. "
"Falling back to region prompt.",
start_url,
e,
)
self._sso_session_prompter.prompt_for_sso_region()
else:
self._sso_session_prompter.prompt_for_sso_region()
self._sso_session_prompter.prompt_for_sso_registration_scopes()
self._write_sso_configuration()
self._print_configuration_success()
Expand Down
222 changes: 222 additions & 0 deletions tests/unit/customizations/configure/test_sso_resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import pytest

from awscli.customizations.exceptions import ConfigurationError
from awscli.testutils import mock


class TestConfigureSSOVanityUrlResolution:
def _create_command(self):
from awscli.customizations.configure.sso_commands import (
ConfigureSSOCommand,
)

session = mock.Mock()
session.full_config = {'sso_sessions': {}}
cmd = ConfigureSSOCommand(session)
cmd._sso_session_prompter = mock.Mock()
return cmd

def test_vanity_url_skips_region_prompt(self):
cmd = self._create_command()
cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = (
'https://aws.mycompany.com'
)
cmd._sso_session_prompter.sso_session_config = {}
with (
mock.patch(
'awscli.customizations.configure.sso_commands.is_aws_owned_domain',
return_value=False,
),
mock.patch(
'awscli.customizations.configure.sso_commands.resolve_start_url',
return_value=(
'https://ssoins-abc.portal.us-east-1.app.aws',
'us-east-1',
),
),
):
start_url, region, resolved_url = (
cmd._prompt_for_sso_start_url_and_sso_region()
)

assert start_url == 'https://aws.mycompany.com'
assert region == 'us-east-1'
assert resolved_url == 'https://ssoins-abc.portal.us-east-1.app.aws'
cmd._sso_session_prompter.prompt_for_sso_region.assert_not_called()

def test_vanity_url_resolution_failure_falls_back_to_prompt(self):
cmd = self._create_command()
cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = (
'https://aws.mycompany.com'
)
cmd._sso_session_prompter.prompt_for_sso_region.return_value = (
'eu-west-1'
)
with (
mock.patch(
'awscli.customizations.configure.sso_commands.is_aws_owned_domain',
return_value=False,
),
mock.patch(
'awscli.customizations.configure.sso_commands.resolve_start_url',
side_effect=ConfigurationError("Failed to resolve"),
),
):
start_url, region, resolved_url = (
cmd._prompt_for_sso_start_url_and_sso_region()
)

assert start_url == 'https://aws.mycompany.com'
assert region == 'eu-west-1'
assert resolved_url is None
cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once()

def test_direct_url_prompts_for_region(self):
cmd = self._create_command()
cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = (
'https://ssoins-abc.portal.us-west-2.app.aws'
)
cmd._sso_session_prompter.prompt_for_sso_region.return_value = (
'us-west-2'
)
with (
mock.patch(
'awscli.customizations.configure.sso_commands.is_aws_owned_domain',
return_value=True,
),
mock.patch(
'awscli.customizations.configure.sso_commands.resolve_start_url',
) as mock_resolve,
):
start_url, region, resolved_url = (
cmd._prompt_for_sso_start_url_and_sso_region()
)

assert start_url == 'https://ssoins-abc.portal.us-west-2.app.aws'
assert region == 'us-west-2'
assert resolved_url is None
cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once()
mock_resolve.assert_not_called()

def test_vanity_url_persists_resolved_region_in_session_config(self):
cmd = self._create_command()
cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = (
'https://aws.mycompany.com'
)
cmd._sso_session_prompter.sso_session_config = {}
with (
mock.patch(
'awscli.customizations.configure.sso_commands.is_aws_owned_domain',
return_value=False,
),
mock.patch(
'awscli.customizations.configure.sso_commands.resolve_start_url',
return_value=(
'https://ssoins-abc.portal.us-east-1.app.aws',
'us-east-1',
),
),
):
cmd._prompt_for_sso_start_url_and_sso_region()

assert (
cmd._sso_session_prompter.sso_session_config['sso_region']
== 'us-east-1'
)


class TestConfigureSSOSessionVanityUrlResolution:
def _create_command(self):
from awscli.customizations.configure.sso_commands import (
ConfigureSSOSessionCommand,
)

session = mock.Mock()
session.full_config = {'sso_sessions': {}}
cmd = ConfigureSSOSessionCommand(session)
cmd._sso_session_prompter = mock.Mock()
cmd._sso_session_prompter.sso_session_config = {}
cmd._init_prompt_toolkit = mock.Mock()
return cmd

def test_vanity_url_skips_region_prompt(self):
cmd = self._create_command()
cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = (
'https://aws.mycompany.com'
)
with (
mock.patch(
'awscli.customizations.configure.sso_commands.is_aws_owned_domain',
return_value=False,
),
mock.patch(
'awscli.customizations.configure.sso_commands.resolve_start_url',
return_value=(
'https://ssoins-abc.portal.us-east-1.app.aws',
'us-east-1',
),
),
mock.patch.object(cmd, '_write_sso_configuration'),
mock.patch.object(cmd, '_print_configuration_success'),
):
cmd._run_main(mock.Mock(), mock.Mock())

cmd._sso_session_prompter.prompt_for_sso_region.assert_not_called()
assert (
cmd._sso_session_prompter.sso_session_config['sso_region']
== 'us-east-1'
)

def test_vanity_url_failure_falls_back_to_prompt(self):
cmd = self._create_command()
cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = (
'https://aws.mycompany.com'
)
with (
mock.patch(
'awscli.customizations.configure.sso_commands.is_aws_owned_domain',
return_value=False,
),
mock.patch(
'awscli.customizations.configure.sso_commands.resolve_start_url',
side_effect=ConfigurationError("Failed to resolve"),
),
mock.patch.object(cmd, '_write_sso_configuration'),
mock.patch.object(cmd, '_print_configuration_success'),
):
cmd._run_main(mock.Mock(), mock.Mock())

cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once()

def test_direct_url_prompts_for_region(self):
cmd = self._create_command()
cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = (
'https://ssoins-abc.portal.us-west-2.app.aws'
)
with (
mock.patch(
'awscli.customizations.configure.sso_commands.is_aws_owned_domain',
return_value=True,
),
mock.patch(
'awscli.customizations.configure.sso_commands.resolve_start_url',
) as mock_resolve,
mock.patch.object(cmd, '_write_sso_configuration'),
mock.patch.object(cmd, '_print_configuration_success'),
):
cmd._run_main(mock.Mock(), mock.Mock())

cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once()
mock_resolve.assert_not_called()
Loading