Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ type Auth struct {
shouldLogin func() (bool, error)
}

const BrevAPIKeyPrefix = "bak-"

const MissingAPIKeyOrgIDMessage = "api key auth requires an org id; run brev login --api-key <api-key> --org-id <org-id>"

type APIKeyAuthStore interface {
GetAuthTokens() (*entity.AuthTokens, error)
}

func IsBrevAPIKey(token string) bool {
return strings.HasPrefix(strings.TrimSpace(token), BrevAPIKeyPrefix)
}

func IsAPIKeyAuthStore(authTokensProvider APIKeyAuthStore) bool {
tokens, err := authTokensProvider.GetAuthTokens()
if err != nil {
return false
}
if tokens == nil {
return false
}
return IsBrevAPIKey(tokens.APIKey)
}

func GetAPIKeyOrgID(authTokensProvider APIKeyAuthStore) (string, error) {
tokens, err := authTokensProvider.GetAuthTokens()
if err != nil {
return "", breverrors.WrapAndTrace(err)
}
if tokens == nil {
return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
}
orgID := strings.TrimSpace(tokens.APIKeyOrgID)
if orgID == "" {
return "", breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
}
return orgID, nil
}

func NewAuth(authStore AuthStore, oauth OAuth) *Auth {
return &Auth{
authStore: authStore,
Expand Down Expand Up @@ -146,6 +184,11 @@ func (t Auth) GetFreshAccessTokenOrNil() (string, error) {
return "", nil
}

apiKey := strings.TrimSpace(tokens.APIKey)
if apiKey != "" {
return apiKey, nil
}

// should always at least have access token?
if tokens.AccessToken == "" {
breverrors.GetDefaultErrorReporter().ReportMessage("access token is an empty string but shouldn't be")
Expand Down Expand Up @@ -222,6 +265,36 @@ func (t Auth) LoginWithToken(token string) error {
return nil
}

func (t Auth) LoginWithAPIKey(apiKey string, orgID string) error {
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return breverrors.NewValidationError("api key is empty")
}
if !IsBrevAPIKey(apiKey) {
return breverrors.NewValidationError(fmt.Sprintf("api key must start with %s", BrevAPIKeyPrefix))
}
orgID = strings.TrimSpace(orgID)
if orgID == "" {
return breverrors.NewValidationError(MissingAPIKeyOrgIDMessage)
}

tokens, err := t.getSavedTokensOrNil()
if err != nil {
return breverrors.WrapAndTrace(err)
}
if tokens == nil {
tokens = &entity.AuthTokens{}
}
tokens.APIKey = apiKey
tokens.APIKeyOrgID = orgID

err = t.authStore.SaveAuthTokens(*tokens)
if err != nil {
return breverrors.WrapAndTrace(err)
}
return nil
}

// showLoginURL displays the login link and CLI alternative for manual navigation.
func showLoginURL(url string) {
urlType := color.New(color.FgCyan, color.Bold).SprintFunc()
Expand Down Expand Up @@ -313,7 +386,7 @@ func (t Auth) getSavedTokensOrNil() (*entity.AuthTokens, error) {
}
return nil, breverrors.WrapAndTrace(err)
}
if tokens != nil && tokens.AccessToken == "" && tokens.RefreshToken == "" {
if tokens != nil && tokens.AccessToken == "" && tokens.RefreshToken == "" && tokens.APIKey == "" {
return nil, nil
}
return tokens, nil
Expand Down Expand Up @@ -415,7 +488,7 @@ func AuthProviderFlagToCredentialProvider(authProviderFlag string) entity.Creden
func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens) OAuth {
// Set KAS as the default authenticator
shouldPromptEmail := false
if email == "" && tokens != nil && tokens.AccessToken != "" {
if email == "" && tokens != nil && tokens.AccessToken != "" && tokens.APIKey == "" {
email = GetEmailFromToken(tokens.AccessToken)
shouldPromptEmail = true
}
Expand Down Expand Up @@ -445,7 +518,7 @@ func StandardLogin(authProvider string, email string, tokens *entity.AuthTokens)
kasAuthenticator,
})

if tokens != nil && tokens.AccessToken != "" {
if tokens != nil && tokens.AccessToken != "" && tokens.APIKey == "" {
authenticatorFromToken, errr := authRetriever.GetByToken(tokens.AccessToken)
if errr != nil {
fmt.Printf("%v\n", errr)
Expand Down
182 changes: 180 additions & 2 deletions pkg/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package auth

import (
"io"
"os"
"testing"

"github.com/brevdev/brev-cli/pkg/entity"
breverrors "github.com/brevdev/brev-cli/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)

Expand Down Expand Up @@ -38,10 +41,12 @@ func TestIsAccessTokenValid(t *testing.T) {

type MockAuthStore struct {
authTokens *entity.AuthTokens
saved entity.AuthTokens
didSave bool
}

func (m *MockAuthStore) SaveAuthTokens(_ entity.AuthTokens) error {
func (m *MockAuthStore) SaveAuthTokens(tokens entity.AuthTokens) error {
m.saved = tokens
m.didSave = true
return nil
}
Expand Down Expand Up @@ -77,7 +82,180 @@ func (m MockOauth) GetNewAuthTokensWithRefresh(_ string) (*entity.AuthTokens, er
return m.authTokens, nil
}

const validToken = "abc"
const (
validToken = "abc"
testAPIKey = BrevAPIKeyPrefix + "test-key"
)

func TestIsBrevAPIKey(t *testing.T) {
assert.True(t, IsBrevAPIKey(testAPIKey))
assert.True(t, IsBrevAPIKey(" "+testAPIKey+" "))
assert.False(t, IsBrevAPIKey("bakery-token"))
assert.False(t, IsBrevAPIKey("jwt-token"))
assert.False(t, IsBrevAPIKey(""))
}

type sideEffectingTokenStore struct {
tokens *entity.AuthTokens
getAccessTokenCalled bool
}

func (s *sideEffectingTokenStore) GetAuthTokens() (*entity.AuthTokens, error) {
return s.tokens, nil
}

func (s *sideEffectingTokenStore) GetAccessToken() (string, error) {
s.getAccessTokenCalled = true
return testAPIKey, nil
}

func TestIsAPIKeyAuthStore_ReadsSavedTokensWithoutAccessTokenSideEffects(t *testing.T) {
s := &sideEffectingTokenStore{
tokens: &entity.AuthTokens{APIKey: testAPIKey},
}

assert.True(t, IsAPIKeyAuthStore(s))
assert.False(t, s.getAccessTokenCalled)
}

func TestIsAPIKeyAuthStore_LegacyCredentialsAreNotAPIKeyAuth(t *testing.T) {
s := &sideEffectingTokenStore{
tokens: &entity.AuthTokens{
AccessToken: validToken,
RefreshToken: "refresh",
},
}

assert.False(t, IsAPIKeyAuthStore(s))
assert.False(t, s.getAccessTokenCalled)
}

func TestGetFreshAccessTokenOrNil_APIKeySkipsJWTValidationAndRefresh(t *testing.T) {
s := MockAuthStore{authTokens: &entity.AuthTokens{
AccessToken: "expired-jwt",
APIKey: testAPIKey,
RefreshToken: "should-not-refresh",
}}
a := Auth{
&s,
&MockOauth{}, func(_ string) (bool, error) {
t.Fatal("api keys must not be parsed as JWTs")
return false, nil
},
func() (bool, error) {
t.Fatal("api keys must not trigger login")
return false, nil
},
}

res, err := a.GetFreshAccessTokenOrNil()
assert.NoError(t, err)
assert.Equal(t, testAPIKey, res)
assert.False(t, s.didSave)
}

func TestGetFreshAccessTokenOrNil_APIKeyOnlyCredentialReturnsAPIKey(t *testing.T) {
s := MockAuthStore{authTokens: &entity.AuthTokens{
APIKey: testAPIKey,
}}
a := Auth{
&s,
&MockOauth{}, func(_ string) (bool, error) {
t.Fatal("api keys must not be parsed as JWTs")
return false, nil
},
func() (bool, error) {
t.Fatal("api keys must not trigger login")
return false, nil
},
}

res, err := a.GetFreshAccessTokenOrNil()
assert.NoError(t, err)
assert.Equal(t, testAPIKey, res)
assert.False(t, s.didSave)
}

func TestLoginWithAPIKey_SavesTypedCredential(t *testing.T) {
s := MockAuthStore{}
a := Auth{
authStore: &s,
oauth: &MockOauth{},
}

err := a.LoginWithAPIKey(testAPIKey, "org-test")
assert.NoError(t, err)
assert.True(t, s.didSave)
assert.Equal(t, entity.AuthTokens{
APIKey: testAPIKey,
APIKeyOrgID: "org-test",
}, s.saved)
}

func TestLoginWithAPIKey_PreservesExistingJWT(t *testing.T) {
s := MockAuthStore{authTokens: &entity.AuthTokens{
AccessToken: "existing-jwt",
RefreshToken: "existing-refresh",
}}
a := Auth{
authStore: &s,
oauth: &MockOauth{},
}

err := a.LoginWithAPIKey(testAPIKey, "org-test")
assert.NoError(t, err)
assert.Equal(t, entity.AuthTokens{
AccessToken: "existing-jwt",
RefreshToken: "existing-refresh",
APIKey: testAPIKey,
APIKeyOrgID: "org-test",
}, s.saved)
}

func TestLoginWithAPIKey_EmptyKeyReturnsError(t *testing.T) {
s := MockAuthStore{}
a := Auth{
authStore: &s,
oauth: &MockOauth{},
}

err := a.LoginWithAPIKey("", "org-test")
assert.Error(t, err)
assert.False(t, s.didSave)
}

func TestLoginWithAPIKey_EmptyOrgIDReturnsError(t *testing.T) {
s := MockAuthStore{}
a := Auth{
authStore: &s,
oauth: &MockOauth{},
}

err := a.LoginWithAPIKey(testAPIKey, "")
assert.Error(t, err)
assert.False(t, s.didSave)
}

func TestStandardLogin_APIKeyCredentialDoesNotProbeOAuthProviders(t *testing.T) {
oldStdout := os.Stdout
t.Cleanup(func() {
os.Stdout = oldStdout
})
readPipe, writePipe, err := os.Pipe()
require.NoError(t, err)
os.Stdout = writePipe

_ = StandardLogin("", "", &entity.AuthTokens{
AccessToken: "existing-jwt",
APIKey: testAPIKey,
})

assert.NoError(t, writePipe.Close())
os.Stdout = oldStdout
out, err := io.ReadAll(readPipe)
assert.NoError(t, err)
assert.Empty(t, string(out))
}

func TestSuccessNoRefreshGetFreshAccessTokenOrLogin(t *testing.T) {
s := MockAuthStore{authTokens: &entity.AuthTokens{
Expand Down
24 changes: 17 additions & 7 deletions pkg/cmd/completions/completions.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package completions

import (
"github.com/brevdev/brev-cli/pkg/auth"
"github.com/brevdev/brev-cli/pkg/entity"
"github.com/brevdev/brev-cli/pkg/store"
"github.com/brevdev/brev-cli/pkg/terminal"
"github.com/spf13/cobra"
)

type CompletionStore interface {
auth.APIKeyAuthStore
GetWorkspaces(organizationID string, options *store.GetWorkspacesOptions) ([]entity.Workspace, error)
GetActiveOrganizationOrDefault() (*entity.Organization, error)
GetCurrentUser() (*entity.User, error)
Expand All @@ -18,12 +20,6 @@ type CompletionHandler func(cmd *cobra.Command, args []string, toComplete string

func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *terminal.Terminal) CompletionHandler {
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
user, err := completionStore.GetCurrentUser()
if err != nil {
t.Errprint(err, "")
return nil, cobra.ShellCompDirectiveError
}

org, err := completionStore.GetActiveOrganizationOrDefault()
if err != nil {
t.Errprint(err, "")
Expand All @@ -33,7 +29,17 @@ func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *te
return []string{}, cobra.ShellCompDirectiveDefault
}

workspaces, err := completionStore.GetWorkspaces(org.ID, &store.GetWorkspacesOptions{UserID: user.ID})
var options *store.GetWorkspacesOptions
if !auth.IsAPIKeyAuthStore(completionStore) {
user, err := completionStore.GetCurrentUser()
if err != nil {
t.Errprint(err, "")
return nil, cobra.ShellCompDirectiveError
}
options = &store.GetWorkspacesOptions{UserID: user.ID}
}

workspaces, err := completionStore.GetWorkspaces(org.ID, options)
if err != nil {
t.Errprint(err, "")
return nil, cobra.ShellCompDirectiveError
Expand All @@ -50,6 +56,10 @@ func GetAllWorkspaceNameCompletionHandler(completionStore CompletionStore, t *te

func GetOrgsNameCompletionHandler(completionStore CompletionStore, t *terminal.Terminal) CompletionHandler {
return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if auth.IsAPIKeyAuthStore(completionStore) {
return []string{}, cobra.ShellCompDirectiveNoFileComp
}

orgs, err := completionStore.GetOrganizations(nil)
if err != nil {
t.Errprint(err, "")
Expand Down
Loading
Loading