diff --git a/.changes/next-release/bugfix-UserAgent-47668.json b/.changes/next-release/bugfix-UserAgent-47668.json new file mode 100644 index 000000000000..f5256313ec29 --- /dev/null +++ b/.changes/next-release/bugfix-UserAgent-47668.json @@ -0,0 +1,5 @@ +{ + "type": "bugfix", + "category": "User Agent", + "description": "Ensure that session IDs are added to the User-Agent HTTP header even when the local AWS CLI cache does not exist." +} diff --git a/awscli/clidriver.py b/awscli/clidriver.py index 2a5d7f0d1713..9805ca066620 100644 --- a/awscli/clidriver.py +++ b/awscli/clidriver.py @@ -78,7 +78,7 @@ set_stream_logger, ) from awscli.plugin import load_plugins -from awscli.telemetry import add_session_id_component_to_user_agent_extra +from awscli.telemetry import register_session_id_event from awscli.utils import ( IMDSRegionProvider, OutputStreamFactory, @@ -217,7 +217,6 @@ def _set_user_agent_for_session(session): session.user_agent_version = __version__ _add_distribution_source_to_user_agent(session) _add_linux_distribution_to_user_agent(session) - add_session_id_component_to_user_agent_extra(session) def register_no_pager_handler(event_emitter): @@ -296,6 +295,7 @@ def __init__(self, session=None, error_handler=None, debug=False): _set_user_agent_for_session(self.session) else: self.session = session + register_session_id_event(self.session) self._error_handler = error_handler if self._error_handler is None: self._error_handler = construct_cli_error_handlers_chain( diff --git a/awscli/telemetry.py b/awscli/telemetry.py index 3cf44e83ace0..aa2e01eab84f 100644 --- a/awscli/telemetry.py +++ b/awscli/telemetry.py @@ -76,13 +76,14 @@ class CLISessionDatabaseConnection: """ _ENABLE_WAL = 'PRAGMA journal_mode=WAL' - def __init__(self, connection=None): + def __init__(self, connection=None, cache_dir=None): + self._cache_dir = cache_dir or _CACHE_DIR + self._ensure_cache_dir() self._connection = connection or sqlite3.connect( - _CACHE_DIR / _DATABASE_FILENAME, + self._cache_dir / _DATABASE_FILENAME, check_same_thread=False, isolation_level=None, ) - self._ensure_cache_dir() self._ensure_database_setup() def execute(self, query, *parameters): @@ -95,7 +96,7 @@ def execute(self, query, *parameters): return sqlite3.Cursor(self._connection) def _ensure_cache_dir(self): - _CACHE_DIR.mkdir(parents=True, exist_ok=True) + self._cache_dir.mkdir(parents=True, exist_ok=True) def _ensure_database_setup(self): self._create_session_table() @@ -295,17 +296,32 @@ def _get_cli_session_orchestrator(): ) -def add_session_id_component_to_user_agent_extra(session, orchestrator=None): - try: - cli_session_orchestrator = ( - orchestrator or _get_cli_session_orchestrator() - ) - add_component_to_user_agent_extra( - session, - UserAgentComponent("sid", cli_session_orchestrator.session_id), - ) - except Exception: - # Ideally, the AWS CLI should never throw if the session id - # can't be generated since it's not critical for users. Issues - # with session data should instead be caught server-side. - pass +def register_session_id_event(session, orchestrator_factory=None): + if orchestrator_factory is None: + orchestrator_factory = _get_cli_session_orchestrator + event_emitter = session.get_component('event_emitter') + + def _inject_session_id(**kwargs): + try: + orchestrator = orchestrator_factory() + sid_component = UserAgentComponent( + "sid", orchestrator.session_id + ).to_string() + # Insert sid after md/installer to preserve original + # user-agent component ordering. + extra = session.user_agent_extra + idx = extra.find('md/installer') + end = extra.find(' ', idx) + if end == -1: + end = len(extra) + session.user_agent_extra = ( + extra[:end] + f' {sid_component}' + extra[end:] + ) + except Exception: + # Ideally, the AWS CLI should never throw if the session id + # can't be generated since it's not critical for users. Issues + # with session data should instead be caught server-side. + pass + event_emitter.unregister('before-create-client', _inject_session_id) + + event_emitter.register('before-create-client', _inject_session_id) diff --git a/tests/functional/test_telemetry.py b/tests/functional/test_telemetry.py index 698b639d0329..5e4c73f5e32e 100644 --- a/tests/functional/test_telemetry.py +++ b/tests/functional/test_telemetry.py @@ -11,12 +11,13 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import sqlite3 -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import MagicMock, patch import pytest from botocore.exceptions import MD5UnavailableError from botocore.session import Session +from awscli.clidriver import create_clidriver from awscli.telemetry import ( CLISessionData, CLISessionDatabaseConnection, @@ -25,7 +26,7 @@ CLISessionDatabaseWriter, CLISessionGenerator, CLISessionOrchestrator, - add_session_id_component_to_user_agent_extra, + register_session_id_event, ) from tests.markers import skip_if_windows @@ -107,6 +108,20 @@ def test_ensure_database_setup(self, session_conn): ) assert cursor.fetchall() == [('session',), ('host_id',)] + def test_creates_database_when_cache_dir_does_not_exist(self, tmp_path): + # When the cache directory doesn't exist, the connection should still + # be established successfully. + nonexistent_dir = tmp_path / 'nonexistent' / 'nested' / 'cache' + assert not nonexistent_dir.exists() + conn = CLISessionDatabaseConnection(cache_dir=nonexistent_dir) + assert nonexistent_dir.exists() + assert (nonexistent_dir / 'session.db').exists() + # Verify the database is functional. + writer = CLISessionDatabaseWriter(conn) + reader = CLISessionDatabaseReader(conn) + writer.write(CLISessionData('key', 'sid', 1000000000)) + assert reader.read('key').session_id == 'sid' + def test_timeout_does_not_raise_exception(self, session_conn): test_query = """ SELECT name @@ -308,17 +323,44 @@ def test_cached_session_id_not_updated_if_valid( assert session_data_2.timestamp != session_data_1.timestamp -def test_add_session_id_component_to_user_agent_extra(): +def test_register_session_id_event_injects_sid_on_before_create_client(): session = MagicMock(Session) - session.user_agent_extra = '' + session.user_agent_extra = 'md/installer#source' + event_emitter = MagicMock() + session.get_component.return_value = event_emitter orchestrator = MagicMock(CLISessionOrchestrator) orchestrator.session_id = 'my-session-id' - add_session_id_component_to_user_agent_extra(session, orchestrator) - assert session.user_agent_extra == 'sid/my-session-id' + + def fake_orchestrator_factory(): + return orchestrator + + register_session_id_event( + session, orchestrator_factory=fake_orchestrator_factory + ) + handler = event_emitter.register.call_args[0][1] + handler() + assert session.user_agent_extra == 'md/installer#source sid/my-session-id' + event_emitter.unregister.assert_called_once_with( + 'before-create-client', handler + ) -def test_entrypoint_catches_bare_exceptions(): - mock_orchestrator = MagicMock(CLISessionOrchestrator) - type(mock_orchestrator).session_id = PropertyMock(side_effect=Exception) +def test_register_session_id_event_catches_bare_exceptions(): session = MagicMock(Session) - add_session_id_component_to_user_agent_extra(session, mock_orchestrator) + session.user_agent_extra = '' + event_emitter = MagicMock() + session.get_component.return_value = event_emitter + register_session_id_event( + session, orchestrator_factory=MagicMock(side_effect=Exception) + ) + handler = event_emitter.register.call_args[0][1] + handler() + assert session.user_agent_extra == '' + + +def test_user_agent_extra_contains_installer_component(): + # register_session_id_event depends on md/installer being present + # in user_agent_extra to insert sid at the correct position. This + # test ensures that invariant holds after driver creation. + driver = create_clidriver() + assert 'md/installer#' in driver.session.user_agent_extra diff --git a/tests/unit/test_clidriver.py b/tests/unit/test_clidriver.py index bcc5a1ca7c8e..d2a280218a57 100644 --- a/tests/unit/test_clidriver.py +++ b/tests/unit/test_clidriver.py @@ -315,12 +315,6 @@ def _run_main(self, args, parsed_globals): return 0 -class FakeCLISessionOrchestrator: - @property - def session_id(self): - return 'mysessionid' - - class TestCliDriver: def setup_method(self): self.session = FakeSession() @@ -849,19 +843,13 @@ def test_idempotency_token_is_not_required_in_help_text(self): self.assertEqual(rc, 252) self.assertNotIn('--idempotency-token', self.stderr.getvalue()) - @mock.patch( - 'awscli.telemetry._get_cli_session_orchestrator', - return_value=FakeCLISessionOrchestrator(), - ) @mock.patch('awscli.clidriver.platform.system', return_value='Linux') @mock.patch('awscli.clidriver.platform.machine', return_value='x86_64') @mock.patch('awscli.clidriver.distro.id', return_value='amzn') @mock.patch('awscli.clidriver.distro.major_version', return_value='1') def test_user_agent_for_linux(self, *args): driver = create_clidriver() - expected_user_agent = ( - 'md/installer#source md/distrib#amzn.1 sid/mysessionid' - ) + expected_user_agent = 'md/installer#source md/distrib#amzn.1' self.assertEqual(expected_user_agent, driver.session.user_agent_extra) def test_user_agent(self, *args):