diff --git a/smart_tests/__main__.py b/smart_tests/__main__.py index 08ce06bad..7e48df83c 100644 --- a/smart_tests/__main__.py +++ b/smart_tests/__main__.py @@ -16,6 +16,7 @@ from smart_tests.commands.subset import subset from smart_tests.commands.update import update from smart_tests.commands.verify import verify +from smart_tests.utils.tracking import send_command_tracking cli = Group(name="cli", callback=Application) cli.add_command(record) @@ -59,7 +60,18 @@ def _load_test_runners(): def main(): - cli.main() + argv = sys.argv[:] + exit_code = 0 + try: + cli.main() + except SystemExit as e: + exit_code = e.code if isinstance(e.code, int) else 1 + finally: + try: + send_command_tracking(argv=argv, exit_code=exit_code) + except Exception: + pass + sys.exit(exit_code) if __name__ == '__main__': diff --git a/smart_tests/utils/commands.py b/smart_tests/utils/commands.py index 6424d336c..dbfd24d6e 100644 --- a/smart_tests/utils/commands.py +++ b/smart_tests/utils/commands.py @@ -6,12 +6,19 @@ class Command(Enum): RECORD_TESTS = 'RECORD_TESTS' RECORD_BUILD = 'RECORD_BUILD' RECORD_SESSION = 'RECORD_SESSION' + RECORD_ATTACHMENT = 'RECORD_ATTACHMENT' + RECORD_DEPLOYMENT = 'RECORD_DEPLOYMENT' SUBSET = 'SUBSET' COMMIT = 'COMMIT' DETECT_FLAKE = 'DETECT_FLAKE' GATE = 'GATE' UPDATE_ALIAS = 'UPDATE_ALIAS' - RECORD_DEPLOYMENT = 'RECORD_DEPLOYMENT' + INSPECT_MODEL = 'INSPECT_MODEL' + INSPECT_SUBSET = 'INSPECT_SUBSET' + STATS_TEST_SESSIONS = 'STATS_TEST_SESSIONS' + COMPARE_SUBSETS = 'COMPARE_SUBSETS' + GET_DOCS = 'GET_DOCS' + UNKNOWN = 'UNKNOWN' # when you add a new constant here, the server also needs to get a new constant in cli_tracking.proto diff --git a/smart_tests/utils/env_keys.py b/smart_tests/utils/env_keys.py index 4076d4da8..b4ad2ff9d 100644 --- a/smart_tests/utils/env_keys.py +++ b/smart_tests/utils/env_keys.py @@ -9,6 +9,7 @@ COMMIT_TIMEOUT = "SMART_TESTS_COMMIT_TIMEOUT" SKIP_CERT_VERIFICATION = "SMART_TESTS_SKIP_CERT_VERIFICATION" SESSION_DIR_KEY = "SMART_TESTS_SESSION_DIR" +CALLER_KEY = "SMART_TESTS_CALLER" # Legacy token key for backward compatibility LEGACY_TOKEN_KEY = "LAUNCHABLE_TOKEN" @@ -17,3 +18,15 @@ def get_token(): """Get token with backward compatibility for LAUNCHABLE_TOKEN.""" return os.getenv(TOKEN_KEY) or os.getenv(LEGACY_TOKEN_KEY) + + +def detect_ci_provider() -> str: + if os.environ.get("GITHUB_ACTIONS"): + return "github-actions" + if os.environ.get("JENKINS_URL"): + return "jenkins" + if os.environ.get("CIRCLECI"): + return "circleci" + if os.environ.get("CODEBUILD_BUILD_ID"): + return "codebuild" + return "" diff --git a/smart_tests/utils/tracking.py b/smart_tests/utils/tracking.py index 8eaade13e..68429bfa1 100644 --- a/smart_tests/utils/tracking.py +++ b/smart_tests/utils/tracking.py @@ -1,21 +1,74 @@ +import os from enum import Enum +from itertools import takewhile from typing import Any, Dict, Union from requests import Session from smart_tests.app import Application from smart_tests.utils.authentication import get_org_workspace +from smart_tests.utils.env_keys import CALLER_KEY, detect_ci_provider from smart_tests.utils.http_client import _HttpClient, _join_paths from smart_tests.version import __version__ from .commands import Command +# Map CLI subcommand tokens to Command enum values. +# Longer matches are tried first so "record build" matches before "record". +_COMMAND_MAP = { + ("verify",): Command.VERIFY, + ("record", "build"): Command.RECORD_BUILD, + ("record", "session"): Command.RECORD_SESSION, + ("record", "tests"): Command.RECORD_TESTS, + ("record", "commit"): Command.COMMIT, + ("record", "attachment"): Command.RECORD_ATTACHMENT, + ("record", "deployment"): Command.RECORD_DEPLOYMENT, + ("subset",): Command.SUBSET, + ("detect-flakes",): Command.DETECT_FLAKE, + ("gate",): Command.GATE, + ("update", "alias"): Command.UPDATE_ALIAS, + ("inspect", "model"): Command.INSPECT_MODEL, + ("inspect", "subset"): Command.INSPECT_SUBSET, + ("stats", "test_sessions"): Command.STATS_TEST_SESSIONS, + ("compare", "subsets"): Command.COMPARE_SUBSETS, + ("get", "docs"): Command.GET_DOCS, +} + + +def _detect_command(argv: list[str]) -> Command: + """Best-effort detection of the Command from argv. Returns UNKNOWN for typos.""" + command_tokens = list(takewhile(lambda a: not a.startswith("-"), argv[1:])) + + for tokens, command in sorted(_COMMAND_MAP.items(), key=lambda x: -len(x[0])): + if tuple(command_tokens[:len(tokens)]) == tokens: + return command + return Command.UNKNOWN + + +def send_command_tracking(argv: list[str], exit_code: int): + """Send a single COMMAND_INVOCATION event with the full command string. Fire-and-forget.""" + client = TrackingClient(_detect_command(argv)) + metadata = { + "exitCode": str(exit_code), + } + + raw_command = " ".join(argv)[:2000] + + payload = client.construct_payload( + event_name=Tracking.Event.COMMAND_INVOCATION, + metadata=metadata, + raw_command=raw_command, + ) + + client.post_payload(payload=payload) + class Tracking: # General events class Event(Enum): SHALLOW_CLONE = 'SHALLOW_CLONE' # this event is an example PERFORMANCE = 'PERFORMANCE' + COMMAND_INVOCATION = 'COMMAND_INVOCATION' # Error events class ErrorEvent(Enum): @@ -45,14 +98,8 @@ def send_event( event_name: Tracking.Event, metadata: Dict[str, Any] | None = None ): - org, workspace = get_org_workspace() - if metadata is None: - metadata = {} - metadata["organization"] = org or "" - metadata["workspace"] = workspace or "" - self._post_payload( - event_name=event_name, - metadata=metadata, + self.post_payload( + payload=self.construct_payload(event_name=event_name, metadata=metadata), ) def send_error_event( @@ -62,34 +109,49 @@ def send_error_event( api: str = "", metadata: Dict[str, Any] | None = None ): - org, workspace = get_org_workspace() if metadata is None: metadata = {} metadata["stackTrace"] = stack_trace - metadata["organization"] = org or "" - metadata["workspace"] = workspace or "" metadata["api"] = api - self._post_payload( - event_name=event_name, - metadata=metadata, - ) - def _post_payload( + payload = self.construct_payload(event_name=event_name, metadata=metadata) + self.post_payload(payload=payload) + + def post_payload( self, - event_name: Union[Tracking.Event, Tracking.ErrorEvent], - metadata: Dict[str, Any] + payload: dict, ): - payload = { - "command": self.command.value, - "eventName": event_name.value, - "cliVersion": __version__, - "metadata": metadata, - } path = _join_paths( '/intake', 'cli_tracking' ) try: - self.http_client.request('post', payload=payload, path=path) + self.http_client.request('post', payload=payload, path=path, timeout=(2, 2)) except Exception: pass + + def construct_payload( + self, + event_name: Union[Tracking.Event, Tracking.ErrorEvent], + metadata: Dict[str, Any] | None = None, + raw_command: str | None = None + ) -> dict: + org, workspace = get_org_workspace() + + if metadata is None: + metadata = {} + + metadata["organization"] = org or "" + metadata["workspace"] = workspace or "" + metadata["caller"] = os.environ.get(CALLER_KEY) or "cli" + metadata["ciProvider"] = detect_ci_provider() + + payload = { + "command": self.command.value, + "eventName": event_name.value, + "cliVersion": __version__, + "metadata": metadata, + "rawCommand": raw_command, + } + + return payload diff --git a/tests/utils/test_http_client.py b/tests/utils/test_http_client.py index 6333c5f82..e4040320e 100644 --- a/tests/utils/test_http_client.py +++ b/tests/utils/test_http_client.py @@ -10,7 +10,12 @@ class HttpClientTest(TestCase): @mock.patch.dict( os.environ, - {"SMART_TESTS_ORGANIZATION": "launchableinc", "SMART_TESTS_WORKSPACE": "test"}, + { + "SMART_TESTS_ORGANIZATION": "launchableinc", + "SMART_TESTS_WORKSPACE": "test", + "SMART_TESTS_TOKEN": "", + "LAUNCHABLE_TOKEN": "", + }, clear=True, ) def test_header(self): diff --git a/tests/utils/test_tracking.py b/tests/utils/test_tracking.py new file mode 100644 index 000000000..6538011d3 --- /dev/null +++ b/tests/utils/test_tracking.py @@ -0,0 +1,312 @@ +import json +import os +from unittest import TestCase, mock + +import responses + +from smart_tests.utils.commands import Command +from smart_tests.utils.env_keys import detect_ci_provider +from smart_tests.utils.http_client import get_base_url +from smart_tests.utils.tracking import _COMMAND_MAP, Tracking, TrackingClient, _detect_command, send_command_tracking + + +class DetectCommandTest(TestCase): + + def test_verify(self): + self.assertEqual(_detect_command(["smart-tests", "verify"]), Command.VERIFY) + + def test_record_build(self): + self.assertEqual(_detect_command(["smart-tests", "record", "build", "--name", "foo"]), Command.RECORD_BUILD) + + def test_record_session(self): + self.assertEqual(_detect_command(["smart-tests", "record", "session", "--build", "123"]), Command.RECORD_SESSION) + + def test_subset(self): + self.assertEqual(_detect_command(["smart-tests", "subset", "pytest", "--target", "30%"]), Command.SUBSET) + + def test_detect_flakes(self): + self.assertEqual(_detect_command(["smart-tests", "detect-flakes", "pytest"]), Command.DETECT_FLAKE) + + def test_gate(self): + self.assertEqual(_detect_command(["smart-tests", "gate", "--session", "builds/1/test_sessions/2"]), Command.GATE) + + def test_update_alias(self): + self.assertEqual(_detect_command(["smart-tests", "update", "alias", "--build", "foo"]), Command.UPDATE_ALIAS) + + def test_command_token_in_flag_value_not_misdetected(self): + """Flag values that match command names must not affect detection.""" + # "verify" is a flag value, not a command + self.assertEqual( + _detect_command(["smart-tests", "record", "build", "--name", "verify"]), + Command.RECORD_BUILD, + ) + # "build" is a flag value, not a subcommand of record + self.assertEqual( + _detect_command(["smart-tests", "record", "tests", "--name", "build"]), + Command.RECORD_TESTS, + ) + + def test_flag_value_matching_subcommand_not_misdetected(self): + """'record --name build' must not be detected as RECORD_BUILD.""" + self.assertEqual( + _detect_command(["smart-tests", "record", "--name", "build"]), + Command.UNKNOWN, + ) + + def test_typo_returns_unknown(self): + self.assertEqual(_detect_command(["smart-tests", "recrd", "build"]), Command.UNKNOWN) + + def test_no_subcommand_returns_unknown(self): + self.assertEqual(_detect_command(["smart-tests"]), Command.UNKNOWN) + + def test_global_options_before_command(self): + self.assertEqual( + _detect_command(["smart-tests", "--dry-run", "record", "build", "--name", "foo"]), + Command.UNKNOWN, + ) + + def test_record_attachment(self): + self.assertEqual(_detect_command(["smart-tests", "record", "attachment", "--session", "s1"]), Command.RECORD_ATTACHMENT) + + def test_record_deployment(self): + self.assertEqual(_detect_command(["smart-tests", "record", "deployment", "--build", "b1"]), Command.RECORD_DEPLOYMENT) + + def test_inspect_model(self): + self.assertEqual(_detect_command(["smart-tests", "inspect", "model"]), Command.INSPECT_MODEL) + + def test_inspect_subset(self): + self.assertEqual(_detect_command(["smart-tests", "inspect", "subset", "--subset-id", "123"]), Command.INSPECT_SUBSET) + + def test_stats_test_sessions(self): + self.assertEqual(_detect_command(["smart-tests", "stats", "test_sessions", "--days", "7"]), Command.STATS_TEST_SESSIONS) + + def test_compare_subsets(self): + self.assertEqual(_detect_command(["smart-tests", "compare", "subsets"]), Command.COMPARE_SUBSETS) + + def test_get_docs(self): + self.assertEqual(_detect_command(["smart-tests", "get", "docs"]), Command.GET_DOCS) + + def test_command_map_covers_all_enum_values(self): + mapped_commands = set(_COMMAND_MAP.values()) + all_commands = {c for c in Command if c != Command.UNKNOWN} + self.assertEqual(mapped_commands, all_commands, + f"Commands missing from _COMMAND_MAP: {all_commands - mapped_commands}") + + +class DetectCiProviderTest(TestCase): + + def test_no_ci(self): + with mock.patch.dict(os.environ, {}, clear=True): + self.assertEqual(detect_ci_provider(), "") + + def test_github_actions(self): + with mock.patch.dict(os.environ, {"GITHUB_ACTIONS": "true"}, clear=True): + self.assertEqual(detect_ci_provider(), "github-actions") + + def test_jenkins(self): + with mock.patch.dict(os.environ, {"JENKINS_URL": "https://jenkins.example.com"}, clear=True): + self.assertEqual(detect_ci_provider(), "jenkins") + + def test_circleci(self): + with mock.patch.dict(os.environ, {"CIRCLECI": "true"}, clear=True): + self.assertEqual(detect_ci_provider(), "circleci") + + def test_codebuild(self): + with mock.patch.dict(os.environ, {"CODEBUILD_BUILD_ID": "build:123"}, clear=True): + self.assertEqual(detect_ci_provider(), "codebuild") + + +class TrackingCallerTest(TestCase): + + @mock.patch.dict( + os.environ, + {"SMART_TESTS_TOKEN": "v1:org/ws:token"}, + ) + @responses.activate + def test_default_caller_is_cli(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + client = TrackingClient(Command.VERIFY, base_url=get_base_url()) + client.send_event(Tracking.Event.PERFORMANCE, {"duration": 100}) + + payload = json.loads(responses.calls[0].request.body) + self.assertEqual(payload["metadata"]["caller"], "cli") + + @mock.patch.dict( + os.environ, + { + "SMART_TESTS_TOKEN": "v1:org/ws:token", + "SMART_TESTS_CALLER": "github-action", + }, + ) + @responses.activate + def test_caller_from_env(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + client = TrackingClient(Command.RECORD_BUILD, base_url=get_base_url()) + client.send_event(Tracking.Event.PERFORMANCE, {"duration": 100}) + + payload = json.loads(responses.calls[0].request.body) + self.assertEqual(payload["metadata"]["caller"], "github-action") + + @mock.patch.dict( + os.environ, + { + "SMART_TESTS_TOKEN": "v1:org/ws:token", + "GITHUB_ACTIONS": "true", + }, + ) + @responses.activate + def test_ci_provider_auto_detected(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + client = TrackingClient(Command.SUBSET, base_url=get_base_url()) + client.send_event(Tracking.Event.PERFORMANCE, {"duration": 100}) + + payload = json.loads(responses.calls[0].request.body) + self.assertEqual(payload["metadata"]["ciProvider"], "github-actions") + self.assertEqual(payload["metadata"]["caller"], "cli") + + @mock.patch.dict( + os.environ, + { + "SMART_TESTS_TOKEN": "v1:org/ws:token", + "GITHUB_ACTIONS": "true", + "SMART_TESTS_CALLER": "github-action", + }, + ) + @responses.activate + def test_caller_and_ci_provider_together(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + client = TrackingClient(Command.RECORD_BUILD, base_url=get_base_url()) + client.send_event(Tracking.Event.PERFORMANCE, {"duration": 100}) + + payload = json.loads(responses.calls[0].request.body) + self.assertEqual(payload["metadata"]["caller"], "github-action") + self.assertEqual(payload["metadata"]["ciProvider"], "github-actions") + + @mock.patch.dict( + os.environ, + {"SMART_TESTS_TOKEN": "v1:org/ws:token"}, + ) + @responses.activate + def test_error_event_includes_caller(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + client = TrackingClient(Command.GATE, base_url=get_base_url()) + client.send_error_event( + event_name=Tracking.ErrorEvent.INTERNAL_CLI_ERROR, + stack_trace="some error", + ) + + payload = json.loads(responses.calls[0].request.body) + self.assertEqual(payload["metadata"]["caller"], "cli") + self.assertIn("ciProvider", payload["metadata"]) + + +class SendCommandTrackingTest(TestCase): + + @mock.patch.dict( + os.environ, + {"SMART_TESTS_TOKEN": "v1:org/ws:token"}, + ) + @responses.activate + def test_sends_command_invocation(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + send_command_tracking( + argv=["smart-tests", "record", "build", "--name", "foo"], + exit_code=0, + ) + + self.assertEqual(len(responses.calls), 1) + payload = json.loads(responses.calls[0].request.body) + self.assertEqual(payload["command"], "RECORD_BUILD") + self.assertEqual(payload["eventName"], "COMMAND_INVOCATION") + self.assertEqual(payload["rawCommand"], "smart-tests record build --name foo") + self.assertIn("cliVersion", payload) + metadata = payload["metadata"] + self.assertEqual(metadata["exitCode"], "0") + self.assertEqual(metadata["caller"], "cli") + + @mock.patch.dict( + os.environ, + { + "SMART_TESTS_TOKEN": "v1:org/ws:token", + "SMART_TESTS_CALLER": "github-action", + "GITHUB_ACTIONS": "true", + }, + ) + @responses.activate + def test_includes_caller_and_ci_in_metadata(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + send_command_tracking(argv=["smart-tests", "verify"], exit_code=0) + + payload = json.loads(responses.calls[0].request.body) + metadata = payload["metadata"] + self.assertEqual(metadata["caller"], "github-action") + self.assertEqual(metadata["ciProvider"], "github-actions") + + @mock.patch.dict( + os.environ, + {"SMART_TESTS_TOKEN": "v1:org/ws:token"}, + ) + @responses.activate + def test_swallows_exceptions(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={"error": "server error"}, + status=500, + ) + # Should not raise + send_command_tracking(argv=["smart-tests", "verify"], exit_code=0) + + @mock.patch.dict( + os.environ, + {"SMART_TESTS_TOKEN": "v1:org/ws:token"}, + ) + @responses.activate + def test_typo_maps_to_unknown(self): + responses.add( + responses.POST, + f"{get_base_url()}/intake/cli_tracking", + json={}, + status=200, + ) + send_command_tracking(argv=["smart-tests", "recrd", "build"], exit_code=1) + + payload = json.loads(responses.calls[0].request.body) + self.assertEqual(payload["command"], "UNKNOWN") + self.assertEqual(payload["metadata"]["exitCode"], "1") + self.assertEqual(payload["rawCommand"], "smart-tests recrd build")