diff --git a/awscli/customizations/configure/sso_commands.py b/awscli/customizations/configure/sso_commands.py index e9d80cc548d3..32a09ffe4412 100644 --- a/awscli/customizations/configure/sso_commands.py +++ b/awscli/customizations/configure/sso_commands.py @@ -26,6 +26,7 @@ import logging import os +from urllib.parse import urlparse import colorama from botocore import UNSIGNED @@ -38,6 +39,10 @@ profile_to_section, ) from awscli.customizations.configure.writer import ConfigFileWriter +from awscli.customizations.sso.resolve import ( + is_aws_owned_domain, + resolve_start_url, +) from awscli.customizations.sso.utils import ( LOGIN_ARGS, BaseSSOCommand, @@ -364,8 +369,13 @@ def _prompt_for_sso_registration_args(self): def _prompt_for_registration_args_with_legacy_format(self): self._store_sso_session_prompter_answers_to_profile_config() self._set_sso_session_defaults_from_profile_config() - start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() - return {'start_url': start_url, 'sso_region': sso_region} + start_url, sso_region, resolved_url = ( + self._prompt_for_sso_start_url_and_sso_region() + ) + args = {'start_url': start_url, 'sso_region': sso_region} + if resolved_url: + args['resolved_start_url'] = resolved_url + return args def _get_sso_registration_args_from_sso_config(self, sso_session): sso_config = self._get_sso_session_config(sso_session) @@ -378,17 +388,22 @@ def _get_sso_registration_args_from_sso_config(self, sso_session): def _prompt_for_registration_args_for_new_sso_session(self, sso_session): self._set_sso_session_defaults_from_profile_config() - start_url, sso_region = self._prompt_for_sso_start_url_and_sso_region() + start_url, sso_region, resolved_url = ( + self._prompt_for_sso_start_url_and_sso_region() + ) scopes = ( self._sso_session_prompter.prompt_for_sso_registration_scopes() ) - return { + args = { 'session_name': sso_session, 'start_url': start_url, 'sso_region': sso_region, 'registration_scopes': scopes, 'force_refresh': True, } + if resolved_url: + args['resolved_start_url'] = resolved_url + return args def _store_sso_session_prompter_answers_to_profile_config(self): self._sso_session_prompter.sso_session_config = ( @@ -410,8 +425,25 @@ def _set_sso_session_defaults_from_profile_config(self): def _prompt_for_sso_start_url_and_sso_region(self): start_url = self._sso_session_prompter.prompt_for_sso_start_url() + hostname = urlparse(start_url).hostname + if hostname and not is_aws_owned_domain(hostname): + try: + resolved_url, region = resolve_start_url( + start_url, session=self._session + ) + self._sso_session_prompter.sso_session_config['sso_region'] = ( + region + ) + return start_url, region, resolved_url + except Exception as e: + logger.debug( + "Failed to resolve vanity URL '%s': %s. " + "Falling back to region prompt.", + start_url, + e, + ) sso_region = self._sso_session_prompter.prompt_for_sso_region() - return start_url, sso_region + return start_url, sso_region, None def _warn_configuring_using_legacy_format(self): uni_print( @@ -479,8 +511,26 @@ class ConfigureSSOSessionCommand(BaseSSOConfigurationCommand): def _run_main(self, parsed_args, parsed_globals): super()._run_main(parsed_args, parsed_globals) self._sso_session_prompter.prompt_for_sso_session() - self._sso_session_prompter.prompt_for_sso_start_url() - self._sso_session_prompter.prompt_for_sso_region() + start_url = self._sso_session_prompter.prompt_for_sso_start_url() + hostname = urlparse(start_url).hostname + if hostname and not is_aws_owned_domain(hostname): + try: + resolved_url, region = resolve_start_url( + start_url, session=self._session + ) + self._sso_session_prompter.sso_session_config['sso_region'] = ( + region + ) + except Exception as e: + logger.debug( + "Failed to resolve vanity URL '%s': %s. " + "Falling back to region prompt.", + start_url, + e, + ) + self._sso_session_prompter.prompt_for_sso_region() + else: + self._sso_session_prompter.prompt_for_sso_region() self._sso_session_prompter.prompt_for_sso_registration_scopes() self._write_sso_configuration() self._print_configuration_success() diff --git a/tests/unit/customizations/configure/test_sso_resolve.py b/tests/unit/customizations/configure/test_sso_resolve.py new file mode 100644 index 000000000000..1cc70503eca3 --- /dev/null +++ b/tests/unit/customizations/configure/test_sso_resolve.py @@ -0,0 +1,222 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest + +from awscli.customizations.exceptions import ConfigurationError +from awscli.testutils import mock + + +class TestConfigureSSOVanityUrlResolution: + def _create_command(self): + from awscli.customizations.configure.sso_commands import ( + ConfigureSSOCommand, + ) + + session = mock.Mock() + session.full_config = {'sso_sessions': {}} + cmd = ConfigureSSOCommand(session) + cmd._sso_session_prompter = mock.Mock() + return cmd + + def test_vanity_url_skips_region_prompt(self): + cmd = self._create_command() + cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = ( + 'https://aws.mycompany.com' + ) + cmd._sso_session_prompter.sso_session_config = {} + with ( + mock.patch( + 'awscli.customizations.configure.sso_commands.is_aws_owned_domain', + return_value=False, + ), + mock.patch( + 'awscli.customizations.configure.sso_commands.resolve_start_url', + return_value=( + 'https://ssoins-abc.portal.us-east-1.app.aws', + 'us-east-1', + ), + ), + ): + start_url, region, resolved_url = ( + cmd._prompt_for_sso_start_url_and_sso_region() + ) + + assert start_url == 'https://aws.mycompany.com' + assert region == 'us-east-1' + assert resolved_url == 'https://ssoins-abc.portal.us-east-1.app.aws' + cmd._sso_session_prompter.prompt_for_sso_region.assert_not_called() + + def test_vanity_url_resolution_failure_falls_back_to_prompt(self): + cmd = self._create_command() + cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = ( + 'https://aws.mycompany.com' + ) + cmd._sso_session_prompter.prompt_for_sso_region.return_value = ( + 'eu-west-1' + ) + with ( + mock.patch( + 'awscli.customizations.configure.sso_commands.is_aws_owned_domain', + return_value=False, + ), + mock.patch( + 'awscli.customizations.configure.sso_commands.resolve_start_url', + side_effect=ConfigurationError("Failed to resolve"), + ), + ): + start_url, region, resolved_url = ( + cmd._prompt_for_sso_start_url_and_sso_region() + ) + + assert start_url == 'https://aws.mycompany.com' + assert region == 'eu-west-1' + assert resolved_url is None + cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once() + + def test_direct_url_prompts_for_region(self): + cmd = self._create_command() + cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = ( + 'https://ssoins-abc.portal.us-west-2.app.aws' + ) + cmd._sso_session_prompter.prompt_for_sso_region.return_value = ( + 'us-west-2' + ) + with ( + mock.patch( + 'awscli.customizations.configure.sso_commands.is_aws_owned_domain', + return_value=True, + ), + mock.patch( + 'awscli.customizations.configure.sso_commands.resolve_start_url', + ) as mock_resolve, + ): + start_url, region, resolved_url = ( + cmd._prompt_for_sso_start_url_and_sso_region() + ) + + assert start_url == 'https://ssoins-abc.portal.us-west-2.app.aws' + assert region == 'us-west-2' + assert resolved_url is None + cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once() + mock_resolve.assert_not_called() + + def test_vanity_url_persists_resolved_region_in_session_config(self): + cmd = self._create_command() + cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = ( + 'https://aws.mycompany.com' + ) + cmd._sso_session_prompter.sso_session_config = {} + with ( + mock.patch( + 'awscli.customizations.configure.sso_commands.is_aws_owned_domain', + return_value=False, + ), + mock.patch( + 'awscli.customizations.configure.sso_commands.resolve_start_url', + return_value=( + 'https://ssoins-abc.portal.us-east-1.app.aws', + 'us-east-1', + ), + ), + ): + cmd._prompt_for_sso_start_url_and_sso_region() + + assert ( + cmd._sso_session_prompter.sso_session_config['sso_region'] + == 'us-east-1' + ) + + +class TestConfigureSSOSessionVanityUrlResolution: + def _create_command(self): + from awscli.customizations.configure.sso_commands import ( + ConfigureSSOSessionCommand, + ) + + session = mock.Mock() + session.full_config = {'sso_sessions': {}} + cmd = ConfigureSSOSessionCommand(session) + cmd._sso_session_prompter = mock.Mock() + cmd._sso_session_prompter.sso_session_config = {} + cmd._init_prompt_toolkit = mock.Mock() + return cmd + + def test_vanity_url_skips_region_prompt(self): + cmd = self._create_command() + cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = ( + 'https://aws.mycompany.com' + ) + with ( + mock.patch( + 'awscli.customizations.configure.sso_commands.is_aws_owned_domain', + return_value=False, + ), + mock.patch( + 'awscli.customizations.configure.sso_commands.resolve_start_url', + return_value=( + 'https://ssoins-abc.portal.us-east-1.app.aws', + 'us-east-1', + ), + ), + mock.patch.object(cmd, '_write_sso_configuration'), + mock.patch.object(cmd, '_print_configuration_success'), + ): + cmd._run_main(mock.Mock(), mock.Mock()) + + cmd._sso_session_prompter.prompt_for_sso_region.assert_not_called() + assert ( + cmd._sso_session_prompter.sso_session_config['sso_region'] + == 'us-east-1' + ) + + def test_vanity_url_failure_falls_back_to_prompt(self): + cmd = self._create_command() + cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = ( + 'https://aws.mycompany.com' + ) + with ( + mock.patch( + 'awscli.customizations.configure.sso_commands.is_aws_owned_domain', + return_value=False, + ), + mock.patch( + 'awscli.customizations.configure.sso_commands.resolve_start_url', + side_effect=ConfigurationError("Failed to resolve"), + ), + mock.patch.object(cmd, '_write_sso_configuration'), + mock.patch.object(cmd, '_print_configuration_success'), + ): + cmd._run_main(mock.Mock(), mock.Mock()) + + cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once() + + def test_direct_url_prompts_for_region(self): + cmd = self._create_command() + cmd._sso_session_prompter.prompt_for_sso_start_url.return_value = ( + 'https://ssoins-abc.portal.us-west-2.app.aws' + ) + with ( + mock.patch( + 'awscli.customizations.configure.sso_commands.is_aws_owned_domain', + return_value=True, + ), + mock.patch( + 'awscli.customizations.configure.sso_commands.resolve_start_url', + ) as mock_resolve, + mock.patch.object(cmd, '_write_sso_configuration'), + mock.patch.object(cmd, '_print_configuration_success'), + ): + cmd._run_main(mock.Mock(), mock.Mock()) + + cmd._sso_session_prompter.prompt_for_sso_region.assert_called_once() + mock_resolve.assert_not_called()