From b027581325c68c9b54c82e03115086c0b782f515 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Mon, 27 Apr 2026 10:34:27 -0700 Subject: [PATCH 1/2] feat(BRE2-915): api key as auth method --- pkg/auth/auth.go | 79 +++++++++- pkg/auth/auth_test.go | 182 +++++++++++++++++++++- pkg/cmd/completions/completions.go | 24 ++- pkg/cmd/delete/delete.go | 4 + pkg/cmd/gpucreate/gpucreate.go | 22 ++- pkg/cmd/gpucreate/gpucreate_test.go | 24 +++ pkg/cmd/invite/invite.go | 5 +- pkg/cmd/login/login.go | 46 +++++- pkg/cmd/login/login_test.go | 232 ++++++++++++++++++++++++++++ pkg/cmd/ls/ls.go | 175 +++++++++++++++------ pkg/cmd/ls/ls_test.go | 153 ++++++++++++++++-- pkg/cmd/set/set.go | 5 + pkg/cmd/set/set_test.go | 77 +++++++++ pkg/cmd/start/start.go | 70 +++++---- pkg/cmd/start/start_test.go | 15 ++ pkg/cmd/stop/stop.go | 34 ++-- pkg/cmd/util/util.go | 18 +++ pkg/cmd/util/util_test.go | 70 +++++++++ pkg/entity/entity.go | 2 + pkg/store/authtoken.go | 4 +- pkg/store/authtoken_test.go | 56 +++++++ pkg/store/memory_auth.go | 4 +- pkg/store/organization.go | 63 ++++++-- pkg/store/organization_test.go | 54 +++++++ pkg/store/workspace.go | 9 ++ 25 files changed, 1291 insertions(+), 136 deletions(-) create mode 100644 pkg/cmd/util/util_test.go diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 3138fd034..5c94bc328 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -101,6 +101,44 @@ type Auth struct { shouldLogin func() (bool, error) } +const BrevAPIKeyPrefix = "bak-" + +const MissingAPIKeyOrgIDMessage = "api key auth requires an org id; run brev login --api-key --org-id " + +type APIKeyAuthStore interface { + GetAuthTokens() (*entity.AuthTokens, error) +} + +func IsBrevAPIKey(token string) bool { + return strings.HasPrefix(strings.TrimSpace(token), BrevAPIKeyPrefix) +} + +func IsAPIKeyAuthStore(authTokensProvider APIKeyAuthStore) bool { + tokens, err := authTokensProvider.GetAuthTokens() + if err != nil { + return false + } + if tokens == nil { + return false + } + return IsBrevAPIKey(tokens.APIKey) +} + +func GetAPIKeyOrgID(authTokensProvider APIKeyAuthStore) (string, error) { + tokens, err := authTokensProvider.GetAuthTokens() + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + if tokens == nil { + return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage) + } + orgID := strings.TrimSpace(tokens.APIKeyOrgID) + if orgID == "" { + return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage) + } + return orgID, nil +} + func NewAuth(authStore AuthStore, oauth OAuth) *Auth { return &Auth{ authStore: authStore, @@ -146,6 +184,11 @@ func (t Auth) GetFreshAccessTokenOrNil() (string, error) { return "", nil } + apiKey := strings.TrimSpace(tokens.APIKey) + if apiKey != "" { + return apiKey, nil + } + // should always at least have access token? if tokens.AccessToken == "" { breverrors.GetDefaultErrorReporter().ReportMessage("access token is an empty string but shouldn't be") @@ -222,6 +265,36 @@ func (t Auth) LoginWithToken(token string) error { return nil } +func (t Auth) LoginWithAPIKey(apiKey string, orgID string) error { + apiKey = strings.TrimSpace(apiKey) + if apiKey == "" { + return breverrors.NewValidationError("api key is empty") + } + if !IsBrevAPIKey(apiKey) { + return breverrors.NewValidationError(fmt.Sprintf("api key must start with %s", BrevAPIKeyPrefix)) + } + orgID = strings.TrimSpace(orgID) + if orgID == "" { + return breverrors.NewValidationError(MissingAPIKeyOrgIDMessage) + } + + tokens, err := t.getSavedTokensOrNil() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if tokens == nil { + tokens = &entity.AuthTokens{} + } + tokens.APIKey = apiKey + tokens.APIKeyOrgID = orgID + + err = t.authStore.SaveAuthTokens(*tokens) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + // showLoginURL displays the login link and CLI alternative for manual navigation. func showLoginURL(url string) { urlType := color.New(color.FgCyan, color.Bold).SprintFunc() @@ -313,7 +386,7 @@ func (t Auth) getSavedTokensOrNil() (*entity.AuthTokens, error) { } return nil, breverrors.WrapAndTrace(err) } - if tokens != nil && tokens.AccessToken == "" && tokens.RefreshToken == "" { + if tokens != nil && tokens.AccessToken == "" && tokens.RefreshToken == "" && tokens.APIKey == "" { return nil, nil } return tokens, nil @@ -415,7 +488,7 @@ func AuthProviderFlagToCredentialProvider(authProviderFlag string) entity.Creden func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens) OAuth { // Set KAS as the default authenticator shouldPromptEmail := false - if email == "" && tokens != nil && tokens.AccessToken != "" { + if email == "" && tokens != nil && tokens.AccessToken != "" && tokens.APIKey == "" { email = GetEmailFromToken(tokens.AccessToken) shouldPromptEmail = true } @@ -445,7 +518,7 @@ func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens) kasAuthenticator, }) - if tokens != nil && tokens.AccessToken != "" { + if tokens != nil && tokens.AccessToken != "" && tokens.APIKey == "" { authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken) if errr != nil { fmt.Printf("%v\n", errr) diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index 992f1bdbb..bdfedbe12 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -1,11 +1,14 @@ package auth import ( + "io" + "os" "testing" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -38,10 +41,12 @@ func TestIsAccessTokenValid(t *testing.T) { type MockAuthStore struct { authTokens *entity.AuthTokens + saved entity.AuthTokens didSave bool } -func (m *MockAuthStore) SaveAuthTokens(_ entity.AuthTokens) error { +func (m *MockAuthStore) SaveAuthTokens(tokens entity.AuthTokens) error { + m.saved = tokens m.didSave = true return nil } @@ -77,7 +82,180 @@ func (m MockOauth) GetNewAuthTokensWithRefresh(_ string) (*entity.AuthTokens, er return m.authTokens, nil } -const validToken = "abc" +const ( + validToken = "abc" + testAPIKey = BrevAPIKeyPrefix + "test-key" +) + +func TestIsBrevAPIKey(t *testing.T) { + assert.True(t, IsBrevAPIKey(testAPIKey)) + assert.True(t, IsBrevAPIKey(" "+testAPIKey+" ")) + assert.False(t, IsBrevAPIKey("bakery-token")) + assert.False(t, IsBrevAPIKey("jwt-token")) + assert.False(t, IsBrevAPIKey("")) +} + +type sideEffectingTokenStore struct { + tokens *entity.AuthTokens + getAccessTokenCalled bool +} + +func (s *sideEffectingTokenStore) GetAuthTokens() (*entity.AuthTokens, error) { + return s.tokens, nil +} + +func (s *sideEffectingTokenStore) GetAccessToken() (string, error) { + s.getAccessTokenCalled = true + return testAPIKey, nil +} + +func TestIsAPIKeyAuthStore_ReadsSavedTokensWithoutAccessTokenSideEffects(t *testing.T) { + s := &sideEffectingTokenStore{ + tokens: &entity.AuthTokens{APIKey: testAPIKey}, + } + + assert.True(t, IsAPIKeyAuthStore(s)) + assert.False(t, s.getAccessTokenCalled) +} + +func TestIsAPIKeyAuthStore_LegacyCredentialsAreNotAPIKeyAuth(t *testing.T) { + s := &sideEffectingTokenStore{ + tokens: &entity.AuthTokens{ + AccessToken: validToken, + RefreshToken: "refresh", + }, + } + + assert.False(t, IsAPIKeyAuthStore(s)) + assert.False(t, s.getAccessTokenCalled) +} + +func TestGetFreshAccessTokenOrNil_APIKeySkipsJWTValidationAndRefresh(t *testing.T) { + s := MockAuthStore{authTokens: &entity.AuthTokens{ + AccessToken: "expired-jwt", + APIKey: testAPIKey, + RefreshToken: "should-not-refresh", + }} + a := Auth{ + &s, + &MockOauth{}, func(_ string) (bool, error) { + t.Fatal("api keys must not be parsed as JWTs") + return false, nil + }, + func() (bool, error) { + t.Fatal("api keys must not trigger login") + return false, nil + }, + } + + res, err := a.GetFreshAccessTokenOrNil() + assert.NoError(t, err) + assert.Equal(t, testAPIKey, res) + assert.False(t, s.didSave) +} + +func TestGetFreshAccessTokenOrNil_APIKeyOnlyCredentialReturnsAPIKey(t *testing.T) { + s := MockAuthStore{authTokens: &entity.AuthTokens{ + APIKey: testAPIKey, + }} + a := Auth{ + &s, + &MockOauth{}, func(_ string) (bool, error) { + t.Fatal("api keys must not be parsed as JWTs") + return false, nil + }, + func() (bool, error) { + t.Fatal("api keys must not trigger login") + return false, nil + }, + } + + res, err := a.GetFreshAccessTokenOrNil() + assert.NoError(t, err) + assert.Equal(t, testAPIKey, res) + assert.False(t, s.didSave) +} + +func TestLoginWithAPIKey_SavesTypedCredential(t *testing.T) { + s := MockAuthStore{} + a := Auth{ + authStore: &s, + oauth: &MockOauth{}, + } + + err := a.LoginWithAPIKey(testAPIKey, "org-test") + assert.NoError(t, err) + assert.True(t, s.didSave) + assert.Equal(t, entity.AuthTokens{ + APIKey: testAPIKey, + APIKeyOrgID: "org-test", + }, s.saved) +} + +func TestLoginWithAPIKey_PreservesExistingJWT(t *testing.T) { + s := MockAuthStore{authTokens: &entity.AuthTokens{ + AccessToken: "existing-jwt", + RefreshToken: "existing-refresh", + }} + a := Auth{ + authStore: &s, + oauth: &MockOauth{}, + } + + err := a.LoginWithAPIKey(testAPIKey, "org-test") + assert.NoError(t, err) + assert.Equal(t, entity.AuthTokens{ + AccessToken: "existing-jwt", + RefreshToken: "existing-refresh", + APIKey: testAPIKey, + APIKeyOrgID: "org-test", + }, s.saved) +} + +func TestLoginWithAPIKey_EmptyKeyReturnsError(t *testing.T) { + s := MockAuthStore{} + a := Auth{ + authStore: &s, + oauth: &MockOauth{}, + } + + err := a.LoginWithAPIKey("", "org-test") + assert.Error(t, err) + assert.False(t, s.didSave) +} + +func TestLoginWithAPIKey_EmptyOrgIDReturnsError(t *testing.T) { + s := MockAuthStore{} + a := Auth{ + authStore: &s, + oauth: &MockOauth{}, + } + + err := a.LoginWithAPIKey(testAPIKey, "") + assert.Error(t, err) + assert.False(t, s.didSave) +} + +func TestStandardLogin_APIKeyCredentialDoesNotProbeOAuthProviders(t *testing.T) { + oldStdout := os.Stdout + t.Cleanup(func() { + os.Stdout = oldStdout + }) + readPipe, writePipe, err := os.Pipe() + require.NoError(t, err) + os.Stdout = writePipe + + _ = StandardLogin("", "", &entity.AuthTokens{ + AccessToken: "existing-jwt", + APIKey: testAPIKey, + }) + + assert.NoError(t, writePipe.Close()) + os.Stdout = oldStdout + out, err := io.ReadAll(readPipe) + assert.NoError(t, err) + assert.Empty(t, string(out)) +} func TestSuccessNoRefreshGetFreshAccessTokenOrLogin(t *testing.T) { s := MockAuthStore{authTokens: &entity.AuthTokens{ diff --git a/pkg/cmd/completions/completions.go b/pkg/cmd/completions/completions.go index 5dcabd1b3..1dcf989bc 100644 --- a/pkg/cmd/completions/completions.go +++ b/pkg/cmd/completions/completions.go @@ -1,6 +1,7 @@ package completions import ( + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/entity" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" @@ -8,6 +9,7 @@ import ( ) type CompletionStore interface { + auth.APIKeyAuthStore GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) GetCurrentUser() (*entity.User, error) @@ -18,12 +20,6 @@ type CompletionHandler func(cmd *cobra.Command, args []string, toComplete string func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *terminal.Terminal) CompletionHandler { return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - user, err := completionStore.GetCurrentUser() - if err != nil { - t.Errprint(err, "") - return nil, cobra.ShellCompDirectiveError - } - org, err := completionStore.GetActiveOrganizationOrDefault() if err != nil { t.Errprint(err, "") @@ -33,7 +29,17 @@ func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *te return []string{}, cobra.ShellCompDirectiveDefault } - workspaces, err := completionStore.GetWorkspaces(org.ID, &store.GetWorkspacesOptions{UserID: user.ID}) + var options *store.GetWorkspacesOptions + if !auth.IsAPIKeyAuthStore(completionStore) { + user, err := completionStore.GetCurrentUser() + if err != nil { + t.Errprint(err, "") + return nil, cobra.ShellCompDirectiveError + } + options = &store.GetWorkspacesOptions{UserID: user.ID} + } + + workspaces, err := completionStore.GetWorkspaces(org.ID, options) if err != nil { t.Errprint(err, "") return nil, cobra.ShellCompDirectiveError @@ -50,6 +56,10 @@ func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *te func GetOrgsNameCompletionHandler(completionStore CompletionStore, t *terminal.Terminal) CompletionHandler { return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if auth.IsAPIKeyAuthStore(completionStore) { + return []string{}, cobra.ShellCompDirectiveNoFileComp + } + orgs, err := completionStore.GetOrganizations(nil) if err != nil { t.Errprint(err, "") diff --git a/pkg/cmd/delete/delete.go b/pkg/cmd/delete/delete.go index 570320de2..4e56d8fcd 100644 --- a/pkg/cmd/delete/delete.go +++ b/pkg/cmd/delete/delete.go @@ -6,6 +6,7 @@ import ( "os" "strings" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/entity" @@ -99,6 +100,9 @@ func deleteWorkspace(workspaceName string, t *terminal.Terminal, deleteStore Del func handleAdminUser(err error, deleteStore DeleteStore, piped bool) error { if strings.Contains(err.Error(), "not found") { + if auth.IsAPIKeyAuthStore(deleteStore) { + return breverrors.WrapAndTrace(err) + } user, err1 := deleteStore.GetCurrentUser() if err1 != nil { return breverrors.WrapAndTrace(err1) diff --git a/pkg/cmd/gpucreate/gpucreate.go b/pkg/cmd/gpucreate/gpucreate.go index 5798c880f..9950fe86a 100644 --- a/pkg/cmd/gpucreate/gpucreate.go +++ b/pkg/cmd/gpucreate/gpucreate.go @@ -14,6 +14,7 @@ import ( "time" "unicode" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/config" @@ -97,6 +98,7 @@ type GPUCreateStore interface { util.GetWorkspaceByNameOrIDErrStore gpusearch.GPUSearchStore GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetAuthTokens() (*entity.AuthTokens, error) GetCurrentUser() (*entity.User, error) GetWorkspace(workspaceID string) (*entity.Workspace, error) CreateWorkspace(organizationID string, options *store.CreateWorkspacesOptions) (*entity.Workspace, error) @@ -720,10 +722,13 @@ func newCreateContext(t *terminal.Terminal, store GPUCreateStore, opts GPUCreate } } - // Get user - user, err := store.GetCurrentUser() - if err != nil { - return nil, breverrors.WrapAndTrace(err) + user := &entity.User{} + if !auth.IsAPIKeyAuthStore(store) { + var err error + user, err = store.GetCurrentUser() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } } ctx.user = user @@ -877,7 +882,11 @@ func (c *createContext) waitForInstances(workspaces []*entity.Workspace) { for _, ws := range workspaces { err := c.pollUntilReady(ws.ID) if err != nil { - c.logf(" %s: Timeout waiting for ready state\n", ws.Name) + if strings.Contains(err.Error(), "timeout waiting") { + c.logf(" %s: Timeout waiting for ready state\n", ws.Name) + } else { + c.logf(" %s: %s\n", ws.Name, c.colorize(err.Error(), c.t.Red)) + } } } } @@ -1220,6 +1229,9 @@ func (c *createContext) pollUntilReady(wsID string) error { } if ws.Status == entity.Failure { + if ws.StatusMessage != "" { + return breverrors.NewValidationError(fmt.Sprintf("instance %s failed: %s", ws.Name, ws.StatusMessage)) + } return breverrors.NewValidationError(fmt.Sprintf("instance %s failed", ws.Name)) } diff --git a/pkg/cmd/gpucreate/gpucreate_test.go b/pkg/cmd/gpucreate/gpucreate_test.go index 082c54565..be7919827 100644 --- a/pkg/cmd/gpucreate/gpucreate_test.go +++ b/pkg/cmd/gpucreate/gpucreate_test.go @@ -3,6 +3,7 @@ package gpucreate import ( "strings" "testing" + "time" "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" "github.com/brevdev/brev-cli/pkg/entity" @@ -44,6 +45,10 @@ func (m *MockGPUCreateStore) GetCurrentUser() (*entity.User, error) { return m.User, nil } +func (m *MockGPUCreateStore) GetAuthTokens() (*entity.AuthTokens, error) { + return nil, nil +} + func (m *MockGPUCreateStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { return m.Org, nil } @@ -641,3 +646,22 @@ func TestFormatInstanceSpecs(t *testing.T) { result := formatInstanceSpecs(specs) assert.Equal(t, "g5.xlarge (1000GB disk), p4d.24xlarge, g6.xlarge (500GB disk)", result) } + +func TestPollUntilReadyReportsWorkspaceFailureMessage(t *testing.T) { + store := NewMockGPUCreateStore() + store.Workspaces["ws-failed"] = &entity.Workspace{ + ID: "ws-failed", + Name: "test", + Status: entity.Failure, + StatusMessage: "unexpected end of JSON input", + } + + ctx := &createContext{ + store: store, + opts: GPUCreateOptions{Timeout: time.Second}, + } + + err := ctx.pollUntilReady("ws-failed") + + assert.ErrorContains(t, err, "instance test failed: unexpected end of JSON input") +} diff --git a/pkg/cmd/invite/invite.go b/pkg/cmd/invite/invite.go index 50e07df6a..901319031 100644 --- a/pkg/cmd/invite/invite.go +++ b/pkg/cmd/invite/invite.go @@ -17,12 +17,9 @@ import ( ) type InviteStore interface { - GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error) - GetActiveOrganizationOrDefault() (*entity.Organization, error) - GetCurrentUser() (*entity.User, error) + completions.CompletionStore GetUsers(queryParams map[string]string) ([]entity.User, error) GetWorkspace(workspaceID string) (*entity.Workspace, error) - GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) CreateInviteLink(organizationID string) (string, error) } diff --git a/pkg/cmd/login/login.go b/pkg/cmd/login/login.go index e96968cc9..51d044c42 100644 --- a/pkg/cmd/login/login.go +++ b/pkg/cmd/login/login.go @@ -33,6 +33,7 @@ type LoginStore interface { auth.AuthStore GetCurrentUser() (*entity.User, error) CreateUser(idToken string) (*entity.User, error) + SetDefaultOrganization(org *entity.Organization) error GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) CreateOrganization(req store.CreateOrganizationRequest) (*entity.Organization, error) @@ -47,6 +48,7 @@ type LoginStore interface { type Auth interface { Login(skipBrowser bool) (*auth.LoginTokens, error) LoginWithToken(token string) error + LoginWithAPIKey(apiKey string, orgID string) error } // loginStore must be a no prompt store @@ -57,6 +59,8 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. } var loginToken string + var apiKey string + var apiKeyOrgID string var skipBrowser bool var emailFlag string var authProviderFlag string @@ -70,7 +74,8 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. Example: "brev login", Args: cmderrors.TransformToValidationError(cobra.NoArgs), RunE: func(cmd *cobra.Command, args []string) error { - err := opts.RunLogin(t, loginToken, skipBrowser, emailFlag, authProviderFlag) + apiKeyLogin := strings.TrimSpace(apiKey) != "" + err := opts.RunLogin(t, loginToken, apiKey, apiKeyOrgID, skipBrowser, emailFlag, authProviderFlag) if err != nil { // if err is ImportIDEConfigError, log err with sentry but continue if _, ok := err.(*importideconfig.ImportIDEConfigError); !ok { @@ -83,6 +88,9 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. } return err //nolint:wrapcheck // we want to return the error from the login } + if apiKeyLogin { + return nil + } homeDir, homeErr := opts.LoginStore.UserHomeDir() if homeErr == nil && agentskill.IsAnyAgentInstalled(homeDir) && !agentskill.IsSkillInstalled(homeDir) { t.Vprintf("\nšŸ’” Detected an AI coding agent. Run %s to enable natural-language commands.\n", t.Yellow("brev agent-skill install")) @@ -94,6 +102,10 @@ func NewCmdLogin(t *terminal.Terminal, loginStore LoginStore, auth Auth) *cobra. }, } cmd.Flags().StringVarP(&loginToken, "token", "", "", "token provided to auto login") + cmd.Flags().StringVar(&apiKey, "api-key", "", "api key to authenticate CLI requests") + cmd.Flags().StringVar(&apiKeyOrgID, "org-id", "", "organization ID for API key auth") + _ = cmd.Flags().MarkHidden("api-key") + _ = cmd.Flags().MarkHidden("org-id") cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "print url instead of auto opening browser") cmd.Flags().StringVar(&emailFlag, "email", "", "email to use for authentication") cmd.Flags().StringVar(&authProviderFlag, "auth", "", "authentication provider to use (nvidia or legacy, default is nvidia)") @@ -145,7 +157,15 @@ func (o LoginOptions) getOrCreateOrg(username string) (*entity.Organization, err return org, nil } -func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrowser bool, emailFlag string, authProviderFlag string) error { +func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, apiKey string, apiKeyOrgID string, skipBrowser bool, emailFlag string, authProviderFlag string) error { + apiKey = strings.TrimSpace(apiKey) + if apiKey != "" { + return o.doApiKeyLogin(t, loginToken, apiKey, apiKeyOrgID, skipBrowser, emailFlag, authProviderFlag) + } + if strings.TrimSpace(apiKeyOrgID) != "" { + return breverrors.NewValidationError("org-id can only be used with api-key") + } + tokens, _ := o.LoginStore.GetAuthTokens() if authProviderFlag != "" && authProviderFlag != "nvidia" && authProviderFlag != "legacy" { @@ -188,6 +208,28 @@ func (o LoginOptions) RunLogin(t *terminal.Terminal, loginToken string, skipBrow return nil } +func (o LoginOptions) doApiKeyLogin(t *terminal.Terminal, loginToken string, apiKey string, apiKeyOrgID string, skipBrowser bool, emailFlag string, authProviderFlag string) error { + if loginToken != "" || skipBrowser || emailFlag != "" || authProviderFlag != "" { + return breverrors.NewValidationError("api-key cannot be used with token, skip-browser, email, or auth flags") + } + apiKey = strings.TrimSpace(apiKey) + orgID := strings.TrimSpace(apiKeyOrgID) + if orgID == "" { + return breverrors.NewValidationError(auth.MissingAPIKeyOrgIDMessage) + } + if err := o.Auth.LoginWithAPIKey(apiKey, orgID); err != nil { + return breverrors.WrapAndTrace(err) + } + if err := o.LoginStore.SetDefaultOrganization(&entity.Organization{ + ID: orgID, + Name: orgID, + }); err != nil { + return breverrors.WrapAndTrace(err) + } + t.Vprint(t.Green(fmt.Sprintf("API key saved for org %s", orgID))) + return nil +} + func (o LoginOptions) handleOnboarding(user *entity.User, _ *terminal.Terminal) error { // figure out if we should onboard the user currentOnboardingStatus, err := user.GetOnboardingData() diff --git a/pkg/cmd/login/login_test.go b/pkg/cmd/login/login_test.go index 43003eea9..46fcec712 100644 --- a/pkg/cmd/login/login_test.go +++ b/pkg/cmd/login/login_test.go @@ -1 +1,233 @@ package login + +import ( + "bytes" + "testing" + + authpkg "github.com/brevdev/brev-cli/pkg/auth" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testAPIKey = authpkg.BrevAPIKeyPrefix + "test-key" + +type mockLoginAuth struct { + apiKeyCalls int + apiKey string + apiKeyOrgID string + tokenCalls int + loginCalls int +} + +func (m *mockLoginAuth) Login(_ bool) (*authpkg.LoginTokens, error) { + m.loginCalls++ + return &authpkg.LoginTokens{}, nil +} + +func (m *mockLoginAuth) LoginWithToken(_ string) error { + m.tokenCalls++ + return nil +} + +func (m *mockLoginAuth) LoginWithAPIKey(apiKey string, orgID string) error { + m.apiKeyCalls++ + m.apiKey = apiKey + m.apiKeyOrgID = orgID + return nil +} + +type mockLoginStore struct { + getCurrentUserCalls int + createUserCalls int + getOrCreateOrgCalls int + createOrganizationCalls int + setDefaultOrgCalls int + updateUserCalls int + userHomeDirCalls int + defaultOrg *entity.Organization +} + +func (m *mockLoginStore) SaveAuthTokens(_ entity.AuthTokens) error { return nil } +func (m *mockLoginStore) GetAuthTokens() (*entity.AuthTokens, error) { return nil, nil } +func (m *mockLoginStore) DeleteAuthTokens() error { return nil } + +func (m *mockLoginStore) GetCurrentUser() (*entity.User, error) { + m.getCurrentUserCalls++ + return &entity.User{ID: "user-1", Username: "testuser", Name: "Test User", Email: "test@example.com"}, nil +} + +func (m *mockLoginStore) CreateUser(_ string) (*entity.User, error) { + m.createUserCalls++ + return &entity.User{}, nil +} + +func (m *mockLoginStore) SetDefaultOrganization(org *entity.Organization) error { + m.setDefaultOrgCalls++ + m.defaultOrg = org + return nil +} + +func (m *mockLoginStore) GetOrganizations(_ *store.GetOrganizationsOptions) ([]entity.Organization, error) { + return []entity.Organization{{ID: "org-1", Name: "org"}}, nil +} + +func (m *mockLoginStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + m.getOrCreateOrgCalls++ + return &entity.Organization{ID: "org-1", Name: "org"}, nil +} + +func (m *mockLoginStore) CreateOrganization(_ store.CreateOrganizationRequest) (*entity.Organization, error) { + m.createOrganizationCalls++ + return &entity.Organization{ID: "org-1", Name: "org"}, nil +} + +func (m *mockLoginStore) GetServerSockFile() string { return "" } + +func (m *mockLoginStore) GetWorkspaces(_ string, _ *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + return nil, nil +} + +func (m *mockLoginStore) UpdateUser(_ string, updatedUser *entity.UpdateUser) (*entity.User, error) { + m.updateUserCalls++ + return &entity.User{ + ID: "user-1", + Username: updatedUser.Username, + Name: updatedUser.Name, + Email: updatedUser.Email, + OnboardingData: updatedUser.OnboardingData, + }, nil +} + +func (m *mockLoginStore) UserHomeDir() (string, error) { + m.userHomeDirCalls++ + return "/home/testuser", nil +} + +func (m *mockLoginStore) GetAllWorkspaces(_ *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + return nil, nil +} +func (m *mockLoginStore) GetCurrentWorkspaceID() (string, error) { return "", nil } +func (m *mockLoginStore) GetWindowsDir() (string, error) { return "", nil } + +func TestRunLoginWithAPIKey_SavesKeyAndOrgWithoutUserOrBackendOrgCalls(t *testing.T) { + auth := &mockLoginAuth{} + loginStore := &mockLoginStore{} + opts := LoginOptions{Auth: auth, LoginStore: loginStore} + + err := opts.RunLogin(terminal.New(), "", " "+testAPIKey+" ", " org-test ", false, "", "") + + require.NoError(t, err) + assert.Equal(t, 1, auth.apiKeyCalls) + assert.Equal(t, testAPIKey, auth.apiKey) + assert.Equal(t, "org-test", auth.apiKeyOrgID) + assert.Equal(t, 1, loginStore.setDefaultOrgCalls) + require.NotNil(t, loginStore.defaultOrg) + assert.Equal(t, "org-test", loginStore.defaultOrg.ID) + assert.Equal(t, "org-test", loginStore.defaultOrg.Name) + assert.Equal(t, 0, auth.tokenCalls) + assert.Equal(t, 0, auth.loginCalls) + assert.Equal(t, 0, loginStore.getCurrentUserCalls) + assert.Equal(t, 0, loginStore.createUserCalls) + assert.Equal(t, 0, loginStore.getOrCreateOrgCalls) + assert.Equal(t, 0, loginStore.createOrganizationCalls) + assert.Equal(t, 0, loginStore.updateUserCalls) +} + +func TestRunLoginWithAPIKey_RejectsConflictingFlags(t *testing.T) { + tests := []struct { + name string + loginToken string + skipBrowser bool + emailFlag string + authProviderFlag string + }{ + {name: "token", loginToken: "token"}, + {name: "skip browser", skipBrowser: true}, + {name: "email", emailFlag: "user@example.com"}, + {name: "auth provider", authProviderFlag: "nvidia"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &mockLoginAuth{} + opts := LoginOptions{Auth: auth, LoginStore: &mockLoginStore{}} + + err := opts.RunLogin(terminal.New(), tt.loginToken, testAPIKey, "org-test", tt.skipBrowser, tt.emailFlag, tt.authProviderFlag) + + require.Error(t, err) + assert.Equal(t, 0, auth.apiKeyCalls) + }) + } +} + +func TestNewCmdLoginWithAPIKey_SkipsPostLoginHooks(t *testing.T) { + auth := &mockLoginAuth{} + loginStore := &mockLoginStore{} + cmd := NewCmdLogin(terminal.New(), loginStore, auth) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + cmd.SetArgs([]string{"--api-key", testAPIKey, "--org-id", "org-test"}) + + err := cmd.Execute() + + require.NoError(t, err) + assert.Equal(t, 1, auth.apiKeyCalls) + assert.Equal(t, 1, loginStore.setDefaultOrgCalls) + assert.Equal(t, 0, loginStore.getCurrentUserCalls) + assert.Equal(t, 0, loginStore.updateUserCalls) + assert.Equal(t, 0, loginStore.userHomeDirCalls) +} + +func TestNewCmdLogin_HidesAPIKeyFlagsFromHelp(t *testing.T) { + cmd := NewCmdLogin(terminal.New(), &mockLoginStore{}, &mockLoginAuth{}) + var out bytes.Buffer + cmd.SetOut(&out) + cmd.SetErr(&bytes.Buffer{}) + cmd.SetArgs([]string{"--help"}) + + err := cmd.Execute() + + require.NoError(t, err) + assert.NotContains(t, out.String(), "--api-key") + assert.NotContains(t, out.String(), "--org-id") +} + +func TestRunLoginWithAPIKey_RejectsMissingOrgID(t *testing.T) { + tests := []struct { + name string + apiKey string + orgID string + }{ + {name: "missing org id", apiKey: testAPIKey, orgID: " "}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auth := &mockLoginAuth{} + loginStore := &mockLoginStore{} + opts := LoginOptions{Auth: auth, LoginStore: loginStore} + + err := opts.RunLogin(terminal.New(), "", tt.apiKey, tt.orgID, false, "", "") + + require.Error(t, err) + assert.Equal(t, 0, auth.apiKeyCalls) + assert.Equal(t, 0, loginStore.setDefaultOrgCalls) + }) + } +} + +func TestRunLoginWithOrgIDWithoutAPIKeyRejects(t *testing.T) { + auth := &mockLoginAuth{} + loginStore := &mockLoginStore{} + opts := LoginOptions{Auth: auth, LoginStore: loginStore} + + err := opts.RunLogin(terminal.New(), "", "", "org-test", false, "", "") + + require.Error(t, err) + assert.Contains(t, err.Error(), "org-id can only be used with api-key") + assert.Equal(t, 0, auth.apiKeyCalls) + assert.Equal(t, 0, loginStore.setDefaultOrgCalls) +} diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index 4d3929b1c..532595770 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -12,6 +12,7 @@ import ( "connectrpc.com/connect" "github.com/brevdev/brev-cli/pkg/analytics" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/externalnode" "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" @@ -41,7 +42,8 @@ type LsStore interface { GetUsers(queryParams map[string]string) ([]entity.User, error) GetWorkspace(workspaceID string) (*entity.Workspace, error) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) - GetAccessToken() (string, error) + externalnode.TokenProvider + GetAuthTokens() (*entity.AuthTokens, error) GetInstanceTypes(includeCPU bool) (*gpusearch.InstanceTypesResponse, error) hello.HelloStore } @@ -113,9 +115,11 @@ with other commands like stop, start, or delete.`, // trackLsAnalytics sends analytics event for ls command func trackLsAnalytics(store LsStore) { userID := "" - user, err := store.GetCurrentUser() - if err == nil { - userID = user.ID + if !isAPIKeyAuthStore(store) { + user, err := store.GetCurrentUser() + if err == nil { + userID = user.ID + } } data := analytics.EventData{ EventName: "Brev ls", @@ -124,8 +128,26 @@ func trackLsAnalytics(store LsStore) { _ = analytics.TrackEvent(data) } -func getOrgForRunLs(lsStore LsStore, orgflag string) (*entity.Organization, error) { +func isAPIKeyAuthStore(lsStore LsStore) bool { + return auth.IsAPIKeyAuthStore(lsStore) +} + +func getOrgForRunLs(lsStore LsStore, orgflag string, apiKeyAuth bool) (*entity.Organization, error) { var org *entity.Organization + if apiKeyAuth { + if orgflag != "" { + return nil, breverrors.NewValidationError("api key auth is scoped to the org saved during login; --org is not supported") + } + org, err := lsStore.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if org == nil { + return nil, breverrors.NewValidationError("no orgs exist") + } + return org, nil + } + if orgflag != "" { var orgs []entity.Organization orgs, err := lsStore.GetOrganizations(&store.GetOrganizationsOptions{Name: orgflag}) @@ -155,12 +177,18 @@ func getOrgForRunLs(lsStore LsStore, orgflag string) (*entity.Organization, erro func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, showAll bool, jsonOutput bool) error { ls := NewLs(lsStore, t, jsonOutput) - user, err := lsStore.GetCurrentUser() - if err != nil { - return breverrors.WrapAndTrace(err) + apiKeyAuth := isAPIKeyAuthStore(lsStore) + + var user *entity.User + if !apiKeyAuth { + var err error + user, err = lsStore.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } } - org, err := getOrgForRunLs(lsStore, orgflag) + org, err := getOrgForRunLs(lsStore, orgflag, apiKeyAuth) if err != nil { return breverrors.WrapAndTrace(err) } @@ -169,7 +197,7 @@ func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, } if len(args) == 1 { //nolint:gocritic // don't want to switch - err = handleLsArg(ls, args[0], user, org, showAll) + err = handleLsArg(ls, args[0], user, org, showAll, apiKeyAuth) if err != nil { return breverrors.WrapAndTrace(err) } @@ -185,45 +213,77 @@ func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, return nil } -func handleLsArg(ls *Ls, arg string, user *entity.User, org *entity.Organization, showAll bool) error { - // todo refactor this to cmd.register - //nolint:gocritic // idk how to write this as a switch - if util.IsSingularOrPlural(arg, "org") || util.IsSingularOrPlural(arg, "organization") { // handle org, orgs, and organization(s) - err := ls.RunOrgs() - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - } else if util.IsSingularOrPlural(arg, "workspace") { - err := ls.RunWorkspaces(org, user, showAll) - if err != nil { - return breverrors.WrapAndTrace(err) - } - } else if util.IsSingularOrPlural(arg, "user") && featureflag.IsAdmin(user.GlobalUserType) { - err := ls.RunUser(showAll) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - } else if util.IsSingularOrPlural(arg, "host") && featureflag.IsAdmin(user.GlobalUserType) { - err := ls.RunHosts(org) - if err != nil { - return breverrors.WrapAndTrace(err) - } - return nil - } else if util.IsSingularOrPlural(arg, "node") { - err := ls.RunNodes(org) - if err != nil { - return breverrors.WrapAndTrace(err) +func handleLsArg(ls *Ls, arg string, user *entity.User, org *entity.Organization, showAll bool, apiKeyAuth bool) error { + switch classifyLsArg(arg) { + case lsArgOrgs: + if apiKeyAuth { + return breverrors.NewValidationError("api key auth cannot list organizations") } + return wrapLsRun(ls.RunOrgs()) + case lsArgWorkspaces: + return wrapLsRun(ls.RunWorkspaces(org, user, showAll)) + case lsArgUsers: + return runAdminLsArg(user, apiKeyAuth, "users", func() error { + return ls.RunUser(showAll) + }) + case lsArgHosts: + return runAdminLsArg(user, apiKeyAuth, "hosts", func() error { + return ls.RunHosts(org) + }) + case lsArgNodes: + return wrapLsRun(ls.RunNodes(org)) + case lsArgInstances: + return wrapLsRun(ls.RunInstances(org, user, showAll)) + default: return nil - } else if util.IsSingularOrPlural(arg, "instance") { - err := ls.RunInstances(org, user, showAll) - if err != nil { - return breverrors.WrapAndTrace(err) - } + } +} + +type lsArgKind int + +const ( + lsArgUnknown lsArgKind = iota + lsArgOrgs + lsArgWorkspaces + lsArgUsers + lsArgHosts + lsArgNodes + lsArgInstances +) + +func classifyLsArg(arg string) lsArgKind { + switch { + case util.IsSingularOrPlural(arg, "org") || util.IsSingularOrPlural(arg, "organization"): + return lsArgOrgs + case util.IsSingularOrPlural(arg, "workspace"): + return lsArgWorkspaces + case util.IsSingularOrPlural(arg, "user"): + return lsArgUsers + case util.IsSingularOrPlural(arg, "host"): + return lsArgHosts + case util.IsSingularOrPlural(arg, "node"): + return lsArgNodes + case util.IsSingularOrPlural(arg, "instance"): + return lsArgInstances + default: + return lsArgUnknown + } +} + +func runAdminLsArg(user *entity.User, apiKeyAuth bool, resource string, run func() error) error { + if apiKeyAuth { + return breverrors.NewValidationError(fmt.Sprintf("api key auth cannot list %s", resource)) + } + if !featureflag.IsAdmin(user.GlobalUserType) { return nil } + return wrapLsRun(run()) +} + +func wrapLsRun(err error) error { + if err != nil { + return breverrors.WrapAndTrace(err) + } return nil } @@ -355,6 +415,19 @@ func (ls Ls) ShowUserWorkspaces(org *entity.Organization, otherOrgs []entity.Org ls.displayWorkspacesAndHelp(org, otherOrgs, userWorkspaces, allWorkspaces, gpuLookup) } +func (ls Ls) ShowOrgWorkspaces(org *entity.Organization, workspaces []entity.Workspace, gpuLookup map[string]string) { + if len(workspaces) == 0 { + ls.terminal.Vprint(ls.terminal.Yellow("No instances in org %s\n", org.Name)) + return + } + ls.terminal.Vprintf("Org %s has %d instances\n", ls.terminal.Yellow(org.Name), len(workspaces)) + displayWorkspacesTable(ls.terminal, workspaces, gpuLookup) + + fmt.Print("\n") + + displayLsResetBreadCrumb(ls.terminal, workspaces) +} + func (ls Ls) displayWorkspacesAndHelp(org *entity.Organization, otherOrgs []entity.Organization, userWorkspaces []entity.Workspace, allWorkspaces []entity.Workspace, gpuLookup map[string]string) { if len(userWorkspaces) == 0 { ls.terminal.Vprint(ls.terminal.Yellow("No instances in org %s\n", org.Name)) @@ -451,9 +524,12 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll // Determine which workspaces to show var workspacesToShow []entity.Workspace - if showAll { + switch { + case showAll: workspacesToShow = allWorkspaces - } else { + case user == nil: + workspacesToShow = allWorkspaces + default: workspacesToShow = store.FilterForUserWorkspaces(allWorkspaces, user.ID) } @@ -462,6 +538,11 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll return ls.outputWorkspacesJSON(workspacesToShow, gpuLookup, nodes) } + if user == nil { + ls.ShowOrgWorkspaces(org, workspacesToShow, gpuLookup) + return nil + } + // Table output with colors and help text orgs, err := ls.lsStore.GetOrganizations(nil) if err != nil { diff --git a/pkg/cmd/ls/ls_test.go b/pkg/cmd/ls/ls_test.go index e7a4f591c..4613cb360 100644 --- a/pkg/cmd/ls/ls_test.go +++ b/pkg/cmd/ls/ls_test.go @@ -8,36 +8,64 @@ import ( nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + authpkg "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/cmd/gpusearch" "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" ) +const testAPIKey = authpkg.BrevAPIKeyPrefix + "test-key" + // mockLsStore implements LsStore (including the embedded hello.HelloStore) for // testing the ls command routing without real API calls. type mockLsStore struct { - user *entity.User - org *entity.Organization - orgs []entity.Organization - workspaces []entity.Workspace + user *entity.User + org *entity.Organization + orgs []entity.Organization + workspaces []entity.Workspace + authTokens *entity.AuthTokens + workspaceOrgID string + currentUserCalls int + getOrganizationsCall int +} + +func (m *mockLsStore) GetCurrentUser() (*entity.User, error) { + m.currentUserCalls++ + return m.user, nil +} + +func (m *mockLsStore) GetAuthTokens() (*entity.AuthTokens, error) { + return m.authTokens, nil +} + +func (m *mockLsStore) GetAccessToken() (string, error) { + return "tok", nil } -func (m *mockLsStore) GetCurrentUser() (*entity.User, error) { return m.user, nil } -func (m *mockLsStore) GetAccessToken() (string, error) { return "tok", nil } func (m *mockLsStore) GetWorkspace(_ string) (*entity.Workspace, error) { return nil, nil } func (m *mockLsStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + if m.authTokens != nil && authpkg.IsBrevAPIKey(m.authTokens.APIKey) { + orgID, err := authpkg.GetAPIKeyOrgID(m) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return &entity.Organization{ID: orgID, Name: m.org.Name}, nil + } return m.org, nil } -func (m *mockLsStore) GetWorkspaces(_ string, _ *store.GetWorkspacesOptions) ([]entity.Workspace, error) { +func (m *mockLsStore) GetWorkspaces(orgID string, _ *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + m.workspaceOrgID = orgID return m.workspaces, nil } func (m *mockLsStore) GetOrganizations(_ *store.GetOrganizationsOptions) ([]entity.Organization, error) { + m.getOrganizationsCall++ return m.orgs, nil } @@ -69,6 +97,113 @@ func newTestStore() *mockLsStore { } } +func TestRunLs_APIKeyJSONSkipsUserAndOrgList(t *testing.T) { + s := newTestStore() + s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org1"} + s.workspaces = []entity.Workspace{ + { + ID: "ws1", + Name: "owned-by-someone", + Status: entity.Running, + CreatedByUserID: "other-user", + }, + { + ID: "ws2", + Name: "owned-by-user", + Status: entity.Stopped, + CreatedByUserID: "u1", + }, + } + term := terminal.New() + + out := captureStdout(t, func() { + err := RunLs(term, s, nil, "", false, true) + if err != nil { + t.Fatalf("RunLs returned error: %v", err) + } + }) + + var parsed struct { + Workspaces []WorkspaceInfo `json:"workspaces"` + } + if err := json.Unmarshal([]byte(out), &parsed); err != nil { + t.Fatalf("failed to parse JSON output: %v\nraw output: %s", err, out) + } + if len(parsed.Workspaces) != 2 { + t.Fatalf("expected API key ls to show all org workspaces, got %d", len(parsed.Workspaces)) + } + if s.currentUserCalls != 0 { + t.Fatalf("expected no GetCurrentUser calls, got %d", s.currentUserCalls) + } + if s.getOrganizationsCall != 0 { + t.Fatalf("expected no GetOrganizations calls, got %d", s.getOrganizationsCall) + } +} + +func TestRunLs_APIKeyUsesCredentialOrgNotCachedActiveOrg(t *testing.T) { + s := newTestStore() + s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org-login"} + s.org = &entity.Organization{ID: "org-set", Name: "set-org"} + s.workspaces = []entity.Workspace{ + { + ID: "ws1", + Name: "api-key-instance", + Status: entity.Running, + CreatedByUserID: "other-user", + }, + } + term := terminal.New() + + out := captureStdout(t, func() { + err := RunLs(term, s, nil, "", false, true) + if err != nil { + t.Fatalf("RunLs returned error: %v", err) + } + }) + + if !strings.Contains(out, "api-key-instance") { + t.Fatalf("expected workspace output, got %s", out) + } + if s.workspaceOrgID != "org-login" { + t.Fatalf("expected API key credential org org-login, got %s", s.workspaceOrgID) + } +} + +func TestGetOrgForRunLs_APIKeyUsesActiveOrgDisplayName(t *testing.T) { + s := newTestStore() + s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org-login"} + s.org = &entity.Organization{ID: "org-login", Name: "friendly-org"} + + org, err := getOrgForRunLs(s, "", true) + if err != nil { + t.Fatalf("getOrgForRunLs returned error: %v", err) + } + if org.ID != "org-login" { + t.Fatalf("expected org-login, got %s", org.ID) + } + if org.Name != "friendly-org" { + t.Fatalf("expected friendly org name, got %s", org.Name) + } +} + +func TestRunLs_APIKeyRequiresCredentialOrg(t *testing.T) { + s := newTestStore() + s.authTokens = &entity.AuthTokens{APIKey: testAPIKey} + term := terminal.New() + + err := RunLs(term, s, nil, "", false, true) + + if err == nil { + t.Fatal("expected missing API key org error, got nil") + } + if !strings.Contains(err.Error(), "api key auth requires an org id") { + t.Fatalf("expected API key org validation error, got %v", err) + } + if s.workspaceOrgID != "" { + t.Fatalf("expected no workspace call, got org %s", s.workspaceOrgID) + } +} + // captureStdout runs fn while capturing stdout and returns the output. func captureStdout(t *testing.T, fn func()) string { t.Helper() @@ -413,7 +548,7 @@ func TestHandleLsArg_Routing(t *testing.T) { "workspace", "workspaces", } for _, arg := range successArgs { - if err := handleLsArg(ls, arg, s.user, s.org, false); err != nil { + if err := handleLsArg(ls, arg, s.user, s.org, false, false); err != nil { t.Errorf("handleLsArg(%q) returned unexpected error: %v", arg, err) } } @@ -421,7 +556,7 @@ func TestHandleLsArg_Routing(t *testing.T) { // "node"/"nodes" route to RunNodes which calls the gRPC client — verify // it attempts the path (error expected due to no real client). for _, arg := range []string{"node", "nodes"} { - _ = handleLsArg(ls, arg, s.user, s.org, false) + _ = handleLsArg(ls, arg, s.user, s.org, false, false) } } diff --git a/pkg/cmd/set/set.go b/pkg/cmd/set/set.go index f405b8bbf..0de3ac4f3 100644 --- a/pkg/cmd/set/set.go +++ b/pkg/cmd/set/set.go @@ -4,6 +4,7 @@ package set import ( "fmt" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmdcontext" @@ -21,6 +22,7 @@ type SetStore interface { GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) GetServerSockFile() string GetCurrentWorkspaceID() (string, error) + GetAuthTokens() (*entity.AuthTokens, error) } func NewCmdSet(t *terminal.Terminal, loginSetStore SetStore, noLoginSetStore SetStore) *cobra.Command { @@ -59,6 +61,9 @@ func set(orgName string, setStore SetStore) error { if workspaceID != "" { return fmt.Errorf("can not set orgs in a workspace") } + if auth.IsAPIKeyAuthStore(setStore) { + return breverrors.NewValidationError("api key auth is scoped to the org saved during login; run brev login --api-key --org-id to change it") + } orgs, err := setStore.GetOrganizations(&store.GetOrganizationsOptions{Name: orgName}) if err != nil { return breverrors.WrapAndTrace(err) diff --git a/pkg/cmd/set/set_test.go b/pkg/cmd/set/set_test.go index 05e26e380..94fd0d4a9 100644 --- a/pkg/cmd/set/set_test.go +++ b/pkg/cmd/set/set_test.go @@ -1 +1,78 @@ package set + +import ( + "strings" + "testing" + + authpkg "github.com/brevdev/brev-cli/pkg/auth" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" +) + +const testAPIKey = authpkg.BrevAPIKeyPrefix + "test-key" + +type mockSetStore struct { + authTokens *entity.AuthTokens + workspaceID string + orgs []entity.Organization + getOrganizations int + setDefaultOrgCalls int + defaultOrganization *entity.Organization +} + +func (m *mockSetStore) GetWorkspaces(_ string, _ *store.GetWorkspacesOptions) ([]entity.Workspace, error) { + return nil, nil +} + +func (m *mockSetStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return nil, nil +} + +func (m *mockSetStore) GetCurrentUser() (*entity.User, error) { + return nil, nil +} + +func (m *mockSetStore) SetDefaultOrganization(org *entity.Organization) error { + m.setDefaultOrgCalls++ + m.defaultOrganization = org + return nil +} + +func (m *mockSetStore) GetOrganizations(_ *store.GetOrganizationsOptions) ([]entity.Organization, error) { + m.getOrganizations++ + return m.orgs, nil +} + +func (m *mockSetStore) GetServerSockFile() string { + return "" +} + +func (m *mockSetStore) GetCurrentWorkspaceID() (string, error) { + return m.workspaceID, nil +} + +func (m *mockSetStore) GetAuthTokens() (*entity.AuthTokens, error) { + return m.authTokens, nil +} + +func TestSetRejectsAPIKeyAuth(t *testing.T) { + s := &mockSetStore{ + authTokens: &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org-test"}, + orgs: []entity.Organization{{ID: "org-other", Name: "other-org"}}, + } + + err := set("other-org", s) + + if err == nil { + t.Fatal("expected API key auth set error, got nil") + } + if !strings.Contains(err.Error(), "api key auth is scoped") { + t.Fatalf("expected API key auth validation error, got %v", err) + } + if s.getOrganizations != 0 { + t.Fatalf("expected set to skip org lookup, got %d calls", s.getOrganizations) + } + if s.setDefaultOrgCalls != 0 { + t.Fatalf("expected set to skip default org write, got %d calls", s.setDefaultOrgCalls) + } +} diff --git a/pkg/cmd/start/start.go b/pkg/cmd/start/start.go index 708f2bffc..1110a160d 100644 --- a/pkg/cmd/start/start.go +++ b/pkg/cmd/start/start.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/cmd/completions" cmdutil "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/config" @@ -36,6 +37,7 @@ var ( type StartStore interface { cmdutil.GetWorkspaceByNameOrIDErrStore + GetAuthTokens() (*entity.AuthTokens, error) GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) GetCurrentUser() (*entity.User, error) @@ -119,12 +121,18 @@ type StartOptions struct { } func runStartWorkspace(t *terminal.Terminal, options StartOptions, startStore StartStore) error { - user, err := startStore.GetCurrentUser() - if err != nil { - return breverrors.WrapAndTrace(err) + apiKeyAuth := auth.IsAPIKeyAuthStore(startStore) + + var user *entity.User + if !apiKeyAuth { + var err error + user, err = startStore.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } } - didStart, err := maybeStartEmpty(t, user, options, startStore) + didStart, err := maybeStartEmpty(t, user, apiKeyAuth, options, startStore) if err != nil { return breverrors.WrapAndTrace(err) } @@ -132,7 +140,7 @@ func runStartWorkspace(t *terminal.Terminal, options StartOptions, startStore St return nil } - didStart, err = maybeStartFromGitURL(t, user, options, startStore) + didStart, err = maybeStartFromGitURL(t, user, apiKeyAuth, options, startStore) if err != nil { return breverrors.WrapAndTrace(err) } @@ -140,7 +148,7 @@ func runStartWorkspace(t *terminal.Terminal, options StartOptions, startStore St return nil } - didStart, err = maybeStartStoppedOrJoin(t, user, options, startStore) + didStart, err = maybeStartStoppedOrJoin(t, user, apiKeyAuth, options, startStore) if err != nil { return breverrors.WrapAndTrace(err) } @@ -148,7 +156,7 @@ func runStartWorkspace(t *terminal.Terminal, options StartOptions, startStore St return nil } - didStart, err = maybeStartWithLocalPath(options, user, t, startStore) + didStart, err = maybeStartWithLocalPath(options, user, apiKeyAuth, t, startStore) if err != nil { return breverrors.WrapAndTrace(err) } @@ -159,9 +167,9 @@ func runStartWorkspace(t *terminal.Terminal, options StartOptions, startStore St return nil } -func maybeStartWithLocalPath(options StartOptions, user *entity.User, t *terminal.Terminal, startStore StartStore) (bool, error) { +func maybeStartWithLocalPath(options StartOptions, user *entity.User, apiKeyAuth bool, t *terminal.Terminal, startStore StartStore) (bool, error) { if util.DoesPathExist(options.RepoOrPathOrNameOrID) { - err := startWorkspaceFromPath(user, t, options, startStore) + err := startWorkspaceFromPath(user, apiKeyAuth, t, options, startStore) if err != nil { return false, breverrors.WrapAndTrace(err) } @@ -172,7 +180,7 @@ func maybeStartWithLocalPath(options StartOptions, user *entity.User, t *termina return false, nil } -func maybeStartStoppedOrJoin(t *terminal.Terminal, user *entity.User, options StartOptions, startStore StartStore) (bool, error) { +func maybeStartStoppedOrJoin(t *terminal.Terminal, user *entity.User, apiKeyAuth bool, options StartOptions, startStore StartStore) (bool, error) { org, err := startStore.GetActiveOrganizationOrDefault() if err != nil { return false, breverrors.WrapAndTrace(err) @@ -181,7 +189,10 @@ func maybeStartStoppedOrJoin(t *terminal.Terminal, user *entity.User, options St if err != nil { return false, breverrors.WrapAndTrace(err) } - userWorkspaces := store.FilterForUserWorkspaces(workspaces, user.ID) + userWorkspaces := workspaces + if !apiKeyAuth { + userWorkspaces = store.FilterForUserWorkspaces(workspaces, user.ID) + } if len(userWorkspaces) > 0 { if len(userWorkspaces) > 1 { userWorkspaces = store.FilterNonFailedWorkspaces(userWorkspaces) @@ -203,7 +214,7 @@ func maybeStartStoppedOrJoin(t *terminal.Terminal, user *entity.User, options St } if len(workspaces) > 0 { - err = joinProjectWithNewWorkspace(t, workspaces[0], org.ID, startStore, user, options) + err = joinProjectWithNewWorkspace(t, workspaces[0], org.ID, startStore, user, apiKeyAuth, options) if err != nil { return false, breverrors.WrapAndTrace(err) } @@ -212,9 +223,9 @@ func maybeStartStoppedOrJoin(t *terminal.Terminal, user *entity.User, options St return false, nil } -func maybeStartFromGitURL(t *terminal.Terminal, user *entity.User, options StartOptions, startStore StartStore) (bool, error) { +func maybeStartFromGitURL(t *terminal.Terminal, user *entity.User, apiKeyAuth bool, options StartOptions, startStore StartStore) (bool, error) { if util.IsGitURL(options.RepoOrPathOrNameOrID) { // todo this is function is not complete, some cloneable urls are not identified - err := createNewWorkspaceFromGit(user, t, options.SetupScript, options, startStore) + err := createNewWorkspaceFromGit(user, apiKeyAuth, t, options.SetupScript, options, startStore) if err != nil { return true, breverrors.WrapAndTrace(err) } @@ -223,9 +234,9 @@ func maybeStartFromGitURL(t *terminal.Terminal, user *entity.User, options Start return false, nil } -func maybeStartEmpty(t *terminal.Terminal, user *entity.User, options StartOptions, startStore StartStore) (bool, error) { +func maybeStartEmpty(t *terminal.Terminal, user *entity.User, apiKeyAuth bool, options StartOptions, startStore StartStore) (bool, error) { if options.RepoOrPathOrNameOrID == "" { - err := createEmptyWorkspace(user, t, options, startStore) + err := createEmptyWorkspace(user, apiKeyAuth, t, options, startStore) if err != nil { return true, breverrors.WrapAndTrace(err) } @@ -234,7 +245,7 @@ func maybeStartEmpty(t *terminal.Terminal, user *entity.User, options StartOptio return false, nil } -func startWorkspaceFromPath(user *entity.User, t *terminal.Terminal, options StartOptions, startStore StartStore) error { +func startWorkspaceFromPath(user *entity.User, apiKeyAuth bool, t *terminal.Terminal, options StartOptions, startStore StartStore) error { pathExists := util.DoesPathExist(options.RepoOrPathOrNameOrID) if !pathExists { return fmt.Errorf("Path: %s does not exist", options.RepoOrPathOrNameOrID) @@ -275,7 +286,7 @@ func startWorkspaceFromPath(user *entity.User, t *terminal.Terminal, options Sta // logic wants it to be the directory path, so set it only before calling // createNewWorkspaceFromGit options.RepoOrPathOrNameOrID = gitURL - err := createNewWorkspaceFromGit(user, t, localSetupPath, options, startStore) + err := createNewWorkspaceFromGit(user, apiKeyAuth, t, localSetupPath, options, startStore) if err != nil { return breverrors.WrapAndTrace(err) } @@ -283,7 +294,7 @@ func startWorkspaceFromPath(user *entity.User, t *terminal.Terminal, options Sta return err } -func createEmptyWorkspace(user *entity.User, t *terminal.Terminal, options StartOptions, startStore StartStore) error { //nolint:funlen,gocyclo // TODO refactor +func createEmptyWorkspace(user *entity.User, apiKeyAuth bool, t *terminal.Terminal, options StartOptions, startStore StartStore) error { //nolint:funlen,gocyclo // TODO refactor // ensure name if len(options.Name) == 0 { return breverrors.NewValidationError("name field is required for empty workspaces") @@ -332,7 +343,7 @@ func createEmptyWorkspace(user *entity.User, t *terminal.Terminal, options Start cwOptions.WithClassID(options.WorkspaceClass) } - cwOptions = resolveWorkspaceUserOptions(cwOptions, user) + cwOptions = resolveWorkspaceUserOptions(cwOptions, user, apiKeyAuth) if len(setupScriptContents) > 0 { cwOptions.WithStartupScript(setupScriptContents) @@ -376,16 +387,17 @@ func createEmptyWorkspace(user *entity.User, t *terminal.Terminal, options Start } } -func resolveWorkspaceUserOptions(options *store.CreateWorkspacesOptions, user *entity.User) *store.CreateWorkspacesOptions { +func resolveWorkspaceUserOptions(options *store.CreateWorkspacesOptions, user *entity.User, apiKeyAuth bool) *store.CreateWorkspacesOptions { + isAdmin := !apiKeyAuth && user != nil && featureflag.IsAdmin(user.GlobalUserType) if options.WorkspaceTemplateID == "" { - if featureflag.IsAdmin(user.GlobalUserType) { + if isAdmin { options.WorkspaceTemplateID = store.DevWorkspaceTemplateID } else { options.WorkspaceTemplateID = store.UserWorkspaceTemplateID } } if options.WorkspaceClassID == "" { - if featureflag.IsAdmin(user.GlobalUserType) { + if isAdmin { options.WorkspaceClassID = store.DevWorkspaceClassID } else { options.WorkspaceClassID = store.UserWorkspaceClassID @@ -433,7 +445,7 @@ func startStopppedWorkspace(workspace *entity.Workspace, startStore StartStore, // "https://github.com/brevdev/microservices-demo.git // "https://github.com/brevdev/microservices-demo.git" // "git@github.com:brevdev/microservices-demo.git" -func joinProjectWithNewWorkspace(t *terminal.Terminal, templateWorkspace entity.Workspace, orgID string, startStore StartStore, user *entity.User, startOptions StartOptions) error { +func joinProjectWithNewWorkspace(t *terminal.Terminal, templateWorkspace entity.Workspace, orgID string, startStore StartStore, user *entity.User, apiKeyAuth bool, startOptions StartOptions) error { clusterID := config.GlobalConfig.GetDefaultClusterID() if startOptions.WorkspaceClass == "" { startOptions.WorkspaceClass = templateWorkspace.WorkspaceClassID @@ -446,7 +458,7 @@ func joinProjectWithNewWorkspace(t *terminal.Terminal, templateWorkspace entity. t.Vprintf("Name flag omitted, using auto generated name: %s\n", t.Green(cwOptions.Name)) } - cwOptions = resolveWorkspaceUserOptions(cwOptions, user) + cwOptions = resolveWorkspaceUserOptions(cwOptions, user, apiKeyAuth) t.Vprintf("Creating instance %s in org %s\n", t.Green(cwOptions.Name), t.Green(orgID)) t.Vprintf("\tname %s\n", cwOptions.Name) @@ -481,7 +493,7 @@ func IsURL(str string) bool { return err == nil && u.Scheme != "" && u.Host != "" } -func createNewWorkspaceFromGit(user *entity.User, t *terminal.Terminal, setupScriptURLOrPath string, startOptions StartOptions, startStore StartStore) error { +func createNewWorkspaceFromGit(user *entity.User, apiKeyAuth bool, t *terminal.Terminal, setupScriptURLOrPath string, startOptions StartOptions, startStore StartStore) error { // https://gist.githubusercontent.com/naderkhalil/4a45d4d293dc3a9eb330adcd5440e148/raw/3ab4889803080c3be94a7d141c7f53e286e81592/setup.sh // fetch contents of file // todo: read contents of file @@ -540,7 +552,7 @@ func createNewWorkspaceFromGit(user *entity.User, t *terminal.Terminal, setupScr orgID = orgs[0].ID } - err = createWorkspace(user, t, newWorkspace, orgID, startStore, startOptions) + err = createWorkspace(user, apiKeyAuth, t, newWorkspace, orgID, startStore, startOptions) if err != nil { return breverrors.WrapAndTrace(err) } @@ -591,7 +603,7 @@ func MakeNewWorkspaceFromURL(url string) NewWorkspace { } } -func createWorkspace(user *entity.User, t *terminal.Terminal, workspace NewWorkspace, orgID string, startStore StartStore, startOptions StartOptions) error { +func createWorkspace(user *entity.User, apiKeyAuth bool, t *terminal.Terminal, workspace NewWorkspace, orgID string, startStore StartStore, startOptions StartOptions) error { clusterID := config.GlobalConfig.GetDefaultClusterID() options := store.NewCreateWorkspacesOptions(clusterID, workspace.Name).WithGitRepo(workspace.GitRepo) @@ -600,7 +612,7 @@ func createWorkspace(user *entity.User, t *terminal.Terminal, workspace NewWorks options = options.WithWorkspaceClassID(startOptions.WorkspaceClass) } - options = resolveWorkspaceUserOptions(options, user) + options = resolveWorkspaceUserOptions(options, user, apiKeyAuth) if startOptions.SetupRepo != "" { options.WithCustomSetupRepo(startOptions.SetupRepo, startOptions.SetupPath) diff --git a/pkg/cmd/start/start_test.go b/pkg/cmd/start/start_test.go index 8ee68cab5..18abaeec0 100644 --- a/pkg/cmd/start/start_test.go +++ b/pkg/cmd/start/start_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/store" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/stretchr/testify/assert" ) @@ -60,3 +61,17 @@ func Test_DisplayBC(t *testing.T) { NetworkID: "", }) } + +func TestResolveWorkspaceUserOptions_APIKeyAuthUsesUserDefaultsWithoutUser(t *testing.T) { + got := resolveWorkspaceUserOptions(&store.CreateWorkspacesOptions{}, nil, true) + + assert.Equal(t, store.UserWorkspaceTemplateID, got.WorkspaceTemplateID) + assert.Equal(t, store.UserWorkspaceClassID, got.WorkspaceClassID) +} + +func TestResolveWorkspaceUserOptions_AdminUserUsesDevDefaults(t *testing.T) { + got := resolveWorkspaceUserOptions(&store.CreateWorkspacesOptions{}, &entity.User{GlobalUserType: entity.Admin}, false) + + assert.Equal(t, store.DevWorkspaceTemplateID, got.WorkspaceTemplateID) + assert.Equal(t, store.DevWorkspaceClassID, got.WorkspaceClassID) +} diff --git a/pkg/cmd/stop/stop.go b/pkg/cmd/stop/stop.go index bdf855f2a..0e07b5ed2 100644 --- a/pkg/cmd/stop/stop.go +++ b/pkg/cmd/stop/stop.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/entity" @@ -80,15 +81,21 @@ func NewCmdStop(t *terminal.Terminal, loginStopStore StopStore, noLoginStopStore } func stopAllWorkspaces(t *terminal.Terminal, stopStore StopStore, piped bool) error { - user, err := stopStore.GetCurrentUser() - if err != nil { - return breverrors.WrapAndTrace(err) - } org, err := stopStore.GetActiveOrganizationOrDefault() if err != nil { return breverrors.WrapAndTrace(err) } - workspaces, err := stopStore.GetWorkspaces(org.ID, &store.GetWorkspacesOptions{UserID: user.ID}) + + var workspaces []entity.Workspace + if auth.IsAPIKeyAuthStore(stopStore) { + workspaces, err = stopStore.GetWorkspaces(org.ID, nil) + } else { + user, userErr := stopStore.GetCurrentUser() + if userErr != nil { + return breverrors.WrapAndTrace(userErr) + } + workspaces, err = stopStore.GetWorkspaces(org.ID, &store.GetWorkspacesOptions{UserID: user.ID}) + } if err != nil { return breverrors.WrapAndTrace(err) } @@ -119,9 +126,16 @@ func stopAllWorkspaces(t *terminal.Terminal, stopStore StopStore, piped bool) er } func stopWorkspace(workspaceName string, t *terminal.Terminal, stopStore StopStore, piped bool) error { - user, err := stopStore.GetCurrentUser() - if err != nil { - return breverrors.WrapAndTrace(err) + user := &entity.User{} + apiKeyAuth := false + var err error + if auth.IsAPIKeyAuthStore(stopStore) { + apiKeyAuth = true + } else { + user, err = stopStore.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } } var workspaceID string @@ -141,7 +155,7 @@ func stopWorkspace(workspaceName string, t *terminal.Terminal, stopStore StopSto if !strings.Contains(err3.Error(), "not found") { return breverrors.WrapAndTrace(err3) } else { - if user.GlobalUserType == entity.Admin { + if !apiKeyAuth && user.GlobalUserType == entity.Admin { if !piped { fmt.Println("admin trying to stop any instance") } @@ -150,7 +164,7 @@ func stopWorkspace(workspaceName string, t *terminal.Terminal, stopStore StopSto return breverrors.WrapAndTrace(err) } } else { - return breverrors.WrapAndTrace(err) + return breverrors.WrapAndTrace(err3) } } } diff --git a/pkg/cmd/util/util.go b/pkg/cmd/util/util.go index 8d8413498..3ac0f5fb2 100644 --- a/pkg/cmd/util/util.go +++ b/pkg/cmd/util/util.go @@ -3,18 +3,32 @@ package util import ( "fmt" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/store" ) type GetWorkspaceByNameOrIDErrStore interface { + auth.APIKeyAuthStore GetActiveOrganizationOrDefault() (*entity.Organization, error) GetWorkspaceByNameOrID(orgID string, nameOrID string) ([]entity.Workspace, error) GetCurrentUser() (*entity.User, error) } func GetUserWorkspaceByNameOrIDErr(storeQ GetWorkspaceByNameOrIDErrStore, workspaceNameOrID string) (*entity.Workspace, error) { + if auth.IsAPIKeyAuthStore(storeQ) { + org, err := storeQ.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + workspaces, err := storeQ.GetWorkspaceByNameOrID(org.ID, workspaceNameOrID) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return selectWorkspaceByNameOrID(workspaces, workspaceNameOrID) + } + user, err := storeQ.GetCurrentUser() if err != nil { return nil, breverrors.WrapAndTrace(err) @@ -45,6 +59,10 @@ func GetAnyWorkspaceByIDOrNameInActiveOrgErr(storeQ GetWorkspaceByNameOrIDErrSto return nil, breverrors.WrapAndTrace(err) } + return selectWorkspaceByNameOrID(workspaces, workspaceNameOrID) +} + +func selectWorkspaceByNameOrID(workspaces []entity.Workspace, workspaceNameOrID string) (*entity.Workspace, error) { if len(workspaces) == 0 { return nil, breverrors.NewValidationError(fmt.Sprintf("instance with id/name %s not found", workspaceNameOrID)) } diff --git a/pkg/cmd/util/util_test.go b/pkg/cmd/util/util_test.go new file mode 100644 index 000000000..820c85021 --- /dev/null +++ b/pkg/cmd/util/util_test.go @@ -0,0 +1,70 @@ +package util + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/auth" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockWorkspaceLookupStore struct { + t *testing.T + tokens *entity.AuthTokens + org *entity.Organization + workspaces []entity.Workspace +} + +func (m *mockWorkspaceLookupStore) GetAuthTokens() (*entity.AuthTokens, error) { + return m.tokens, nil +} + +func (m *mockWorkspaceLookupStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.org, nil +} + +func (m *mockWorkspaceLookupStore) GetWorkspaceByNameOrID(_ string, _ string) ([]entity.Workspace, error) { + return m.workspaces, nil +} + +func (m *mockWorkspaceLookupStore) GetCurrentUser() (*entity.User, error) { + m.t.Fatal("api-key workspace lookup must not fetch current user") + return nil, nil +} + +func TestGetUserWorkspaceByNameOrIDErr_APIKeyRejectsAmbiguousMatches(t *testing.T) { + store := &mockWorkspaceLookupStore{ + t: t, + tokens: &entity.AuthTokens{APIKey: auth.BrevAPIKeyPrefix + "test-key"}, + org: &entity.Organization{ID: "org-test"}, + workspaces: []entity.Workspace{ + {ID: "workspace-1", Name: "dev", Status: entity.Running}, + {ID: "workspace-2", Name: "dev", Status: entity.Stopped}, + }, + } + + workspace, err := GetUserWorkspaceByNameOrIDErr(store, "dev") + + require.Error(t, err) + assert.Nil(t, workspace) + assert.Contains(t, err.Error(), "multiple instances found with id/name dev") +} + +func TestGetUserWorkspaceByNameOrIDErr_APIKeyUsesOnlyNonFailedMatch(t *testing.T) { + store := &mockWorkspaceLookupStore{ + t: t, + tokens: &entity.AuthTokens{APIKey: auth.BrevAPIKeyPrefix + "test-key"}, + org: &entity.Organization{ID: "org-test"}, + workspaces: []entity.Workspace{ + {ID: "failed-workspace", Name: "dev", Status: entity.Failure}, + {ID: "running-workspace", Name: "dev", Status: entity.Running}, + }, + } + + workspace, err := GetUserWorkspaceByNameOrIDErr(store, "dev") + + require.NoError(t, err) + require.NotNil(t, workspace) + assert.Equal(t, "running-workspace", workspace.ID) +} diff --git a/pkg/entity/entity.go b/pkg/entity/entity.go index 868f1fee9..e6c163fa6 100644 --- a/pkg/entity/entity.go +++ b/pkg/entity/entity.go @@ -27,6 +27,8 @@ var LegacyWorkspaceGroups = map[string]bool{ type AuthTokens struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` + APIKey string `json:"api_key,omitempty"` + APIKeyOrgID string `json:"api_key_org_id,omitempty"` } type IDEConfig struct { diff --git a/pkg/store/authtoken.go b/pkg/store/authtoken.go index 584dd4c89..c4bc2036a 100644 --- a/pkg/store/authtoken.go +++ b/pkg/store/authtoken.go @@ -19,8 +19,8 @@ const ( ) func (f FileStore) SaveAuthTokens(token entity.AuthTokens) error { - if token.AccessToken == "" { - return fmt.Errorf("access token is empty") + if token.AccessToken == "" && token.APIKey == "" { + return fmt.Errorf("access token and api key are empty") } brevCredentialsFile, err := f.getBrevCredentialsFile() if err != nil { diff --git a/pkg/store/authtoken_test.go b/pkg/store/authtoken_test.go index 72440ea2a..c24e70359 100644 --- a/pkg/store/authtoken_test.go +++ b/pkg/store/authtoken_test.go @@ -1 +1,57 @@ package store + +import ( + "path/filepath" + "testing" + + authpkg "github.com/brevdev/brev-cli/pkg/auth" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testAPIKey = authpkg.BrevAPIKeyPrefix + "test-key" + +func newAuthTokenTestStore(t *testing.T) (*FileStore, afero.Fs, string) { + t.Helper() + fs := afero.NewOsFs() + home := t.TempDir() + s := NewBasicStore().WithFileSystem(fs).WithUserHomeDirGetter( + func() (string, error) { return home, nil }, + ) + return s, fs, home +} + +func TestFileStore_GetAuthTokens_LegacyCredentialsDefaultToJWT(t *testing.T) { + s, fs, home := newAuthTokenTestStore(t) + credentialsPath := filepath.Join(home, ".brev", "credentials.json") + require.NoError(t, fs.MkdirAll(filepath.Dir(credentialsPath), 0o755)) + require.NoError(t, afero.WriteFile(fs, credentialsPath, []byte(`{ + "access_token": "jwt-token", + "refresh_token": "refresh-token" +}`), 0o600)) + + got, err := s.GetAuthTokens() + + require.NoError(t, err) + assert.Equal(t, "jwt-token", got.AccessToken) + assert.Equal(t, "refresh-token", got.RefreshToken) +} + +func TestFileStore_SaveAuthTokens_WritesAPIKey(t *testing.T) { + s, _, _ := newAuthTokenTestStore(t) + + err := s.SaveAuthTokens(entity.AuthTokens{ + AccessToken: "jwt-token", + APIKey: testAPIKey, + APIKeyOrgID: "org-test", + }) + require.NoError(t, err) + + got, err := s.GetAuthTokens() + require.NoError(t, err) + assert.Equal(t, "jwt-token", got.AccessToken) + assert.Equal(t, testAPIKey, got.APIKey) + assert.Equal(t, "org-test", got.APIKeyOrgID) +} diff --git a/pkg/store/memory_auth.go b/pkg/store/memory_auth.go index 5c004df72..de930b917 100644 --- a/pkg/store/memory_auth.go +++ b/pkg/store/memory_auth.go @@ -18,8 +18,8 @@ func NewMemoryAuthStore() *MemoryAuthStore { } func (m *MemoryAuthStore) SaveAuthTokens(tokens entity.AuthTokens) error { - if tokens.AccessToken == "" { - return fmt.Errorf("access token is empty") + if tokens.AccessToken == "" && tokens.APIKey == "" { + return fmt.Errorf("access token and api key are empty") } m.tokens = &tokens return nil diff --git a/pkg/store/organization.go b/pkg/store/organization.go index 920f7efaa..6083cd903 100644 --- a/pkg/store/organization.go +++ b/pkg/store/organization.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/files" @@ -38,8 +39,54 @@ func (f FileStore) ClearDefaultOrganization() error { return nil } +func (f FileStore) GetCachedActiveOrganizationOrNil() (*entity.Organization, error) { + home, err := f.UserHomeDir() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + brevActiveOrgsFile := files.GetActiveOrgsPath(home) + + exists, err := afero.Exists(f.fs, brevActiveOrgsFile) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if !exists { + return nil, nil + } + + var activeOrg entity.Organization + err = files.ReadJSON(f.fs, brevActiveOrgsFile, &activeOrg) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return &activeOrg, nil +} + // returns the 'set'/active organization or nil if not set func (s AuthHTTPStore) GetActiveOrganizationOrNil() (*entity.Organization, error) { + if auth.IsAPIKeyAuthStore(&s) { + orgID, err := auth.GetAPIKeyOrgID(&s) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + org := &entity.Organization{ID: orgID, Name: orgID} + // Name hydration is best-effort; the command itself should surface backend auth errors. + freshOrg, err := s.GetOrganization(orgID) + if err != nil { + return org, nil + } + if freshOrg == nil { + return org, nil + } + if freshOrg.ID == "" { + freshOrg.ID = orgID + } + if freshOrg.Name == "" { + freshOrg.Name = freshOrg.ID + } + return freshOrg, nil + } + workspaceID, err := s.GetCurrentWorkspaceID() if err != nil { return nil, breverrors.WrapAndTrace(err) @@ -58,26 +105,14 @@ func (s AuthHTTPStore) GetActiveOrganizationOrNil() (*entity.Organization, error return org, nil } - home, err := s.UserHomeDir() - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - brevActiveOrgsFile := files.GetActiveOrgsPath(home) - - exists, err := afero.Exists(s.fs, brevActiveOrgsFile) + activeOrg, err := s.GetCachedActiveOrganizationOrNil() if err != nil { return nil, breverrors.WrapAndTrace(err) } - if !exists { + if activeOrg == nil { return nil, nil } - var activeOrg entity.Organization - err = files.ReadJSON(s.fs, brevActiveOrgsFile, &activeOrg) - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - freshOrg, err := s.GetOrganization(activeOrg.ID) if err != nil { if !IsNetwork404Or403Error(err) { // handle because can login with bad cache diff --git a/pkg/store/organization_test.go b/pkg/store/organization_test.go index c25566004..bb161e192 100644 --- a/pkg/store/organization_test.go +++ b/pkg/store/organization_test.go @@ -4,10 +4,12 @@ import ( "fmt" "testing" + authpkg "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/entity" "github.com/jarcoal/httpmock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetActiveOrganization(t *testing.T) { @@ -50,6 +52,58 @@ func TestGetOrganizations(t *testing.T) { } } +func TestGetActiveOrganization_APIKeyUsesCredentialOrg(t *testing.T) { + apiKey := authpkg.BrevAPIKeyPrefix + "test-key" + fileStore, _, _ := newAuthTokenTestStore(t) + s := fileStore.WithAuthHTTPClient(NewAuthHTTPClient(MockAuth{token: &apiKey}, "https://api.test")) + httpmock.ActivateNonDefault(s.authHTTPClient.restyClient.GetClient()) + defer httpmock.DeactivateAndReset() + + require.NoError(t, s.SaveAuthTokens(entity.AuthTokens{ + APIKey: apiKey, + APIKeyOrgID: "org-api-key", + })) + require.NoError(t, s.SetDefaultOrganization(&entity.Organization{ + ID: "org-cached", + Name: "cached-org", + })) + + org, err := s.GetActiveOrganizationOrDefault() + + require.NoError(t, err) + require.NotNil(t, org) + assert.Equal(t, "org-api-key", org.ID) + assert.Equal(t, "org-api-key", org.Name) +} + +func TestGetActiveOrganization_APIKeyUsesCredentialOrgNameWhenAvailable(t *testing.T) { + apiKey := authpkg.BrevAPIKeyPrefix + "test-key" + fileStore, _, _ := newAuthTokenTestStore(t) + s := fileStore.WithAuthHTTPClient(NewAuthHTTPClient(MockAuth{token: &apiKey}, "https://api.test")) + httpmock.ActivateNonDefault(s.authHTTPClient.restyClient.GetClient()) + defer httpmock.DeactivateAndReset() + + require.NoError(t, s.SaveAuthTokens(entity.AuthTokens{ + APIKey: apiKey, + APIKeyOrgID: "org-api-key", + })) + expected := &entity.Organization{ + ID: "org-api-key", + Name: "friendly-org", + } + res, err := httpmock.NewJsonResponder(200, expected) + require.NoError(t, err) + url := fmt.Sprintf("%s/%s", s.authHTTPClient.restyClient.BaseURL, fmt.Sprintf(orgIDPathPattern, "org-api-key")) + httpmock.RegisterResponder("GET", url, res) + + org, err := s.GetActiveOrganizationOrDefault() + + require.NoError(t, err) + require.NotNil(t, org) + assert.Equal(t, expected.ID, org.ID) + assert.Equal(t, expected.Name, org.Name) +} + func TestCreateOrganization(t *testing.T) { fs := MakeMockAuthHTTPStore() httpmock.ActivateNonDefault(fs.authHTTPClient.restyClient.GetClient()) diff --git a/pkg/store/workspace.go b/pkg/store/workspace.go index cefe77e84..35f5d4f11 100644 --- a/pkg/store/workspace.go +++ b/pkg/store/workspace.go @@ -6,6 +6,7 @@ import ( "io/ioutil" "strings" + "github.com/brevdev/brev-cli/pkg/auth" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" @@ -343,6 +344,10 @@ func FilterNonFailedWorkspaces(workspaces []entity.Workspace) []entity.Workspace } func (s AuthHTTPStore) GetWorkspaceByNameOrID(orgID string, nameOrID string) ([]entity.Workspace, error) { + if auth.IsAPIKeyAuthStore(&s) { + return s.GetWorkspaces(orgID, &GetWorkspacesOptions{Name: nameOrID}) + } + // pretty srue we always want to filter for workspaces owned by the user user, err := s.GetCurrentUser() if err != nil { @@ -365,6 +370,10 @@ func (s AuthHTTPStore) GetContextWorkspaces() ([]entity.Workspace, error) { if err != nil { return nil, breverrors.WrapAndTrace(err) } + if auth.IsAPIKeyAuthStore(&s) { + return s.GetWorkspaces(org.ID, nil) + } + user, err := s.GetCurrentUser() if err != nil { return nil, breverrors.WrapAndTrace(err) From d71c2d41a65f3fae413706cdc0abadb9d5caea56 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Wed, 6 May 2026 13:49:11 -0700 Subject: [PATCH 2/2] review feedback --- pkg/auth/auth.go | 29 ++++++++++++++++ pkg/auth/auth_test.go | 62 +++++++++++++++++++++++++++++++++ pkg/cmd/ls/ls.go | 81 +++++++++++++++++++++---------------------- pkg/cmd/ls/ls_test.go | 44 +++++++++++++++-------- 4 files changed, 160 insertions(+), 56 deletions(-) diff --git a/pkg/auth/auth.go b/pkg/auth/auth.go index 5c94bc328..66344ceee 100644 --- a/pkg/auth/auth.go +++ b/pkg/auth/auth.go @@ -109,6 +109,35 @@ type APIKeyAuthStore interface { GetAuthTokens() (*entity.AuthTokens, error) } +type CurrentUserAuthStore interface { + APIKeyAuthStore + GetCurrentUser() (*entity.User, error) +} + +type CLIAuth struct { + apiKey bool + user *entity.User +} + +func (a CLIAuth) IsAPIKey() bool { + return a.apiKey +} + +func (a CLIAuth) User() *entity.User { + return a.user +} + +func ResolveCLIAuth(store CurrentUserAuthStore) (CLIAuth, error) { + if IsAPIKeyAuthStore(store) { + return CLIAuth{apiKey: true}, nil + } + user, err := store.GetCurrentUser() + if err != nil { + return CLIAuth{}, breverrors.WrapAndTrace(err) + } + return CLIAuth{user: user}, nil +} + func IsBrevAPIKey(token string) bool { return strings.HasPrefix(strings.TrimSpace(token), BrevAPIKeyPrefix) } diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index bdfedbe12..4c271b0b9 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -130,6 +130,68 @@ func TestIsAPIKeyAuthStore_LegacyCredentialsAreNotAPIKeyAuth(t *testing.T) { assert.False(t, s.getAccessTokenCalled) } +type cliAuthStore struct { + tokens *entity.AuthTokens + user *entity.User + currentUserErr error + currentUserCalls int +} + +func (s *cliAuthStore) GetAuthTokens() (*entity.AuthTokens, error) { + return s.tokens, nil +} + +func (s *cliAuthStore) GetCurrentUser() (*entity.User, error) { + s.currentUserCalls++ + if s.currentUserErr != nil { + return nil, s.currentUserErr + } + return s.user, nil +} + +func TestResolveCLIAuth_APIKeySkipsCurrentUser(t *testing.T) { + s := &cliAuthStore{ + tokens: &entity.AuthTokens{APIKey: testAPIKey}, + user: &entity.User{ID: "user-test"}, + } + + cliAuth, err := ResolveCLIAuth(s) + + assert.NoError(t, err) + assert.True(t, cliAuth.IsAPIKey()) + assert.Nil(t, cliAuth.User()) + assert.Equal(t, 0, s.currentUserCalls) +} + +func TestResolveCLIAuth_LegacyCredentialsFetchCurrentUser(t *testing.T) { + user := &entity.User{ID: "user-test"} + s := &cliAuthStore{ + tokens: &entity.AuthTokens{AccessToken: validToken}, + user: user, + } + + cliAuth, err := ResolveCLIAuth(s) + + assert.NoError(t, err) + assert.False(t, cliAuth.IsAPIKey()) + assert.Equal(t, user, cliAuth.User()) + assert.Equal(t, 1, s.currentUserCalls) +} + +func TestResolveCLIAuth_CurrentUserErrorReturnsError(t *testing.T) { + s := &cliAuthStore{ + tokens: &entity.AuthTokens{AccessToken: validToken}, + currentUserErr: breverrors.NewValidationError("current user failed"), + } + + cliAuth, err := ResolveCLIAuth(s) + + assert.Error(t, err) + assert.False(t, cliAuth.IsAPIKey()) + assert.Nil(t, cliAuth.User()) + assert.Equal(t, 1, s.currentUserCalls) +} + func TestGetFreshAccessTokenOrNil_APIKeySkipsJWTValidationAndRefresh(t *testing.T) { s := MockAuthStore{authTokens: &entity.AuthTokens{ AccessToken: "expired-jwt", diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index 532595770..0491215a5 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -88,12 +88,16 @@ with other commands like stop, start, or delete.`, Args: cmderrors.TransformToValidationError(cobra.MinimumNArgs(0)), ValidArgs: []string{"orgs", "workspaces", "nodes", "instances"}, RunE: func(cmd *cobra.Command, args []string) error { - err := RunLs(t, loginLsStore, args, org, showAll, jsonOutput) + cliAuth, err := auth.ResolveCLIAuth(loginLsStore) + if err != nil { + return breverrors.WrapAndTrace(err) + } + err = RunLs(t, cliAuth, loginLsStore, args, org, showAll, jsonOutput) if err != nil { return breverrors.WrapAndTrace(err) } if !jsonOutput { - trackLsAnalytics(loginLsStore) + trackLsAnalytics(cliAuth) } return nil }, @@ -113,13 +117,10 @@ with other commands like stop, start, or delete.`, } // trackLsAnalytics sends analytics event for ls command -func trackLsAnalytics(store LsStore) { +func trackLsAnalytics(cliAuth auth.CLIAuth) { userID := "" - if !isAPIKeyAuthStore(store) { - user, err := store.GetCurrentUser() - if err == nil { - userID = user.ID - } + if !cliAuth.IsAPIKey() && cliAuth.User() != nil { + userID = cliAuth.User().ID } data := analytics.EventData{ EventName: "Brev ls", @@ -128,13 +129,9 @@ func trackLsAnalytics(store LsStore) { _ = analytics.TrackEvent(data) } -func isAPIKeyAuthStore(lsStore LsStore) bool { - return auth.IsAPIKeyAuthStore(lsStore) -} - -func getOrgForRunLs(lsStore LsStore, orgflag string, apiKeyAuth bool) (*entity.Organization, error) { +func getOrgForRunLs(cliAuth auth.CLIAuth, lsStore LsStore, orgflag string) (*entity.Organization, error) { var org *entity.Organization - if apiKeyAuth { + if cliAuth.IsAPIKey() { if orgflag != "" { return nil, breverrors.NewValidationError("api key auth is scoped to the org saved during login; --org is not supported") } @@ -175,20 +172,10 @@ func getOrgForRunLs(lsStore LsStore, orgflag string, apiKeyAuth bool) (*entity.O return org, nil } -func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, showAll bool, jsonOutput bool) error { +func RunLs(t *terminal.Terminal, cliAuth auth.CLIAuth, lsStore LsStore, args []string, orgflag string, showAll bool, jsonOutput bool) error { ls := NewLs(lsStore, t, jsonOutput) - apiKeyAuth := isAPIKeyAuthStore(lsStore) - var user *entity.User - if !apiKeyAuth { - var err error - user, err = lsStore.GetCurrentUser() - if err != nil { - return breverrors.WrapAndTrace(err) - } - } - - org, err := getOrgForRunLs(lsStore, orgflag, apiKeyAuth) + org, err := getOrgForRunLs(cliAuth, lsStore, orgflag) if err != nil { return breverrors.WrapAndTrace(err) } @@ -197,12 +184,12 @@ func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, } if len(args) == 1 { //nolint:gocritic // don't want to switch - err = handleLsArg(ls, args[0], user, org, showAll, apiKeyAuth) + err = handleLsArg(ls, cliAuth, args[0], org, showAll) if err != nil { return breverrors.WrapAndTrace(err) } } else if len(args) == 0 { - err = ls.RunWorkspaces(org, user, showAll) + err = ls.RunWorkspaces(cliAuth, org, showAll) if err != nil { return breverrors.WrapAndTrace(err) } @@ -213,27 +200,27 @@ func RunLs(t *terminal.Terminal, lsStore LsStore, args []string, orgflag string, return nil } -func handleLsArg(ls *Ls, arg string, user *entity.User, org *entity.Organization, showAll bool, apiKeyAuth bool) error { +func handleLsArg(ls *Ls, cliAuth auth.CLIAuth, arg string, org *entity.Organization, showAll bool) error { switch classifyLsArg(arg) { case lsArgOrgs: - if apiKeyAuth { + if cliAuth.IsAPIKey() { return breverrors.NewValidationError("api key auth cannot list organizations") } return wrapLsRun(ls.RunOrgs()) case lsArgWorkspaces: - return wrapLsRun(ls.RunWorkspaces(org, user, showAll)) + return wrapLsRun(ls.RunWorkspaces(cliAuth, org, showAll)) case lsArgUsers: - return runAdminLsArg(user, apiKeyAuth, "users", func() error { + return runAdminLsArg(cliAuth, "users", func() error { return ls.RunUser(showAll) }) case lsArgHosts: - return runAdminLsArg(user, apiKeyAuth, "hosts", func() error { + return runAdminLsArg(cliAuth, "hosts", func() error { return ls.RunHosts(org) }) case lsArgNodes: return wrapLsRun(ls.RunNodes(org)) case lsArgInstances: - return wrapLsRun(ls.RunInstances(org, user, showAll)) + return wrapLsRun(ls.RunInstances(cliAuth, org, showAll)) default: return nil } @@ -270,10 +257,14 @@ func classifyLsArg(arg string) lsArgKind { } } -func runAdminLsArg(user *entity.User, apiKeyAuth bool, resource string, run func() error) error { - if apiKeyAuth { +func runAdminLsArg(cliAuth auth.CLIAuth, resource string, run func() error) error { + if cliAuth.IsAPIKey() { return breverrors.NewValidationError(fmt.Sprintf("api key auth cannot list %s", resource)) } + user := cliAuth.User() + if user == nil { + return breverrors.NewValidationError("user is required") + } if !featureflag.IsAdmin(user.GlobalUserType) { return nil } @@ -484,7 +475,7 @@ func buildGPULookup(s LsStore) map[string]string { return lookup } -func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll bool) error { +func (ls Ls) RunWorkspaces(cliAuth auth.CLIAuth, org *entity.Organization, showAll bool) error { // Fetch workspaces and instance types concurrently var allWorkspaces []entity.Workspace var wsErr error @@ -527,9 +518,13 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll switch { case showAll: workspacesToShow = allWorkspaces - case user == nil: + case cliAuth.IsAPIKey(): workspacesToShow = allWorkspaces default: + user := cliAuth.User() + if user == nil { + return breverrors.NewValidationError("user is required") + } workspacesToShow = store.FilterForUserWorkspaces(allWorkspaces, user.ID) } @@ -538,10 +533,14 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll return ls.outputWorkspacesJSON(workspacesToShow, gpuLookup, nodes) } - if user == nil { + if cliAuth.IsAPIKey() { ls.ShowOrgWorkspaces(org, workspacesToShow, gpuLookup) return nil } + user := cliAuth.User() + if user == nil { + return breverrors.NewValidationError("user is required") + } // Table output with colors and help text orgs, err := ls.lsStore.GetOrganizations(nil) @@ -561,8 +560,8 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll return nil } -func (ls Ls) RunInstances(org *entity.Organization, user *entity.User, showAll bool) error { - if err := ls.RunWorkspaces(org, user, showAll); err != nil { +func (ls Ls) RunInstances(cliAuth auth.CLIAuth, org *entity.Organization, showAll bool) error { + if err := ls.RunWorkspaces(cliAuth, org, showAll); err != nil { return err } return nil diff --git a/pkg/cmd/ls/ls_test.go b/pkg/cmd/ls/ls_test.go index 4613cb360..0e176f8c9 100644 --- a/pkg/cmd/ls/ls_test.go +++ b/pkg/cmd/ls/ls_test.go @@ -97,6 +97,20 @@ func newTestStore() *mockLsStore { } } +func resolveTestCLIAuth(t *testing.T, s *mockLsStore) authpkg.CLIAuth { + t.Helper() + cliAuth, err := authpkg.ResolveCLIAuth(s) + if err != nil { + t.Fatalf("ResolveCLIAuth returned error: %v", err) + } + return cliAuth +} + +func runLs(t *testing.T, term *terminal.Terminal, s *mockLsStore, args []string, showAll bool) error { + t.Helper() + return RunLs(term, resolveTestCLIAuth(t, s), s, args, "", showAll, true) +} + func TestRunLs_APIKeyJSONSkipsUserAndOrgList(t *testing.T) { s := newTestStore() s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org1"} @@ -117,7 +131,7 @@ func TestRunLs_APIKeyJSONSkipsUserAndOrgList(t *testing.T) { term := terminal.New() out := captureStdout(t, func() { - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err != nil { t.Fatalf("RunLs returned error: %v", err) } @@ -155,7 +169,7 @@ func TestRunLs_APIKeyUsesCredentialOrgNotCachedActiveOrg(t *testing.T) { term := terminal.New() out := captureStdout(t, func() { - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err != nil { t.Fatalf("RunLs returned error: %v", err) } @@ -174,7 +188,7 @@ func TestGetOrgForRunLs_APIKeyUsesActiveOrgDisplayName(t *testing.T) { s.authTokens = &entity.AuthTokens{APIKey: testAPIKey, APIKeyOrgID: "org-login"} s.org = &entity.Organization{ID: "org-login", Name: "friendly-org"} - org, err := getOrgForRunLs(s, "", true) + org, err := getOrgForRunLs(resolveTestCLIAuth(t, s), s, "") if err != nil { t.Fatalf("getOrgForRunLs returned error: %v", err) } @@ -191,7 +205,7 @@ func TestRunLs_APIKeyRequiresCredentialOrg(t *testing.T) { s.authTokens = &entity.AuthTokens{APIKey: testAPIKey} term := terminal.New() - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err == nil { t.Fatal("expected missing API key org error, got nil") @@ -251,7 +265,7 @@ func TestRunLs_DefaultJSON(t *testing.T) { term := terminal.New() out := captureStdout(t, func() { - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err != nil { t.Fatalf("RunLs returned error: %v", err) } @@ -299,7 +313,7 @@ func TestRunLs_DefaultJSON_Empty(t *testing.T) { term := terminal.New() out := captureStdout(t, func() { - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err != nil { t.Fatalf("RunLs returned error: %v", err) } @@ -332,7 +346,7 @@ func TestRunLs_InstancesJSON(t *testing.T) { term := terminal.New() out := captureStdout(t, func() { - err := RunLs(term, s, []string{"instances"}, "", false, true) + err := runLs(t, term, s, []string{"instances"}, false) if err != nil { t.Fatalf("RunLs instances returned error: %v", err) } @@ -363,7 +377,7 @@ func TestRunLs_OrgsJSON(t *testing.T) { term := terminal.New() out := captureStdout(t, func() { - err := RunLs(term, s, []string{"orgs"}, "", false, true) + err := runLs(t, term, s, []string{"orgs"}, false) if err != nil { t.Fatalf("RunLs orgs returned error: %v", err) } @@ -410,7 +424,7 @@ func TestRunLs_ShowAllJSON(t *testing.T) { // Without --all: only my workspaces outMine := captureStdout(t, func() { - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err != nil { t.Fatalf("RunLs returned error: %v", err) } @@ -431,7 +445,7 @@ func TestRunLs_ShowAllJSON(t *testing.T) { // With --all: output is always {workspaces, nodes} object. // Nodes fetch fails gracefully (no gRPC server), so nodes will be empty. outAll := captureStdout(t, func() { - err := RunLs(term, s, nil, "", true, true) + err := runLs(t, term, s, nil, true) if err != nil { t.Fatalf("RunLs --all returned error: %v", err) } @@ -456,7 +470,7 @@ func TestRunLs_TooManyArgs(t *testing.T) { s := newTestStore() term := terminal.New() - err := RunLs(term, s, []string{"instances", "nodes"}, "", false, true) + err := runLs(t, term, s, []string{"instances", "nodes"}, false) if err == nil { t.Fatal("expected error for too many args, got nil") } @@ -480,7 +494,7 @@ func TestRunLs_WorkspaceGPULookup(t *testing.T) { // The mock returns empty InstanceTypesResponse, so GPU should be "-". out := captureStdout(t, func() { - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err != nil { t.Fatalf("RunLs returned error: %v", err) } @@ -517,7 +531,7 @@ func TestRunLs_UnhealthyStatus(t *testing.T) { term := terminal.New() out := captureStdout(t, func() { - err := RunLs(term, s, nil, "", false, true) + err := runLs(t, term, s, nil, false) if err != nil { t.Fatalf("RunLs returned error: %v", err) } @@ -548,7 +562,7 @@ func TestHandleLsArg_Routing(t *testing.T) { "workspace", "workspaces", } for _, arg := range successArgs { - if err := handleLsArg(ls, arg, s.user, s.org, false, false); err != nil { + if err := handleLsArg(ls, resolveTestCLIAuth(t, s), arg, s.org, false); err != nil { t.Errorf("handleLsArg(%q) returned unexpected error: %v", arg, err) } } @@ -556,7 +570,7 @@ func TestHandleLsArg_Routing(t *testing.T) { // "node"/"nodes" route to RunNodes which calls the gRPC client — verify // it attempts the path (error expected due to no real client). for _, arg := range []string{"node", "nodes"} { - _ = handleLsArg(ls, arg, s.user, s.org, false, false) + _ = handleLsArg(ls, resolveTestCLIAuth(t, s), arg, s.org, false) } }