Skip to content

Commit e619b01

Browse files
authored
[Feature] Migrate U2M Code to SDKs (#2076)
## Changes Ports the SDK to this change: databricks/databricks-sdk-go#1108. This eliminates the need for the SDK to invoke itself (!) when loading an OAuth token. It also allows us to respond with more suitable errors when fetching tokens. Functionally, this should behave identically to the existing CLI from an external perspective. One concrete improvement with this change is that the error message returned when a refresh token is invalid now provides the correct command to run, including the appropriate `--profile` flag. Users will see the following message when using the CLI with a profile/host with an expired refresh token: ![Screenshot 2025-01-20 at 13 18 59](https://github.com/user-attachments/assets/bf3040cb-971d-4c04-8557-0f6749d82b5e) ## Tests Mostly this change involves refactoring existing code/tests to use/mock the Go SDK interfaces, so there aren't too many substantial test changes. `token_test.go` is rewritten using the table test approach to reduce boilerplate.
1 parent 77e4492 commit e619b01

18 files changed

Lines changed: 402 additions & 1278 deletions

cmd/auth/auth.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ Azure: https://learn.microsoft.com/azure/databricks/dev-tools/auth
2222
GCP: https://docs.gcp.databricks.com/dev-tools/auth/index.html`,
2323
}
2424

25-
var perisistentAuth auth.PersistentAuth
26-
cmd.PersistentFlags().StringVar(&perisistentAuth.Host, "host", perisistentAuth.Host, "Databricks Host")
27-
cmd.PersistentFlags().StringVar(&perisistentAuth.AccountID, "account-id", perisistentAuth.AccountID, "Databricks Account ID")
25+
var authArguments auth.AuthArguments
26+
cmd.PersistentFlags().StringVar(&authArguments.Host, "host", "", "Databricks Host")
27+
cmd.PersistentFlags().StringVar(&authArguments.AccountID, "account-id", "", "Databricks Account ID")
2828

2929
cmd.AddCommand(newEnvCommand())
30-
cmd.AddCommand(newLoginCommand(&perisistentAuth))
30+
cmd.AddCommand(newLoginCommand(&authArguments))
3131
cmd.AddCommand(newProfilesCommand())
32-
cmd.AddCommand(newTokenCommand(&perisistentAuth))
32+
cmd.AddCommand(newTokenCommand(&authArguments))
3333
cmd.AddCommand(newDescribeCommand())
3434
return cmd
3535
}

cmd/auth/in_memory_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package auth
2+
3+
import (
4+
"github.com/databricks/databricks-sdk-go/credentials/u2m/cache"
5+
"golang.org/x/oauth2"
6+
)
7+
8+
type inMemoryTokenCache struct {
9+
Tokens map[string]*oauth2.Token
10+
}
11+
12+
// Lookup implements TokenCache.
13+
func (i *inMemoryTokenCache) Lookup(key string) (*oauth2.Token, error) {
14+
token, ok := i.Tokens[key]
15+
if !ok {
16+
return nil, cache.ErrNotConfigured
17+
}
18+
return token, nil
19+
}
20+
21+
// Store implements TokenCache.
22+
func (i *inMemoryTokenCache) Store(key string, t *oauth2.Token) error {
23+
i.Tokens[key] = t
24+
return nil
25+
}
26+
27+
var _ cache.TokenCache = (*inMemoryTokenCache)(nil)

cmd/auth/login.go

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"errors"
66
"fmt"
77
"runtime"
8+
"strings"
89
"time"
910

1011
"github.com/databricks/cli/libs/auth"
@@ -14,6 +15,7 @@ import (
1415
"github.com/databricks/cli/libs/databrickscfg/profile"
1516
"github.com/databricks/databricks-sdk-go"
1617
"github.com/databricks/databricks-sdk-go/config"
18+
"github.com/databricks/databricks-sdk-go/credentials/u2m"
1719
"github.com/spf13/cobra"
1820
)
1921

