Skip to content

Commit 8821700

Browse files
Generalize CLI token source into progressive command list
Replace the three explicit command fields (force_cmd, cmd, fallback_cmd) and manual fallback methods with a CliCommand dataclass and an optional commands list on CliTokenSource. When commands is provided, refresh() delegates to _refresh_progressive() which walks the list from activeCommandIndex, falling back on unsupported-flag errors. When commands is None, refresh() delegates to _refresh_single() which preserves the original single-command behavior with zero changes for AzureCliTokenSource. DatabricksCliTokenSource._build_commands() produces the progressive list: --profile + --force-refresh first, plain --profile second, and --host as a terminal fallback. --force-refresh is only paired with --profile, never with --host. Adding future flags (e.g. --scopes) requires only adding entries to _build_commands().
1 parent 4115f37 commit 8821700

3 files changed

Lines changed: 184 additions & 138 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
### Internal Changes
2020
* Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors.
2121
* Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages.
22+
* Generalize CLI token source into a progressive command list for forward-compatible flag support.
2223

2324
### API Changes
2425
* Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service.

databricks/sdk/credentials_provider.py

Lines changed: 85 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import base64
3+
import dataclasses
34
import functools
45
import io
56
import json
@@ -649,8 +650,16 @@ def refreshed_headers() -> Dict[str, str]:
649650
return OAuthCredentialsProvider(refreshed_headers, token)
650651

651652

652-
class CliTokenSource(oauth.Refreshable):
653+
@dataclasses.dataclass
654+
class CliCommand:
655+
"""A single CLI command variant with metadata for progressive fallback."""
656+
657+
args: List[str]
658+
flags: List[str]
659+
warning: str
653660

661+
662+
class CliTokenSource(oauth.Refreshable):
654663
def __init__(
655664
self,
656665
cmd: List[str],
@@ -659,6 +668,7 @@ def __init__(
659668
expiry_field: str,
660669
disable_async: bool = True,
661670
fallback_cmd: Optional[List[str]] = None,
671+
commands: Optional[List[CliCommand]] = None,
662672
):
663673
super().__init__(disable_async=disable_async)
664674
self._cmd = cmd
@@ -670,6 +680,8 @@ def __init__(
670680
self._token_type_field = token_type_field
671681
self._access_token_field = access_token_field
672682
self._expiry_field = expiry_field
683+
self._commands = commands
684+
self._active_command_index = 0
673685

674686
@staticmethod
675687
def _parse_expiry(expiry: str) -> datetime:
@@ -700,7 +712,18 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
700712
message = "\n".join(filter(None, [stdout, stderr]))
701713
raise IOError(f"cannot get access token: {message}") from e
702714

715+
@staticmethod
716+
def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool:
717+
"""Check if the error indicates the CLI rejected one of the given flags."""
718+
msg = str(error)
719+
return any(f"unknown flag: {flag}" in msg for flag in flags)
720+
703721
def refresh(self) -> oauth.Token:
722+
if self._commands is not None:
723+
return self._refresh_progressive()
724+
return self._refresh_single()
725+
726+
def _refresh_single(self) -> oauth.Token:
704727
try:
705728
return self._exec_cli_command(self._cmd)
706729
except IOError as e:
@@ -712,6 +735,22 @@ def refresh(self) -> oauth.Token:
712735
return self._exec_cli_command(self._fallback_cmd)
713736
raise
714737

738+
def _refresh_progressive(self) -> oauth.Token:
739+
last_err: Optional[IOError] = None
740+
for i in range(self._active_command_index, len(self._commands)):
741+
cmd = self._commands[i]
742+
try:
743+
token = self._exec_cli_command(cmd.args)
744+
self._active_command_index = i
745+
return token
746+
except IOError as e:
747+
is_last = i == len(self._commands) - 1
748+
if is_last or not self._is_unknown_flag_error(e, cmd.flags):
749+
raise
750+
logger.warning(cmd.warning)
751+
last_err = e
752+
raise last_err
753+
715754

716755
def _run_subprocess(
717756
popenargs,
@@ -899,15 +938,7 @@ def __init__(self, cfg: "Config"):
899938
elif cli_path.count("/") == 0:
900939
cli_path = self.__class__._find_executable(cli_path)
901940

902-
fallback_cmd = None
903-
if cfg.profile:
904-
args = ["auth", "token", "--profile", cfg.profile]
905-
if cfg.host:
906-
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
907-
else:
908-
args = self.__class__._build_host_args(cfg)
909-
910-
self._force_cmd = [cli_path, *args, "--force-refresh"]
941+
commands = self.__class__._build_commands(cli_path, cfg)
911942

912943
# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
913944
# cause false-positive mismatches against every token that wasn't issued with
@@ -917,27 +948,57 @@ def __init__(self, cfg: "Config"):
917948
self._host = cfg.host
918949

919950
super().__init__(
920-
cmd=[cli_path, *args],
951+
cmd=commands[-1].args,
921952
token_type_field="token_type",
922953
access_token_field="access_token",
923954
expiry_field="expiry",
924955
disable_async=cfg.disable_async_token_refresh,
925-
fallback_cmd=fallback_cmd,
956+
commands=commands,
926957
)
927958

928-
def refresh(self) -> oauth.Token:
929-
try:
930-
token = self._exec_cli_command(self._force_cmd)
931-
except IOError as e:
932-
err_msg = str(e)
933-
if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
934-
logger.warning(
935-
"Databricks CLI does not support --force-refresh. "
936-
"Please upgrade your CLI to the latest version."
959+
@staticmethod
960+
def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
961+
commands: List[CliCommand] = []
962+
if cfg.profile:
963+
profile_args = [cli_path, "auth", "token", "--profile", cfg.profile]
964+
commands.append(
965+
CliCommand(
966+
args=profile_args + ["--force-refresh"],
967+
flags=["--force-refresh", "--profile"],
968+
warning="Databricks CLI does not support --force-refresh. "
969+
"Please upgrade your CLI to the latest version.",
937970
)
938-
token = super().refresh()
939-
else:
940-
raise
971+
)
972+
commands.append(
973+
CliCommand(
974+
args=profile_args,
975+
flags=["--profile"],
976+
warning="Databricks CLI does not support --profile flag. "
977+
"Falling back to --host. "
978+
"Please upgrade your CLI to the latest version.",
979+
)
980+
)
981+
if cfg.host:
982+
commands.append(
983+
CliCommand(
984+
args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
985+
flags=[],
986+
warning="",
987+
)
988+
)
989+
else:
990+
host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
991+
commands.append(
992+
CliCommand(
993+
args=host_args,
994+
flags=[],
995+
warning="",
996+
)
997+
)
998+
return commands
999+
1000+
def refresh(self) -> oauth.Token:
1001+
token = super().refresh()
9411002
if self._requested_scopes:
9421003
self._validate_token_scopes(token)
9431004
return token

0 commit comments

Comments
 (0)