Skip to content

Commit 1432f4c

Browse files
Generalize CLI token source into progressive command list
Replace the explicit force_cmd/fallback_cmd fields with a CliCommand dataclass and an optional commands list on CliTokenSource. When commands is provided, refresh() delegates to _refresh_progressive() which walks the list from activeCommandIndex, falling back on unsupported-flag errors. When commands is None, refresh() delegates to _refresh_single() which preserves the original fallback behavior with zero changes for AzureCliTokenSource. DatabricksCliTokenSource._build_commands() produces the progressive list: --profile + --force-refresh first, plain --profile second, and --host as a terminal fallback. --force-refresh is only paired with --profile, never with --host. Adding future flags (e.g. --scopes) requires only adding entries to _build_commands().
1 parent cd6c876 commit 1432f4c

3 files changed

Lines changed: 162 additions & 137 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
### Internal Changes
2020
* Replace the async-disabling mechanism on token refresh failure with a 1-minute retry backoff. Previously, a single failed async refresh would disable proactive token renewal until the token expired. Now, the SDK waits a short cooldown period and retries, improving resilience to transient errors.
2121
* Extract `_resolve_profile` to simplify config file loading and improve `__settings__` error messages.
22+
* Generalize CLI token source into a progressive command list for forward-compatible flag support.
2223