@@ -34,7 +36,7 @@ const (
3436
defaultTimeout = 1 * time.Hour
3537
)
3638

37-
func newLoginCommand(persistentAuth *auth.PersistentAuth) *cobra.Command {
39+
func newLoginCommand(authArguments *auth.AuthArguments) *cobra.Command {
3840
defaultConfigPath := "~/.databrickscfg"
3941
if runtime.GOOS == "windows" {
4042
defaultConfigPath = "%USERPROFILE%\\.databrickscfg"
@@ -98,14 +100,22 @@ depends on the existing profiles you have set in your configuration file
98100
// If the user has not specified a profile name, prompt for one.
99101
if profileName == "" {
100102
var err error
101-
profileName, err = promptForProfile(ctx, persistentAuth.ProfileName())
103+
profileName, err = promptForProfile(ctx, getProfileName(authArguments))
102104
if err != nil {
103105
return err
104106
}
105107
}
106108

107109
// Set the host and account-id based on the provided arguments and flags.
108-
err := setHostAndAccountId(ctx, profileName, persistentAuth, args)
110+
err := setHostAndAccountId(ctx, profile.DefaultProfiler, profileName, authArguments, args)
111+
if err != nil {
112+
return err
113+
}
114+
oauthArgument, err := authArguments.ToOAuthArgument()
115+
if err != nil {
116+
return err
117+
}
118+
persistentAuth, err := u2m.NewPersistentAuth(ctx, u2m.WithOAuthArgument(oauthArgument))
109119
if err != nil {
110120
return err
111121
}
@@ -114,16 +124,15 @@ depends on the existing profiles you have set in your configuration file
114124
// We need the config without the profile before it's used to initialise new workspace client below.
115125
// Otherwise it will complain about non existing profile because it was not yet saved.
116126
cfg := config.Config{
117-
Host: persistentAuth.Host,
118-
AccountID: persistentAuth.AccountID,
127+
Host: authArguments.Host,
128+
AccountID: authArguments.AccountID,
119129
AuthType: "databricks-cli",
120130
}
121131

122132
ctx, cancel := context.WithTimeout(ctx, loginTimeout)
123133
defer cancel()
124134

125-
err = persistentAuth.Challenge(ctx)
126-
if err != nil {
135+
if err = persistentAuth.Challenge(); err != nil {
127136
return err
128137
}
129138

@@ -173,53 +182,66 @@ depends on the existing profiles you have set in your configuration file
173182
// 1. --account-id flag.
174183
// 2. account-id from the specified profile, if available.
175184
// 3. Prompt the user for the account-id.
176-
func setHostAndAccountId(ctx context.Context, profileName string, persistentAuth *auth.PersistentAuth, args []string) error {
185+
func setHostAndAccountId(ctx context.Context, profiler profile.Profiler, profileName string, authArguments *auth.AuthArguments, args []string) error {
177186
// If both [HOST] and --host are provided, return an error.
178-
if len(args) > 0 && persistentAuth.Host != "" {
187+
host := authArguments.Host
188+
if len(args) > 0 && host != "" {
179189
return errors.New("please only provide a host as an argument or a flag, not both")
180190
}
181191

182-
profiler := profile.GetProfiler(ctx)
183192
// If the chosen profile has a hostname and the user hasn't specified a host, infer the host from the profile.
184193
profiles, err := profiler.LoadProfiles(ctx, profile.WithName(profileName))
185194
// Tolerate ErrNoConfiguration here, as we will write out a configuration as part of the login flow.
186195
if err != nil && !errors.Is(err, profile.ErrNoConfiguration) {
187196
return err
188197
}
189198

190-
if persistentAuth.Host == "" {
199+
if host == "" {
191200
if len(args) > 0 {
192201
// If [HOST] is provided, set the host to the provided positional argument.
193-
persistentAuth.Host = args[0]
202+
authArguments.Host = args[0]
194203
} else if len(profiles) > 0 && profiles[0].Host != "" {
195204
// If neither [HOST] nor --host are provided, and the profile has a host, use it.
196-
persistentAuth.Host = profiles[0].Host
205+
authArguments.Host = profiles[0].Host
197206
} else {
198207
// If neither [HOST] nor --host are provided, and the profile does not have a host,
199208
// then prompt the user for a host.
200209
hostName, err := promptForHost(ctx)
201210
if err != nil {
202211
return err
203212
}
204-
persistentAuth.Host = hostName
213+
authArguments.Host = hostName
205214
}
206215
}
207216

208217
// If the account-id was not provided as a cmd line flag, try to read it from
209218
// the specified profile.
210-
isAccountClient := (&config.Config{Host: persistentAuth.Host}).IsAccountClient()
211-
if isAccountClient && persistentAuth.AccountID == "" {
219+
isAccountClient := (&config.Config{Host: authArguments.Host}).IsAccountClient()
220+
accountID := authArguments.AccountID
221+
if isAccountClient && accountID == "" {
212222
if len(profiles) > 0 && profiles[0].AccountID != "" {
213-
persistentAuth.AccountID = profiles[0].AccountID
223+
authArguments.AccountID = profiles[0].AccountID
214224
} else {
215225
// Prompt user for the account-id if it we could not get it from a
216226
// profile.
217227
accountId, err := promptForAccountID(ctx)
218228
if err != nil {
219229
return err
220230
}
221-
persistentAuth.AccountID = accountId
231+
authArguments.AccountID = accountId
222232
}
223233
}
224234
return nil
225235
}
236+
237+
// getProfileName returns the default profile name for a given host/account ID.
238+
// If the account ID is provided, the profile name is "ACCOUNT-<account-id>".
239+
// Otherwise, the profile name is the first part of the host URL.
240+
func getProfileName(authArguments *auth.AuthArguments) string {
241+
if authArguments.AccountID != "" {
242+
return "ACCOUNT-" + authArguments.AccountID
243+
}
244+
host := strings.TrimPrefix(authArguments.Host, "https://")
245+
split := strings.Split(host, ".")
246+
return split[0]
247+
}

cmd/auth/login_test.go

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66

77
"github.com/databricks/cli/libs/auth"
88
"github.com/databricks/cli/libs/cmdio"
9+
"github.com/databricks/cli/libs/databrickscfg/profile"
910
"github.com/databricks/cli/libs/env"
1011
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/require"
@@ -14,72 +15,72 @@ import (
1415
func TestSetHostDoesNotFailWithNoDatabrickscfg(t *testing.T) {
1516
ctx := context.Background()
1617
ctx = env.Set(ctx, "DATABRICKS_CONFIG_FILE", "./imaginary-file/databrickscfg")
17-
err := setHostAndAccountId(ctx, "foo", &auth.PersistentAuth{Host: "test"}, []string{})
18+
err := setHostAndAccountId(ctx, profile.DefaultProfiler, "foo", &auth.AuthArguments{Host: "test"}, []string{})
1819
assert.NoError(t, err)
1920
}
2021

2122
func TestSetHost(t *testing.T) {
22-
var persistentAuth auth.PersistentAuth
23+
authArguments := auth.AuthArguments{}
2324
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")
2425
ctx, _ := cmdio.SetupTest(context.Background())
2526

2627
// Test error when both flag and argument are provided
27-
persistentAuth.Host = "val from --host"
28-
err := setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
28+
authArguments.Host = "val from --host"
29+
err := setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"})
2930
assert.EqualError(t, err, "please only provide a host as an argument or a flag, not both")
3031

3132
// Test setting host from flag
32-
persistentAuth.Host = "val from --host"
33-
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{})
33+
authArguments.Host = "val from --host"
34+
err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{})
3435
assert.NoError(t, err)
35-
assert.Equal(t, "val from --host", persistentAuth.Host)
36+
assert.Equal(t, "val from --host", authArguments.Host)
3637

3738
// Test setting host from argument
38-
persistentAuth.Host = ""
39-
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{"val from [HOST]"})
39+
authArguments.Host = ""
40+
err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{"val from [HOST]"})
4041
assert.NoError(t, err)
41-
assert.Equal(t, "val from [HOST]", persistentAuth.Host)
42+
assert.Equal(t, "val from [HOST]", authArguments.Host)
4243

4344
// Test setting host from profile
44-
persistentAuth.Host = ""
45-
err = setHostAndAccountId(ctx, "profile-1", &persistentAuth, []string{})
45+
authArguments.Host = ""
46+
err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-1", &authArguments, []string{})
4647
assert.NoError(t, err)
47-
assert.Equal(t, "https://www.host1.com", persistentAuth.Host)
48+
assert.Equal(t, "https://www.host1.com", authArguments.Host)
4849

4950
// Test setting host from profile
50-
persistentAuth.Host = ""
51-
err = setHostAndAccountId(ctx, "profile-2", &persistentAuth, []string{})
51+
authArguments.Host = ""
52+
err = setHostAndAccountId(ctx, profile.DefaultProfiler, "profile-2", &authArguments, []string{})
5253
assert.NoError(t, err)
53-
assert.Equal(t, "https://www.host2.com", persistentAuth.Host)
54+
assert.Equal(t, "https://www.host2.com", authArguments.Host)
5455

5556
// Test host is not set. Should prompt.
56-
persistentAuth.Host = ""
57-
err = setHostAndAccountId(ctx, "", &persistentAuth, []string{})
57+
authArguments.Host = ""
58+
err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{})
5859
assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify a host using --host")
5960
}
6061

6162
func TestSetAccountId(t *testing.T) {
62-
var persistentAuth auth.PersistentAuth
63+
var authArguments auth.AuthArguments
6364
t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg")
6465
ctx, _ := cmdio.SetupTest(context.Background())
6566

6667
// Test setting account-id from flag
67-
persistentAuth.AccountID = "val from --account-id"
68-
err := setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{})
68+
authArguments.AccountID = "val from --account-id"
69+
err := setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{})
6970
assert.NoError(t, err)
70-
assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host)
71-
assert.Equal(t, "val from --account-id", persistentAuth.AccountID)
71+
assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host)
72+
assert.Equal(t, "val from --account-id", authArguments.AccountID)
7273

7374
// Test setting account_id from profile
74-
persistentAuth.AccountID = ""
75-
err = setHostAndAccountId(ctx, "account-profile", &persistentAuth, []string{})
75+
authArguments.AccountID = ""
76+
err = setHostAndAccountId(ctx, profile.DefaultProfiler, "account-profile", &authArguments, []string{})
7677
require.NoError(t, err)
77-
assert.Equal(t, "https://accounts.cloud.databricks.com", persistentAuth.Host)
78-
assert.Equal(t, "id-from-profile", persistentAuth.AccountID)
78+
assert.Equal(t, "https://accounts.cloud.databricks.com", authArguments.Host)
79+
assert.Equal(t, "id-from-profile", authArguments.AccountID)
7980

8081
// Neither flag nor profile account-id is set, should prompt
81-
persistentAuth.AccountID = ""
82-
persistentAuth.Host = "https://accounts.cloud.databricks.com"
83-
err = setHostAndAccountId(ctx, "", &persistentAuth, []string{})
82+
authArguments.AccountID = ""
83+
authArguments.Host = "https://accounts.cloud.databricks.com"
84+
err = setHostAndAccountId(ctx, profile.DefaultProfiler, "", &authArguments, []string{})
8485
assert.EqualError(t, err, "the command is being run in a non-interactive environment, please specify an account ID using --account-id")
8586
}

0 commit comments

Comments
 (0)