diff --git a/lean/components/api/auth0_client.py b/lean/components/api/auth0_client.py index f6b42cbf..5c90cebc 100644 --- a/lean/components/api/auth0_client.py +++ b/lean/components/api/auth0_client.py @@ -29,40 +29,50 @@ def __init__(self, api_client: 'APIClient') -> None: self._api = api_client self._cache = {} - def read(self, brokerage_id: str) -> QCAuth0Authorization: + def read(self, brokerage_id: str, user_name: str = None) -> QCAuth0Authorization: """Reads the authorization data for a brokerage. :param brokerage_id: the id of the brokerage to read the authorization data for + :param user_name: the optional login ID of the user :return: the authorization data for the specified brokerage """ try: # First check cache - if brokerage_id in self._cache.keys(): - return self._cache[brokerage_id] + if user_name: + cache_key = (brokerage_id, user_name) + else: + cache_key = brokerage_id + if cache_key in self._cache: + return self._cache[cache_key] payload = { "brokerage": brokerage_id } + if user_name: + payload["userId"] = user_name data = self._api.post("live/auth0/read", payload) # Store in cache result = QCAuth0Authorization(**data) - self._cache[brokerage_id] = result + self._cache[cache_key] = result return result except RequestFailedError as e: return QCAuth0Authorization(authorization=None) @staticmethod - def authorize(brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False) -> None: + def authorize(brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False, user_name: str = None) -> None: """Starts the authorization process for a brokerage. :param brokerage_id: the id of the brokerage to start the authorization process for :param logger: the logger instance to use :param project_id: The local or cloud project_id + :param user_name: the optional login ID of the user to pre-fill in the authorization page :param no_browser: whether to disable opening the browser """ from webbrowser import open full_url = f"{API_BASE_URL}live/auth0/authorize?brokerage={brokerage_id}&projectId={project_id}" + if user_name: + full_url += f"&userId={user_name}" logger.info(f"Please open the following URL in your browser to authorize the LEAN CLI.") logger.info(full_url) diff --git a/lean/components/util/auth0_helper.py b/lean/components/util/auth0_helper.py index 85906330..fb2d1e5f 100644 --- a/lean/components/util/auth0_helper.py +++ b/lean/components/util/auth0_helper.py @@ -16,7 +16,7 @@ from lean.components.util.logger import Logger -def get_authorization(auth0_client: Auth0Client, brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False) -> QCAuth0Authorization: +def get_authorization(auth0_client: Auth0Client, brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False, user_name: str = None) -> QCAuth0Authorization: """Gets the authorization data for a brokerage, authorizing if necessary. :param auth0_client: An instance of Auth0Client, containing methods to interact with live/auth0/* API endpoints. @@ -28,18 +28,18 @@ def get_authorization(auth0_client: Auth0Client, brokerage_id: str, logger: Logg """ from time import time, sleep - data = auth0_client.read(brokerage_id) + data = auth0_client.read(brokerage_id, user_name=user_name) if data.authorization is not None: return data start_time = time() - auth0_client.authorize(brokerage_id, logger, project_id, no_browser) + auth0_client.authorize(brokerage_id, logger, project_id, no_browser, user_name=user_name) # keep checking for new data every 5 seconds for 7 minutes while time() - start_time < 420: logger.debug("Will sleep 5 seconds and retry fetching authorization...") sleep(5) - data = auth0_client.read(brokerage_id) + data = auth0_client.read(brokerage_id, user_name=user_name) if data.authorization is None: continue return data diff --git a/lean/models/configuration.py b/lean/models/configuration.py index 50d7ce1d..67b79e05 100644 --- a/lean/models/configuration.py +++ b/lean/models/configuration.py @@ -401,6 +401,7 @@ class AuthConfiguration(InternalInputUserInput): def __init__(self, config_json_object): super().__init__(config_json_object) self.require_project_id = config_json_object.get("require-project-id", False) + self.require_user_name = config_json_object.get("require-user-name", False) def factory(config_json_object) -> 'AuthConfiguration': """Creates an instance of the child classes. diff --git a/lean/models/json_module.py b/lean/models/json_module.py index 7c5b2e55..3f83538a 100644 --- a/lean/models/json_module.py +++ b/lean/models/json_module.py @@ -175,6 +175,30 @@ def convert_variable_to_lean_key(self, variable_key: str) -> str: """ return variable_key.replace('_', '-') + def get_user_name(self, lean_config: Dict[str, Any], configuration, user_provided_options: Dict[str, Any], require_user_name: bool) -> str: + """Retrieve the user name, prompting the user if required and not already set. + + :param lean_config: The Lean config dict to read defaults from. + :param configuration: The AuthConfiguration instance. + :param user_provided_options: Options passed as command-line arguments. + :param require_user_name: Flag to determine if prompting is necessary. + :return: The user name, or None if not required. + """ + if not require_user_name: + return None + from click import prompt + user_name_key = configuration._id.replace("-oauth-token", "") + "-user-name" + user_name_variable = self.convert_lean_key_to_variable(user_name_key) + if user_name_variable in user_provided_options and user_provided_options[user_name_variable]: + return user_provided_options[user_name_variable] + if lean_config and lean_config.get(user_name_key): + return lean_config[user_name_key] + user_name = prompt("Please enter your Login ID to proceed with Auth0 authentication", + show_default=False) + if lean_config is not None: + lean_config[user_name_key] = user_name + return user_name + def get_project_id(self, default_project_id: int, require_project_id: bool) -> int: """Retrieve the project ID, prompting the user if required and default is invalid. @@ -238,8 +262,12 @@ def config_build(self, lean_config["project-id"] = self.get_project_id(lean_config["project-id"], configuration.require_project_id) logger.debug(f'project_id: {lean_config["project-id"]}') + user_name = self.get_user_name(lean_config, configuration, user_provided_options, + configuration.require_user_name) + logger.debug(f'user_name: {user_name}') auth_authorizations = get_authorization(container.api_client.auth0, self._display_name.lower(), - logger, lean_config["project-id"], no_browser=no_browser) + logger, lean_config["project-id"], no_browser=no_browser, + user_name=user_name) logger.debug(f'auth: {auth_authorizations}') configuration._value = auth_authorizations.get_authorization_config_without_account() for inner_config in self._lean_configs: @@ -255,6 +283,12 @@ def config_build(self, for account_id in api_account_ids)): raise ValueError(f"The provided account id '{user_provide_account_id}' is not valid, " f"available: {api_account_ids}") + existing_account = lean_config.get(inner_config._id) + if existing_account and (existing_account not in api_account_ids + or len(api_account_ids) > 1): + # Clear stale or ambiguous account so the user is prompted + # to select from the current API choices + lean_config.pop(inner_config._id) break continue diff --git a/tests/components/api/test_auth0_client.py b/tests/components/api/test_auth0_client.py index 5495747a..0069503b 100644 --- a/tests/components/api/test_auth0_client.py +++ b/tests/components/api/test_auth0_client.py @@ -15,6 +15,7 @@ from unittest import mock from lean.constants import API_BASE_URL from lean.components.api.api_client import APIClient +from lean.components.api.auth0_client import Auth0Client from lean.components.util.http_client import HTTPClient @@ -49,6 +50,127 @@ def test_auth0client_trade_station() -> None: assert len(result.get_account_ids()) > 0 +def test_auth0client_authorize_with_user_name() -> None: + with mock.patch("webbrowser.open") as mock_open: + Auth0Client.authorize("charles-schwab", mock.Mock(), 123, user_name="test_login") + mock_open.assert_called_once() + called_url = mock_open.call_args[0][0] + assert "&userId=test_login" in called_url + + +def test_auth0client_authorize_without_user_name() -> None: + with mock.patch("webbrowser.open") as mock_open: + Auth0Client.authorize("charles-schwab", mock.Mock(), 123) + mock_open.assert_called_once() + called_url = mock_open.call_args[0][0] + assert "userId" not in called_url + + +@responses.activate +def test_auth0client_read_with_user_name() -> None: + api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc") + + responses.add( + responses.POST, + f"{API_BASE_URL}live/auth0/read", + json={ + "authorization": { + "charles-schwab-access-token": "abc123", + "accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}] + }, + "success": "true"}, + status=200 + ) + + result = api_clint.auth0.read("charles-schwab", user_name="test_login") + + assert result + assert result.authorization + sent_body = responses.calls[0].request.body.decode() + assert "userId" in sent_body + assert "test_login" in sent_body + + +@responses.activate +def test_auth0client_read_without_user_name() -> None: + api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc") + + responses.add( + responses.POST, + f"{API_BASE_URL}live/auth0/read", + json={ + "authorization": { + "charles-schwab-access-token": "abc123", + "accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}] + }, + "success": "true"}, + status=200 + ) + + result = api_clint.auth0.read("charles-schwab") + + assert result + assert result.authorization + sent_body = responses.calls[0].request.body.decode() + assert "userId" not in sent_body + + +@responses.activate +def test_auth0client_read_caches_without_user_name() -> None: + api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc") + + responses.add( + responses.POST, + f"{API_BASE_URL}live/auth0/read", + json={ + "authorization": { + "charles-schwab-access-token": "abc123", + "accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}] + }, + "success": "true"}, + status=200 + ) + + api_clint.auth0.read("charles-schwab") + api_clint.auth0.read("charles-schwab") + + assert len(responses.calls) == 1 + + +@responses.activate +def test_auth0client_read_caches_per_user_name() -> None: + api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc") + + responses.add( + responses.POST, + f"{API_BASE_URL}live/auth0/read", + json={ + "authorization": { + "charles-schwab-access-token": "abc123", + "accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}] + }, + "success": "true"}, + status=200 + ) + responses.add( + responses.POST, + f"{API_BASE_URL}live/auth0/read", + json={ + "authorization": { + "charles-schwab-access-token": "xyz789", + "accounts": [{"id": "ACC002", "name": "ACC002 | Individual | USD"}] + }, + "success": "true"}, + status=200 + ) + + api_clint.auth0.read("charles-schwab", user_name="user_a") + api_clint.auth0.read("charles-schwab", user_name="user_a") # cache hit + api_clint.auth0.read("charles-schwab", user_name="user_b") # different user — new call + + assert len(responses.calls) == 2 + + @responses.activate def test_auth0client_alpaca() -> None: api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc") diff --git a/tests/components/util/test_json_modules_handler.py b/tests/components/util/test_json_modules_handler.py index 0fd95922..bb46eb8e 100644 --- a/tests/components/util/test_json_modules_handler.py +++ b/tests/components/util/test_json_modules_handler.py @@ -10,6 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import MagicMock import pytest @@ -17,9 +18,37 @@ from lean.components.util.json_modules_handler import find_module from lean.constants import MODULE_CLI_PLATFORM, MODULE_BROKERAGE from lean.models.json_module import JsonModule +from tests.conftest import initialize_container from tests.test_helpers import create_fake_lean_cli_directory +_SCHWAB_LIKE_MODULE_DATA = { + "id": "test-brokerage", + "display-id": "TestBrokerage", + "configurations": [ + { + "id": "test-oauth-token", + "type": "oauth-token" + }, + { + "id": "test-account-number", + "type": "input", + "input-method": "choice", + "prompt-info": "Select account", + "filters": [ + { + "condition": { + "dependent-config-id": "test-oauth-token", + "pattern": "^(?!\\s*$).+", + "type": "regex" + } + } + ] + } + ] +} + + @pytest.mark.parametrize("id,display,search_name", [("ads", "binAnce", "BiNAnce"), ("binAnce", "a", "BiNAnce"), ("ads", "binAnce", "QC.Brokerage.Binance.BiNAnce"), @@ -47,3 +76,119 @@ def test_is_value_in_config(searching: str, expected: bool) -> None: result = module.is_value_in_config(searching) assert expected == result + + +def test_get_user_name_returns_none_when_not_required() -> None: + module = JsonModule({"id": "test", "configurations": [], "display-id": "Test"}, + MODULE_BROKERAGE, MODULE_CLI_PLATFORM) + result = module.get_user_name({}, mock.Mock(), {}, require_user_name=False) + assert result is None + + +def test_get_user_name_from_user_provided_options() -> None: + module = JsonModule({"id": "test", "configurations": [], "display-id": "Test"}, + MODULE_BROKERAGE, MODULE_CLI_PLATFORM) + config = mock.Mock() + config._id = "charles-schwab-oauth-token" + result = module.get_user_name({}, config, + {"charles_schwab_user_name": "cli_login"}, + require_user_name=True) + assert result == "cli_login" + + +def test_get_user_name_from_lean_config() -> None: + module = JsonModule({"id": "test", "configurations": [], "display-id": "Test"}, + MODULE_BROKERAGE, MODULE_CLI_PLATFORM) + config = mock.Mock() + config._id = "charles-schwab-oauth-token" + lean_config = {"charles-schwab-user-name": "saved_login"} + result = module.get_user_name(lean_config, config, {}, require_user_name=True) + assert result == "saved_login" + + +def test_get_user_name_prompts_and_saves_to_lean_config() -> None: + module = JsonModule({"id": "test", "configurations": [], "display-id": "Test"}, + MODULE_BROKERAGE, MODULE_CLI_PLATFORM) + config = mock.Mock() + config._id = "charles-schwab-oauth-token" + lean_config = {} + with mock.patch("click.prompt", return_value="prompted_login") as mock_prompt: + result = module.get_user_name(lean_config, config, {}, require_user_name=True) + assert result == "prompted_login" + assert lean_config["charles-schwab-user-name"] == "prompted_login" + mock_prompt.assert_called_once() + + +def test_config_build_prompts_when_lean_config_has_stale_account() -> None: + create_fake_lean_cli_directory() + initialize_container() + + module = JsonModule(_SCHWAB_LIKE_MODULE_DATA, MODULE_BROKERAGE, MODULE_CLI_PLATFORM) + + lean_config = {"project-id": 123, "test-account-number": "89630725"} # stale — not returned by API + + mock_auth = mock.MagicMock() + mock_auth.get_authorization_config_without_account.return_value = {"token": "abc"} + mock_auth.get_account_ids.return_value = ["60102549"] + + with mock.patch("lean.models.json_module.get_current_context") as mock_ctx, \ + mock.patch("lean.models.json_module.get_authorization", return_value=mock_auth), \ + mock.patch("lean.models.configuration.prompt", return_value="60102549") as mock_prompt, \ + mock.patch.object(module, "_save_property"): + mock_ctx.return_value.get_parameter_source.return_value = None + + module.config_build(lean_config, mock.Mock(), interactive=True) + + mock_prompt.assert_called_once() + account_config = next(c for c in module._lean_configs if c._id == "test-account-number") + assert account_config._value == "60102549" + + +def test_config_build_prompts_when_api_returns_multiple_accounts() -> None: + create_fake_lean_cli_directory() + initialize_container() + + module = JsonModule(_SCHWAB_LIKE_MODULE_DATA, MODULE_BROKERAGE, MODULE_CLI_PLATFORM) + + lean_config = {"project-id": 123, "test-account-number": "60102549"} # valid but ambiguous + + mock_auth = mock.MagicMock() + mock_auth.get_authorization_config_without_account.return_value = {"token": "abc"} + mock_auth.get_account_ids.return_value = ["60102549", "99887766"] # multiple accounts + + with mock.patch("lean.models.json_module.get_current_context") as mock_ctx, \ + mock.patch("lean.models.json_module.get_authorization", return_value=mock_auth), \ + mock.patch("lean.models.configuration.prompt", return_value="60102549") as mock_prompt, \ + mock.patch.object(module, "_save_property"): + mock_ctx.return_value.get_parameter_source.return_value = None + + module.config_build(lean_config, mock.Mock(), interactive=True) + + mock_prompt.assert_called_once() + account_config = next(c for c in module._lean_configs if c._id == "test-account-number") + assert account_config._value == "60102549" + + +def test_config_build_uses_lean_config_account_when_valid() -> None: + create_fake_lean_cli_directory() + initialize_container() + + module = JsonModule(_SCHWAB_LIKE_MODULE_DATA, MODULE_BROKERAGE, MODULE_CLI_PLATFORM) + + lean_config = {"project-id": 123, "test-account-number": "60102549"} # valid — matches API response + + mock_auth = mock.MagicMock() + mock_auth.get_authorization_config_without_account.return_value = {"token": "abc"} + mock_auth.get_account_ids.return_value = ["60102549"] + + with mock.patch("lean.models.json_module.get_current_context") as mock_ctx, \ + mock.patch("lean.models.json_module.get_authorization", return_value=mock_auth), \ + mock.patch("lean.models.configuration.prompt") as mock_prompt, \ + mock.patch.object(module, "_save_property"): + mock_ctx.return_value.get_parameter_source.return_value = None + + module.config_build(lean_config, mock.Mock(), interactive=True) + + mock_prompt.assert_not_called() + account_config = next(c for c in module._lean_configs if c._id == "test-account-number") + assert account_config._value == "60102549"