2324
### API Changes
2425
* Add `create_catalog()`, `create_synced_table()`, `delete_catalog()`, `delete_synced_table()`, `get_catalog()` and `get_synced_table()` methods for [w.postgres](https://databricks-sdk-py.readthedocs.io/en/latest/workspace/postgres/postgres.html) workspace-level service.

databricks/sdk/credentials_provider.py

Lines changed: 93 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import base64
3+
import dataclasses
34
import functools
45
import io
56
import json
@@ -649,6 +650,15 @@ def refreshed_headers() -> Dict[str, str]:
649650
return OAuthCredentialsProvider(refreshed_headers, token)
650651

651652

653+
@dataclasses.dataclass
654+
class CliCommand:
655+
"""A single CLI command variant with metadata for progressive fallback."""
656+
657+
args: List[str]
658+
flags: List[str]
659+
warning: str
660+
661+
652662
class CliTokenSource(oauth.Refreshable):
653663

654664
def __init__(
@@ -659,6 +669,7 @@ def __init__(
659669
expiry_field: str,
660670
disable_async: bool = True,
661671
fallback_cmd: Optional[List[str]] = None,
672+
commands: Optional[List[CliCommand]] = None,
662673
):
663674
super().__init__(disable_async=disable_async)
664675
self._cmd = cmd
@@ -670,6 +681,8 @@ def __init__(
670681
self._token_type_field = token_type_field
671682
self._access_token_field = access_token_field
672683
self._expiry_field = expiry_field
684+
self._commands = commands
685+
self._active_command_index = -1
673686

674687
@staticmethod
675688
def _parse_expiry(expiry: str) -> datetime:
@@ -700,7 +713,18 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
700713
message = "\n".join(filter(None, [stdout, stderr]))
701714
raise IOError(f"cannot get access token: {message}") from e
702715

716+
@staticmethod
717+
def _is_unknown_flag_error(error: IOError, flags: List[str]) -> bool:
718+
"""Check if the error indicates the CLI rejected one of the given flags."""
719+
msg = str(error)
720+
return any(f"unknown flag: {flag}" in msg for flag in flags)
721+
703722
def refresh(self) -> oauth.Token:
723+
if self._commands is not None:
724+
return self._refresh_progressive()
725+
return self._refresh_single()
726+
727+
def _refresh_single(self) -> oauth.Token:
704728
try:
705729
return self._exec_cli_command(self._cmd)
706730
except IOError as e:
@@ -712,6 +736,30 @@ def refresh(self) -> oauth.Token:
712736
return self._exec_cli_command(self._fallback_cmd)
713737
raise
714738

739+
def _refresh_progressive(self) -> oauth.Token:
740+
idx = self._active_command_index
741+
if idx >= 0:
742+
return self._exec_cli_command(self._commands[idx].args)
743+
return self._probe_and_exec()
744+
745+
def _probe_and_exec(self) -> oauth.Token:
746+
"""Walk the command list to find a CLI command that succeeds.
747+
748+
When a command fails with "unknown flag" for one of its flags, log a
749+
warning and try the next. On success, store _active_command_index so
750+
future calls skip probing.
751+
"""
752+
for i, cmd in enumerate(self._commands):
753+
try:
754+
token = self._exec_cli_command(cmd.args)
755+
self._active_command_index = i
756+
return token
757+
except IOError as e:
758+
is_last = i == len(self._commands) - 1
759+
if is_last or not self._is_unknown_flag_error(e, cmd.flags):
760+
raise
761+
logger.warning(cmd.warning)
762+
715763

716764
def _run_subprocess(
717765
popenargs,
@@ -899,15 +947,7 @@ def __init__(self, cfg: "Config"):
899947
elif cli_path.count("/") == 0:
900948
cli_path = self.__class__._find_executable(cli_path)
901949

902-
fallback_cmd = None
903-
self._force_cmd = None
904-
if cfg.profile:
905-
args = ["auth", "token", "--profile", cfg.profile]
906-
self._force_cmd = [cli_path, *args, "--force-refresh"]
907-
if cfg.host:
908-
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
909-
else:
910-
args = self.__class__._build_host_args(cfg)
950+
commands = self.__class__._build_commands(cli_path, cfg)
911951

912952
# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
913953
# cause false-positive mismatches against every token that wasn't issued with
@@ -917,30 +957,57 @@ def __init__(self, cfg: "Config"):
917957
self._host = cfg.host
918958

919959
super().__init__(
920-
cmd=[cli_path, *args],
960+
cmd=commands[-1].args,
921961
token_type_field="token_type",
922962
access_token_field="access_token",
923963
expiry_field="expiry",
924964
disable_async=cfg.disable_async_token_refresh,
925-
fallback_cmd=fallback_cmd,
965+
commands=commands,
926966
)
927967

928-
def refresh(self) -> oauth.Token:
929-
if self._force_cmd is None:
930-
token = super().refresh()
931-
else:
932-
try:
933-
token = self._exec_cli_command(self._force_cmd)
934-
except IOError as e:
935-
err_msg = str(e)
936-
if "unknown flag: --force-refresh" in err_msg or "unknown flag: --profile" in err_msg:
937-
logger.warning(
938-
"Databricks CLI does not support --force-refresh. "
939-
"Please upgrade your CLI to the latest version."
968+
@staticmethod
969+
def _build_commands(cli_path: str, cfg: "Config") -> List[CliCommand]:
970+
commands: List[CliCommand] = []
971+
if cfg.profile:
972+
profile_args = [cli_path, "auth", "token", "--profile", cfg.profile]
973+
commands.append(
974+
CliCommand(
975+
args=profile_args + ["--force-refresh"],
976+
flags=["--force-refresh", "--profile"],
977+
warning="Databricks CLI does not support --force-refresh. "
978+
"Please upgrade your CLI to the latest version.",
979+
)
980+
)
981+
commands.append(
982+
CliCommand(
983+
args=profile_args,
984+
flags=["--profile"],
985+
warning="Databricks CLI does not support --profile flag. "
986+
"Falling back to --host. "
987+
"Please upgrade your CLI to the latest version.",
988+
)
989+
)
990+
if cfg.host:
991+
commands.append(
992+
CliCommand(
993+
args=[cli_path, *DatabricksCliTokenSource._build_host_args(cfg)],
994+
flags=[],
995+
warning="",
940996
)
941-
token = super().refresh()
942-
else:
943-
raise
997+
)
998+
else:
999+
host_args = [cli_path, *DatabricksCliTokenSource._build_host_args(cfg)]
1000+
commands.append(
1001+
CliCommand(
1002+
args=host_args,
1003+
flags=[],
1004+
warning="",
1005+
)
1006+
)
1007+
return commands
1008+
1009+
def refresh(self) -> oauth.Token:
1010+
token = super().refresh()
9441011
if self._requested_scopes:
9451012
self._validate_token_scopes(token)
9461013
return token

0 commit comments

Comments
 (0)