From 6c10eaa463169e898015eb476deb98897ddd4904 Mon Sep 17 00:00:00 2001 From: Qasim Date: Sat, 4 Apr 2026 22:38:59 -0400 Subject: [PATCH] TW-4787 Fix CLI auth and secret-store regressions --- .../adapters/keyring/crossplatform_test.go | 265 ++++++++- internal/adapters/keyring/file.go | 182 +++++- internal/adapters/keyring/keyring.go | 8 +- internal/adapters/keyring/keyring_test.go | 25 +- internal/adapters/mcp/proxy.go | 1 + internal/adapters/mcp/proxy_sse_test.go | 16 + internal/adapters/nylas/auth.go | 17 +- internal/adapters/nylas/auth_test.go | 37 +- internal/adapters/nylas/client.go | 29 +- internal/adapters/nylas/client_base_test.go | 10 +- .../nylas/client_mock_methods_test.go | 2 +- internal/adapters/nylas/demo/base.go | 4 +- internal/adapters/nylas/demo/base_test.go | 14 +- internal/adapters/nylas/demo_client.go | 4 +- internal/adapters/nylas/mock_client.go | 17 +- internal/adapters/nylas/mock_grants.go | 6 +- internal/adapters/nylas/security_test.go | 2 +- internal/adapters/oauth/mock.go | 4 +- internal/adapters/oauth/server.go | 38 +- .../adapters/oauth/server_integration_test.go | 73 +++ internal/adapters/oauth/server_test.go | 88 ++- internal/adapters/webhookserver/server.go | 22 +- .../adapters/webhookserver/server_test.go | 59 ++ internal/app/auth/service.go | 128 ++++- internal/app/auth/service_test.go | 116 +++- internal/cli/auth/providers.go | 88 ++- internal/cli/auth/providers_test.go | 96 ++++ internal/cli/auth/remove.go | 7 +- internal/cli/common/client.go | 144 +++-- internal/cli/common/client_test.go | 187 +++++-- internal/cli/common/errors.go | 11 + internal/cli/common/errors_test.go | 34 ++ internal/cli/doctor.go | 12 +- internal/cli/doctor_test.go | 31 ++ .../cli/integration/local_regressions_test.go | 518 ++++++++++++++++++ internal/cli/integration/test.go | 114 ++-- internal/ports/auth.go | 4 +- internal/ports/nylas.go | 2 +- 38 files changed, 2124 insertions(+), 291 deletions(-) create mode 100644 internal/adapters/oauth/server_integration_test.go create mode 100644 internal/cli/doctor_test.go create mode 100644 internal/cli/integration/local_regressions_test.go diff --git a/internal/adapters/keyring/crossplatform_test.go b/internal/adapters/keyring/crossplatform_test.go index 3e41871..ed5219c 100644 --- a/internal/adapters/keyring/crossplatform_test.go +++ b/internal/adapters/keyring/crossplatform_test.go @@ -1,17 +1,41 @@ package keyring import ( + "crypto/rand" + "encoding/base64" + "errors" "os" "path/filepath" "runtime" + "strings" "testing" "github.com/nylas/cli/internal/domain" ) +func setFileStorePassphrase(t *testing.T) string { + t.Helper() + + orig := os.Getenv(fileStorePassphraseEnv) + passphrase := "test-file-store-passphrase" + if err := os.Setenv(fileStorePassphraseEnv, passphrase); err != nil { + t.Fatalf("failed to set %s: %v", fileStorePassphraseEnv, err) + } + t.Cleanup(func() { + if orig != "" { + _ = os.Setenv(fileStorePassphraseEnv, orig) + } else { + _ = os.Unsetenv(fileStorePassphraseEnv) + } + }) + + return passphrase +} + // TestCrossPlatformEncryptedFileStore tests the encrypted file store across platforms. func TestCrossPlatformEncryptedFileStore(t *testing.T) { tmpDir := t.TempDir() + setFileStorePassphrase(t) store, err := NewEncryptedFileStore(tmpDir) if err != nil { @@ -100,6 +124,14 @@ func TestCrossPlatformEncryptedFileStore(t *testing.T) { if mode != 0600 { t.Errorf("File permissions are %o, want 0600", mode) } + + saltInfo, err := os.Stat(filepath.Join(tmpDir, ".secrets.salt")) + if err != nil { + t.Fatalf("Failed to stat salt file: %v", err) + } + if saltMode := saltInfo.Mode().Perm(); saltMode != 0600 { + t.Errorf("Salt file permissions are %o, want 0600", saltMode) + } } }) @@ -181,12 +213,12 @@ func TestCrossPlatformEncryptedFileStore(t *testing.T) { }) } -// TestDeriveKey tests key derivation across platforms. -func TestDeriveKey(t *testing.T) { +// TestDeriveLegacyKey tests the legacy key derivation kept for migration. +func TestDeriveLegacyKey(t *testing.T) { t.Run("key_is_32_bytes", func(t *testing.T) { - key, err := deriveKey() + key, err := deriveLegacyKey() if err != nil { - t.Fatalf("deriveKey failed: %v", err) + t.Fatalf("deriveLegacyKey failed: %v", err) } if len(key) != 32 { @@ -195,14 +227,14 @@ func TestDeriveKey(t *testing.T) { }) t.Run("key_is_deterministic", func(t *testing.T) { - key1, err := deriveKey() + key1, err := deriveLegacyKey() if err != nil { - t.Fatalf("deriveKey failed: %v", err) + t.Fatalf("deriveLegacyKey failed: %v", err) } - key2, err := deriveKey() + key2, err := deriveLegacyKey() if err != nil { - t.Fatalf("deriveKey failed: %v", err) + t.Fatalf("deriveLegacyKey failed: %v", err) } if string(key1) != string(key2) { @@ -211,6 +243,220 @@ func TestDeriveKey(t *testing.T) { }) } +func TestEncryptedFileStore_RequiresPassphraseForWrites(t *testing.T) { + orig := os.Getenv(fileStorePassphraseEnv) + if orig != "" { + _ = os.Unsetenv(fileStorePassphraseEnv) + t.Cleanup(func() { _ = os.Setenv(fileStorePassphraseEnv, orig) }) + } + + store, err := NewEncryptedFileStore(t.TempDir()) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + + err = store.Set("api_key", "value") + if err == nil { + t.Fatal("Set succeeded without passphrase") + } + if !strings.Contains(err.Error(), fileStorePassphraseEnv) { + t.Fatalf("Set error %q does not mention %s", err.Error(), fileStorePassphraseEnv) + } +} + +func TestEncryptedFileStore_MigratesLegacyCiphertext(t *testing.T) { + tmpDir := t.TempDir() + passphrase := setFileStorePassphrase(t) + legacyKey, err := deriveLegacyKey() + if err != nil { + t.Fatalf("deriveLegacyKey failed: %v", err) + } + + legacyCiphertext, err := encryptWithKey(legacyKey, []byte(`{"api_key":"legacy-value"}`)) + if err != nil { + t.Fatalf("encryptWithKey failed: %v", err) + } + + secretsPath := filepath.Join(tmpDir, ".secrets.enc") + if err := os.WriteFile(secretsPath, legacyCiphertext, 0600); err != nil { + t.Fatalf("failed to write legacy secrets file: %v", err) + } + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + + value, err := store.Get("api_key") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if value != "legacy-value" { + t.Fatalf("Get returned %q, want %q", value, "legacy-value") + } + + if err := store.Set("new_key", "new-value"); err != nil { + t.Fatalf("Set failed: %v", err) + } + + data, err := os.ReadFile(secretsPath) + if err != nil { + t.Fatalf("failed to read rewritten secrets file: %v", err) + } + + if _, err := decryptWithKey(legacyKey, data); err == nil { + t.Fatal("rewritten secrets file should no longer use the legacy key") + } + + salt, err := os.ReadFile(filepath.Join(tmpDir, ".secrets.salt")) + if err != nil { + t.Fatalf("failed to read salt file: %v", err) + } + decodedSalt, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(salt))) + if err != nil { + t.Fatalf("failed to decode salt: %v", err) + } + plaintext, err := decryptWithKey(derivePassphraseKey([]byte(passphrase), decodedSalt), data) + if err != nil { + t.Fatalf("failed to decrypt rewritten secrets file with passphrase-derived key: %v", err) + } + if string(plaintext) == "" { + t.Fatal("rewritten secrets plaintext should not be empty") + } +} + +func TestEncryptedFileStore_MigratesLegacyMasterKeyCiphertext(t *testing.T) { + tmpDir := t.TempDir() + passphrase := setFileStorePassphrase(t) + + migrationKey := make([]byte, 32) + if _, err := rand.Read(migrationKey); err != nil { + t.Fatalf("rand.Read failed: %v", err) + } + + ciphertext, err := encryptWithKey(migrationKey, []byte(`{"api_key":"migrated-value"}`)) + if err != nil { + t.Fatalf("encryptWithKey failed: %v", err) + } + + secretsPath := filepath.Join(tmpDir, ".secrets.enc") + if err := os.WriteFile(secretsPath, ciphertext, 0600); err != nil { + t.Fatalf("failed to write secrets file: %v", err) + } + + keyPath := filepath.Join(tmpDir, ".secrets.key") + if err := os.WriteFile(keyPath, []byte(base64.StdEncoding.EncodeToString(migrationKey)), 0600); err != nil { + t.Fatalf("failed to write migration key: %v", err) + } + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + + value, err := store.Get("api_key") + if err != nil { + t.Fatalf("Get failed: %v", err) + } + if value != "migrated-value" { + t.Fatalf("Get returned %q, want %q", value, "migrated-value") + } + + if err := store.Set("new_key", "new-value"); err != nil { + t.Fatalf("Set failed: %v", err) + } + + if _, err := os.Stat(keyPath); !os.IsNotExist(err) { + t.Fatalf("migration key file should be removed after rewrite, stat err = %v", err) + } + + salt, err := os.ReadFile(filepath.Join(tmpDir, ".secrets.salt")) + if err != nil { + t.Fatalf("failed to read salt file: %v", err) + } + decodedSalt, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(salt))) + if err != nil { + t.Fatalf("failed to decode salt: %v", err) + } + + data, err := os.ReadFile(secretsPath) + if err != nil { + t.Fatalf("failed to read rewritten secrets file: %v", err) + } + if _, err := decryptWithKey(migrationKey, data); err == nil { + t.Fatal("rewritten secrets file should no longer use the plaintext migration key") + } + if _, err := decryptWithKey(derivePassphraseKey([]byte(passphrase), decodedSalt), data); err != nil { + t.Fatalf("failed to decrypt rewritten secrets file with passphrase-derived key: %v", err) + } +} + +func TestEncryptedFileStore_ReopensWithSamePassphrase(t *testing.T) { + tmpDir := t.TempDir() + setFileStorePassphrase(t) + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + if err := store.Set("api_key", "reopen-value"); err != nil { + t.Fatalf("Set failed: %v", err) + } + + reopened, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("reopen NewEncryptedFileStore failed: %v", err) + } + value, err := reopened.Get("api_key") + if err != nil { + t.Fatalf("Get after reopen failed: %v", err) + } + if value != "reopen-value" { + t.Fatalf("Get after reopen returned %q, want %q", value, "reopen-value") + } +} + +func TestEncryptedFileStore_RequiresPassphraseForReads(t *testing.T) { + tmpDir := t.TempDir() + orig := os.Getenv(fileStorePassphraseEnv) + passphrase := setFileStorePassphrase(t) + + store, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("NewEncryptedFileStore failed: %v", err) + } + if err := store.Set("api_key", "read-protected-value"); err != nil { + t.Fatalf("Set failed: %v", err) + } + + _ = os.Unsetenv(fileStorePassphraseEnv) + t.Cleanup(func() { + if orig != "" { + _ = os.Setenv(fileStorePassphraseEnv, orig) + } else { + _ = os.Unsetenv(fileStorePassphraseEnv) + } + }) + + reopened, err := NewEncryptedFileStore(tmpDir) + if err != nil { + t.Fatalf("reopen NewEncryptedFileStore failed: %v", err) + } + + _, err = reopened.Get("api_key") + if err == nil { + t.Fatal("Get succeeded without passphrase") + } + if !errors.Is(err, domain.ErrSecretStoreFailed) { + t.Fatalf("Get error = %v, want ErrSecretStoreFailed", err) + } + if !strings.Contains(err.Error(), fileStorePassphraseEnv) { + t.Fatalf("Get error %q does not mention %s", err.Error(), fileStorePassphraseEnv) + } + + _ = os.Setenv(fileStorePassphraseEnv, passphrase) +} + // TestGetMachineID tests machine ID retrieval across platforms. func TestGetMachineID(t *testing.T) { t.Logf("Running on %s/%s", runtime.GOOS, runtime.GOARCH) @@ -225,6 +471,8 @@ func TestGetMachineID(t *testing.T) { // TestNewSecretStore tests secret store creation with fallback. func TestNewSecretStore(t *testing.T) { tmpDir := t.TempDir() + setFileStorePassphrase(t) + t.Setenv("NYLAS_DISABLE_KEYRING", "true") store, err := NewSecretStore(tmpDir) if err != nil { @@ -257,6 +505,7 @@ func TestNewSecretStore(t *testing.T) { // TestConcurrentAccess tests concurrent access to the encrypted file store. func TestConcurrentAccess(t *testing.T) { tmpDir := t.TempDir() + setFileStorePassphrase(t) store, err := NewEncryptedFileStore(tmpDir) if err != nil { diff --git a/internal/adapters/keyring/file.go b/internal/adapters/keyring/file.go index 067f024..ea13ef6 100644 --- a/internal/adapters/keyring/file.go +++ b/internal/adapters/keyring/file.go @@ -16,28 +16,56 @@ import ( "sync" "github.com/nylas/cli/internal/domain" + "golang.org/x/crypto/argon2" +) + +const ( + fileStorePassphraseEnv = "NYLAS_FILE_STORE_PASSPHRASE" + fileStoreSaltSize = 16 ) // EncryptedFileStore implements SecretStore using an encrypted file. // This is a fallback for environments where the system keyring is unavailable. -// Uses AES-256-GCM encryption with a machine-specific key. +// Uses AES-256-GCM encryption with a key derived from user-supplied secret material. type EncryptedFileStore struct { - path string - key []byte - mu sync.RWMutex + path string + keyPath string + saltPath string + passphrase []byte + migrationKey []byte + legacyKey []byte + mu sync.RWMutex } // NewEncryptedFileStore creates a new EncryptedFileStore. // The secrets are stored in an encrypted file within the config directory. func NewEncryptedFileStore(configDir string) (*EncryptedFileStore, error) { path := filepath.Join(configDir, ".secrets.enc") - key, err := deriveKey() + keyPath := filepath.Join(configDir, ".secrets.key") + saltPath := filepath.Join(configDir, ".secrets.salt") + + legacyKey, err := deriveLegacyKey() + if err != nil { + return nil, fmt.Errorf("failed to derive legacy encryption key: %w", err) + } + + migrationKey, err := readCompatibilityMasterKey(keyPath) if err != nil { - return nil, fmt.Errorf("failed to derive encryption key: %w", err) + return nil, fmt.Errorf("failed to load legacy file-store key: %w", err) + } + + var passphrase []byte + if value := os.Getenv(fileStorePassphraseEnv); value != "" { + passphrase = []byte(value) } + return &EncryptedFileStore{ - path: path, - key: key, + path: path, + keyPath: keyPath, + saltPath: saltPath, + passphrase: passphrase, + migrationKey: migrationKey, + legacyKey: legacyKey, }, nil } @@ -149,12 +177,61 @@ func (f *EncryptedFileStore) saveSecrets(secrets map[string]string) error { } // Write with restrictive permissions - return os.WriteFile(f.path, ciphertext, 0600) + if err := os.WriteFile(f.path, ciphertext, 0600); err != nil { + return err + } + + // Remove the plaintext migration key once the store has been rewritten. + if f.keyPath != "" { + _ = os.Remove(f.keyPath) + } + + return nil } // encrypt encrypts plaintext using AES-256-GCM. func (f *EncryptedFileStore) encrypt(plaintext []byte) ([]byte, error) { - block, err := aes.NewCipher(f.key) + key, err := f.passphraseKey(true) + if err != nil { + return nil, err + } + return encryptWithKey(key, plaintext) +} + +// decrypt decrypts ciphertext using AES-256-GCM. +func (f *EncryptedFileStore) decrypt(data []byte) ([]byte, error) { + if key, err := f.passphraseKey(false); err == nil { + plaintext, err := decryptWithKey(key, data) + if err == nil { + return plaintext, nil + } + } else if !os.IsNotExist(err) && len(f.passphrase) > 0 { + return nil, err + } + + if len(f.migrationKey) > 0 { + plaintext, err := decryptWithKey(f.migrationKey, data) + if err == nil { + return plaintext, nil + } + } + + if len(f.legacyKey) > 0 { + plaintext, err := decryptWithKey(f.legacyKey, data) + if err == nil { + return plaintext, nil + } + } + + if len(f.passphrase) == 0 { + return nil, fmt.Errorf("%s must be set to unlock the encrypted file store", fileStorePassphraseEnv) + } + + return nil, fmt.Errorf("failed to decrypt encrypted file store with the configured passphrase") +} + +func encryptWithKey(key, plaintext []byte) ([]byte, error) { + block, err := aes.NewCipher(key) if err != nil { return nil, err } @@ -173,14 +250,13 @@ func (f *EncryptedFileStore) encrypt(plaintext []byte) ([]byte, error) { return []byte(base64.StdEncoding.EncodeToString(ciphertext)), nil } -// decrypt decrypts ciphertext using AES-256-GCM. -func (f *EncryptedFileStore) decrypt(data []byte) ([]byte, error) { +func decryptWithKey(key, data []byte) ([]byte, error) { ciphertext, err := base64.StdEncoding.DecodeString(string(data)) if err != nil { return nil, err } - block, err := aes.NewCipher(f.key) + block, err := aes.NewCipher(key) if err != nil { return nil, err } @@ -198,9 +274,83 @@ func (f *EncryptedFileStore) decrypt(data []byte) ([]byte, error) { return gcm.Open(nil, nonce, ciphertext, nil) } -// deriveKey derives a 32-byte encryption key from machine-specific identifiers. -// This makes the encrypted file non-portable but provides reasonable security. -func deriveKey() ([]byte, error) { +func readCompatibilityMasterKey(path string) ([]byte, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + + key, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(data))) + if err != nil { + return nil, err + } + if len(key) != 32 { + return nil, fmt.Errorf("invalid master key length: %d", len(key)) + } + return key, nil +} + +func (f *EncryptedFileStore) passphraseKey(createSalt bool) ([]byte, error) { + if len(f.passphrase) == 0 { + return nil, fmt.Errorf("%s must be set to use the encrypted file secret store", fileStorePassphraseEnv) + } + + salt, err := f.loadSalt(createSalt) + if err != nil { + return nil, err + } + + return derivePassphraseKey(f.passphrase, salt), nil +} + +func (f *EncryptedFileStore) loadSalt(create bool) ([]byte, error) { + data, err := os.ReadFile(f.saltPath) + if err == nil { + salt, err := base64.StdEncoding.DecodeString(strings.TrimSpace(string(data))) + if err != nil { + return nil, err + } + if len(salt) != fileStoreSaltSize { + return nil, fmt.Errorf("invalid file-store salt length: %d", len(salt)) + } + return salt, nil + } + if !os.IsNotExist(err) { + return nil, err + } + if !create { + return nil, err + } + + if err := os.MkdirAll(filepath.Dir(f.saltPath), 0700); err != nil { + return nil, err + } + + salt := make([]byte, fileStoreSaltSize) + if _, err := io.ReadFull(rand.Reader, salt); err != nil { + return nil, err + } + + encoded := base64.StdEncoding.EncodeToString(salt) + if err := os.WriteFile(f.saltPath, []byte(encoded), 0600); err != nil { + return nil, err + } + + return salt, nil +} + +func derivePassphraseKey(passphrase, salt []byte) []byte { + // Argon2id keeps the fallback store bound to user-supplied secret material + // instead of host metadata while staying fast enough for CLI use. + return argon2.IDKey(passphrase, salt, 1, 64*1024, 4, 32) +} + +// deriveLegacyKey derives the pre-v2 machine-specific fallback key so older +// encrypted files can still be read and rewritten with a passphrase-derived key. +func deriveLegacyKey() ([]byte, error) { // Collect machine-specific identifiers var identifiers []byte diff --git a/internal/adapters/keyring/keyring.go b/internal/adapters/keyring/keyring.go index 39b7d96..003b8c8 100644 --- a/internal/adapters/keyring/keyring.go +++ b/internal/adapters/keyring/keyring.go @@ -2,6 +2,7 @@ package keyring import ( + "errors" "os" "github.com/nylas/cli/internal/domain" @@ -89,8 +90,11 @@ func NewSecretStore(configDir string) (ports.SecretStore, error) { // Check if file store has credentials apiKey, err := fileStore.Get(ports.KeyAPIKey) if err != nil { - // No credentials in file store either, use keyring for fresh setup - return kr, nil + if errors.Is(err, domain.ErrSecretNotFound) { + // No credentials in file store either, use keyring for fresh setup + return kr, nil + } + return nil, err } // Migrate credentials from file store to keyring diff --git a/internal/adapters/keyring/keyring_test.go b/internal/adapters/keyring/keyring_test.go index 8c45215..60a73d4 100644 --- a/internal/adapters/keyring/keyring_test.go +++ b/internal/adapters/keyring/keyring_test.go @@ -13,6 +13,20 @@ import ( "github.com/stretchr/testify/require" ) +func setFileStorePassphrase(t *testing.T) { + t.Helper() + + orig := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + require.NoError(t, os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "test-file-store-passphrase")) + t.Cleanup(func() { + if orig != "" { + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", orig) + } else { + _ = os.Unsetenv("NYLAS_FILE_STORE_PASSPHRASE") + } + }) +} + func TestMockSecretStore(t *testing.T) { store := keyring.NewMockSecretStore() @@ -94,6 +108,7 @@ func TestMockSecretStore(t *testing.T) { func TestEncryptedFileStore(t *testing.T) { tmpDir := t.TempDir() + setFileStorePassphrase(t) store, err := keyring.NewEncryptedFileStore(tmpDir) require.NoError(t, err) @@ -168,6 +183,8 @@ func TestEncryptedFileStore(t *testing.T) { func TestNewSecretStore(t *testing.T) { tmpDir := t.TempDir() + setFileStorePassphrase(t) + t.Setenv("NYLAS_DISABLE_KEYRING", "true") store, err := keyring.NewSecretStore(tmpDir) require.NoError(t, err) @@ -176,13 +193,16 @@ func TestNewSecretStore(t *testing.T) { assert.True(t, store.IsAvailable()) name := store.Name() - assert.True(t, name == "system keyring" || name == "encrypted file", - "store name should be 'system keyring' or 'encrypted file', got: %s", name) + assert.Equal(t, "encrypted file", name) t.Logf("Platform: %s, Secret store: %s", runtime.GOOS, name) } func TestSystemKeyring(t *testing.T) { + if os.Getenv("NYLAS_RUN_SYSTEM_KEYRING_TESTS") != "true" { + t.Skip("set NYLAS_RUN_SYSTEM_KEYRING_TESTS=true to run live system keyring tests") + } + kr := keyring.NewSystemKeyring() require.NotNil(t, kr) @@ -214,6 +234,7 @@ func TestSystemKeyring(t *testing.T) { func TestCrossPlatformKeyDerivation(t *testing.T) { tmpDir := t.TempDir() + setFileStorePassphrase(t) store, err := keyring.NewEncryptedFileStore(tmpDir) require.NoError(t, err, "EncryptedFileStore should be creatable on %s", runtime.GOOS) diff --git a/internal/adapters/mcp/proxy.go b/internal/adapters/mcp/proxy.go index 0329c23..bc3caea 100644 --- a/internal/adapters/mcp/proxy.go +++ b/internal/adapters/mcp/proxy.go @@ -263,6 +263,7 @@ func (p *Proxy) forward(ctx context.Context, request []byte, parsed *rpcRequest) // readSSE reads Server-Sent Events and extracts JSON-RPC messages. func (p *Proxy) readSSE(reader io.Reader) ([]byte, error) { scanner := bufio.NewScanner(reader) + scanner.Buffer(make([]byte, 64*1024), 10*1024*1024) var responses []json.RawMessage for scanner.Scan() { diff --git a/internal/adapters/mcp/proxy_sse_test.go b/internal/adapters/mcp/proxy_sse_test.go index 7c3c188..b64ade2 100644 --- a/internal/adapters/mcp/proxy_sse_test.go +++ b/internal/adapters/mcp/proxy_sse_test.go @@ -105,6 +105,22 @@ func TestProxy_readSSE(t *testing.T) { } } +func TestProxy_readSSE_LargeFrame(t *testing.T) { + t.Parallel() + + proxy := NewProxy("test-key", "us") + largePayload := strings.Repeat("x", 70*1024) + + result, err := proxy.readSSE(strings.NewReader("data: {\"payload\":\"" + largePayload + "\"}\n\n")) + if err != nil { + t.Fatalf("readSSE failed: %v", err) + } + + if !json.Valid(result) { + t.Fatalf("expected valid JSON, got %d bytes", len(result)) + } +} + // mockGrantStore implements ports.GrantStore for testing. type mockGrantStore struct { grants []domain.GrantInfo diff --git a/internal/adapters/nylas/auth.go b/internal/adapters/nylas/auth.go index cafcc2c..5cde7c3 100644 --- a/internal/adapters/nylas/auth.go +++ b/internal/adapters/nylas/auth.go @@ -10,19 +10,26 @@ import ( ) // BuildAuthURL builds the OAuth authorization URL. -func (c *HTTPClient) BuildAuthURL(provider domain.Provider, redirectURI string) string { +func (c *HTTPClient) BuildAuthURL(provider domain.Provider, redirectURI, state, codeChallenge string) string { baseURL := fmt.Sprintf("%s/v3/connect/auth", c.baseURL) - return NewQueryBuilder(). + query := NewQueryBuilder(). Add("client_id", c.clientID). Add("redirect_uri", redirectURI). Add("response_type", "code"). Add("provider", string(provider)). Add("access_type", "offline"). - BuildURL(baseURL) + Add("state", state) + + if codeChallenge != "" { + query.Add("code_challenge", codeChallenge). + Add("code_challenge_method", "S256") + } + + return query.BuildURL(baseURL) } // ExchangeCode exchanges an authorization code for tokens. -func (c *HTTPClient) ExchangeCode(ctx context.Context, code, redirectURI string) (*domain.Grant, error) { +func (c *HTTPClient) ExchangeCode(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { // In Nylas v3, client_secret is the API key secret := c.clientSecret if secret == "" { @@ -35,7 +42,7 @@ func (c *HTTPClient) ExchangeCode(ctx context.Context, code, redirectURI string) "grant_type": "authorization_code", "client_id": c.clientID, "client_secret": secret, - "code_verifier": "nylas", + "code_verifier": codeVerifier, } resp, err := c.doJSONRequestNoAuth(ctx, "POST", c.baseURL+"/v3/connect/token", payload) diff --git a/internal/adapters/nylas/auth_test.go b/internal/adapters/nylas/auth_test.go index 328064a..4c07532 100644 --- a/internal/adapters/nylas/auth_test.go +++ b/internal/adapters/nylas/auth_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/nylas/cli/internal/domain" "github.com/stretchr/testify/assert" @@ -40,6 +41,9 @@ func TestHTTPClient_BuildAuthURL(t *testing.T) { "response_type=code", "provider=google", "access_type=offline", + "state=test-state", + "code_challenge=test-challenge", + "code_challenge_method=S256", }, }, { @@ -63,7 +67,7 @@ func TestHTTPClient_BuildAuthURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - url := client.BuildAuthURL(tt.provider, tt.redirectURI) + url := client.BuildAuthURL(tt.provider, tt.redirectURI, "test-state", "test-challenge") for _, want := range tt.wantInURL { assert.Contains(t, url, want) @@ -135,7 +139,7 @@ func TestHTTPClient_ExchangeCode(t *testing.T) { client := newTestClient("test-api-key", "test-client-id", "test-client-secret") client.SetBaseURL(server.URL) - grant, err := client.ExchangeCode(context.Background(), "test-code", "http://localhost/callback") + grant, err := client.ExchangeCode(context.Background(), "test-code", "http://localhost/callback", "test-verifier") if tt.wantErr { assert.Error(t, err) @@ -173,12 +177,13 @@ func TestHTTPClient_ExchangeCode_UsesAPIKeyAsSecret(t *testing.T) { client := newTestClient("my-api-key", "my-client-id", "") client.SetBaseURL(server.URL) - _, err := client.ExchangeCode(context.Background(), "test-code", "http://localhost/callback") + _, err := client.ExchangeCode(context.Background(), "test-code", "http://localhost/callback", "test-verifier") require.NoError(t, err) assert.Equal(t, "my-api-key", receivedBody["client_secret"]) assert.Equal(t, "my-client-id", receivedBody["client_id"]) assert.Equal(t, "test-code", receivedBody["code"]) + assert.Equal(t, "test-verifier", receivedBody["code_verifier"]) } func TestHTTPClient_ListGrants(t *testing.T) { @@ -441,7 +446,31 @@ func TestHTTPClient_ExchangeCode_ContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - _, err := client.ExchangeCode(ctx, "test-code", "http://localhost/callback") + _, err := client.ExchangeCode(ctx, "test-code", "http://localhost/callback", "test-verifier") assert.Error(t, err) } + +func TestHTTPClient_ExchangeCode_DoesNotCancelBodyReads(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + flusher, ok := w.(http.Flusher) + require.True(t, ok) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + flusher.Flush() + + time.Sleep(50 * time.Millisecond) + _, _ = w.Write([]byte(`{"grant_id":"grant-123","email":"user@example.com","provider":"google"}`)) + })) + defer server.Close() + + client := newTestClient("test-api-key", "test-client-id", "") + client.SetBaseURL(server.URL) + client.requestTimeout = time.Second + + grant, err := client.ExchangeCode(context.Background(), "test-code", "http://localhost/callback", "test-verifier") + require.NoError(t, err) + require.NotNil(t, grant) + assert.Equal(t, "grant-123", grant.ID) +} diff --git a/internal/adapters/nylas/client.go b/internal/adapters/nylas/client.go index 36afb97..7727f3c 100644 --- a/internal/adapters/nylas/client.go +++ b/internal/adapters/nylas/client.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "strconv" + "sync" "time" "github.com/nylas/cli/internal/adapters/providers" @@ -193,9 +194,10 @@ func (c *HTTPClient) doRequest(ctx context.Context, req *http.Request) (*http.Re // Execute request resp, err := c.httpClient.Do(reqToUse) - cancel() // Cancel timeout context if err != nil { + cancel() + // Don't retry if the parent context is done if ctx.Err() != nil { return nil, ctx.Err() @@ -221,6 +223,11 @@ func (c *HTTPClient) doRequest(ctx context.Context, req *http.Request) (*http.Re return nil, lastErr } + resp.Body = &cancelOnCloseBody{ + ReadCloser: resp.Body, + cancel: cancel, + } + // Check if we should retry based on status code if c.shouldRetryStatus(resp.StatusCode) && attempt < c.maxRetries { lastResp = resp @@ -245,6 +252,26 @@ func (c *HTTPClient) doRequest(ctx context.Context, req *http.Request) (*http.Re return nil, lastErr } +type cancelOnCloseBody struct { + io.ReadCloser + cancel context.CancelFunc + once sync.Once +} + +func (b *cancelOnCloseBody) Read(p []byte) (int, error) { + n, err := b.ReadCloser.Read(p) + if errors.Is(err, io.EOF) { + b.once.Do(b.cancel) + } + return n, err +} + +func (b *cancelOnCloseBody) Close() error { + err := b.ReadCloser.Close() + b.once.Do(b.cancel) + return err +} + // shouldRetryStatus determines if a status code is retryable func (c *HTTPClient) shouldRetryStatus(statusCode int) bool { switch statusCode { diff --git a/internal/adapters/nylas/client_base_test.go b/internal/adapters/nylas/client_base_test.go index 031acc8..d24020e 100644 --- a/internal/adapters/nylas/client_base_test.go +++ b/internal/adapters/nylas/client_base_test.go @@ -19,7 +19,7 @@ func TestMockClientImplementsInterface(t *testing.T) { var _ interface { SetRegion(region string) SetCredentials(clientID, clientSecret, apiKey string) - BuildAuthURL(provider domain.Provider, redirectURI string) string + BuildAuthURL(provider domain.Provider, redirectURI, state, codeChallenge string) string } = nylas.NewMockClient() } @@ -33,13 +33,13 @@ func TestHTTPClient_SetRegion(t *testing.T) { t.Run("sets US region by default", func(t *testing.T) { client.SetRegion("us") - url := client.BuildAuthURL(domain.ProviderGoogle, "http://localhost") + url := client.BuildAuthURL(domain.ProviderGoogle, "http://localhost", "", "") assert.Contains(t, url, "api.us.nylas.com") }) t.Run("sets EU region", func(t *testing.T) { client.SetRegion("eu") - url := client.BuildAuthURL(domain.ProviderGoogle, "http://localhost") + url := client.BuildAuthURL(domain.ProviderGoogle, "http://localhost", "", "") assert.Contains(t, url, "api.eu.nylas.com") }) } @@ -48,7 +48,7 @@ func TestHTTPClient_SetCredentials(t *testing.T) { client := nylas.NewHTTPClient() client.SetCredentials("my-client-id", "my-secret", "my-api-key") - url := client.BuildAuthURL(domain.ProviderGoogle, "http://localhost") + url := client.BuildAuthURL(domain.ProviderGoogle, "http://localhost", "", "") assert.Contains(t, url, "client_id=my-client-id") } @@ -85,7 +85,7 @@ func TestHTTPClient_BuildAuthURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - url := client.BuildAuthURL(tt.provider, tt.redirectURI) + url := client.BuildAuthURL(tt.provider, tt.redirectURI, "", "") for _, want := range tt.wantContain { assert.Contains(t, url, want) } diff --git a/internal/adapters/nylas/client_mock_methods_test.go b/internal/adapters/nylas/client_mock_methods_test.go index b999db5..ee2a707 100644 --- a/internal/adapters/nylas/client_mock_methods_test.go +++ b/internal/adapters/nylas/client_mock_methods_test.go @@ -278,7 +278,7 @@ func TestMockClient_Grants(t *testing.T) { mock := nylas.NewMockClient() t.Run("ExchangeCode", func(t *testing.T) { - grant, err := mock.ExchangeCode(ctx, "auth-code", "http://localhost") + grant, err := mock.ExchangeCode(ctx, "auth-code", "http://localhost", "test-verifier") require.NoError(t, err) assert.Equal(t, "mock-grant-id", grant.ID) assert.True(t, mock.ExchangeCodeCalled) diff --git a/internal/adapters/nylas/demo/base.go b/internal/adapters/nylas/demo/base.go index 97d2476..9d90878 100644 --- a/internal/adapters/nylas/demo/base.go +++ b/internal/adapters/nylas/demo/base.go @@ -22,12 +22,12 @@ func (d *Client) SetRegion(region string) {} func (d *Client) SetCredentials(clientID, clientSecret, apiKey string) {} // BuildAuthURL returns a mock auth URL. -func (d *Client) BuildAuthURL(provider domain.Provider, redirectURI string) string { +func (d *Client) BuildAuthURL(provider domain.Provider, redirectURI, state, codeChallenge string) string { return "https://demo.nylas.com/auth" } // ExchangeCode returns a mock grant. -func (d *Client) ExchangeCode(ctx context.Context, code, redirectURI string) (*domain.Grant, error) { +func (d *Client) ExchangeCode(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { return &domain.Grant{ ID: "demo-grant-id", Email: "demo@example.com", diff --git a/internal/adapters/nylas/demo/base_test.go b/internal/adapters/nylas/demo/base_test.go index 6059e0c..e85c8c5 100644 --- a/internal/adapters/nylas/demo/base_test.go +++ b/internal/adapters/nylas/demo/base_test.go @@ -64,7 +64,7 @@ func TestClient_BuildAuthURL(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := client.BuildAuthURL(tt.provider, tt.redirectURI) + got := client.BuildAuthURL(tt.provider, tt.redirectURI, "", "") if got != tt.want { t.Errorf("BuildAuthURL() = %q, want %q", got, tt.want) } @@ -104,7 +104,7 @@ func TestClient_ExchangeCode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - grant, err := client.ExchangeCode(ctx, tt.code, tt.redirectURI) + grant, err := client.ExchangeCode(ctx, tt.code, tt.redirectURI, "test-verifier") if (err != nil) != tt.wantErr { t.Errorf("ExchangeCode() error = %v, wantErr %v", err, tt.wantErr) @@ -136,7 +136,7 @@ func TestClient_ExchangeCode_FieldValidation(t *testing.T) { client := New() ctx := context.Background() - grant, err := client.ExchangeCode(ctx, "test-code", "http://localhost/callback") + grant, err := client.ExchangeCode(ctx, "test-code", "http://localhost/callback", "test-verifier") if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -178,9 +178,9 @@ func TestClient_ExchangeCode_Consistency(t *testing.T) { ctx := context.Background() // Call multiple times to ensure consistency - grant1, err1 := client.ExchangeCode(ctx, "code1", "uri1") - grant2, err2 := client.ExchangeCode(ctx, "code2", "uri2") - grant3, err3 := client.ExchangeCode(ctx, "code3", "uri3") + grant1, err1 := client.ExchangeCode(ctx, "code1", "uri1", "verifier-1") + grant2, err2 := client.ExchangeCode(ctx, "code2", "uri2", "verifier-2") + grant3, err3 := client.ExchangeCode(ctx, "code3", "uri3", "verifier-3") if err1 != nil || err2 != nil || err3 != nil { t.Fatal("unexpected errors") @@ -205,7 +205,7 @@ func TestClient_ContextCancellation(t *testing.T) { // Demo client should still work even with cancelled context // because it doesn't make real API calls - grant, err := client.ExchangeCode(ctx, "code", "uri") + grant, err := client.ExchangeCode(ctx, "code", "uri", "test-verifier") if err != nil { t.Errorf("demo client should ignore context cancellation, got error: %v", err) } diff --git a/internal/adapters/nylas/demo_client.go b/internal/adapters/nylas/demo_client.go index 3e03711..b8b9873 100644 --- a/internal/adapters/nylas/demo_client.go +++ b/internal/adapters/nylas/demo_client.go @@ -22,12 +22,12 @@ func (d *DemoClient) SetRegion(region string) {} func (d *DemoClient) SetCredentials(clientID, clientSecret, apiKey string) {} // BuildAuthURL returns a mock auth URL. -func (d *DemoClient) BuildAuthURL(provider domain.Provider, redirectURI string) string { +func (d *DemoClient) BuildAuthURL(provider domain.Provider, redirectURI, state, codeChallenge string) string { return "https://demo.nylas.com/auth" } // ExchangeCode returns a mock grant. -func (d *DemoClient) ExchangeCode(ctx context.Context, code, redirectURI string) (*domain.Grant, error) { +func (d *DemoClient) ExchangeCode(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { return &domain.Grant{ ID: "demo-grant-id", Email: "demo@example.com", diff --git a/internal/adapters/nylas/mock_client.go b/internal/adapters/nylas/mock_client.go index f87be63..b06dd69 100644 --- a/internal/adapters/nylas/mock_client.go +++ b/internal/adapters/nylas/mock_client.go @@ -15,6 +15,7 @@ type MockClient struct { APIKey string // Call tracking + BuildAuthURLCalled bool ExchangeCodeCalled bool ListGrantsCalled bool GetGrantCalled bool @@ -44,6 +45,10 @@ type MockClient struct { GetAttachmentCalled bool DownloadAttachmentCalled bool LastGrantID string + LastRedirectURI string + LastAuthState string + LastCodeChallenge string + LastCodeVerifier string LastMessageID string LastThreadID string LastDraftID string @@ -51,7 +56,8 @@ type MockClient struct { LastAttachmentID string // Custom functions - ExchangeCodeFunc func(ctx context.Context, code, redirectURI string) (*domain.Grant, error) + BuildAuthURLFunc func(provider domain.Provider, redirectURI, state, codeChallenge string) string + ExchangeCodeFunc func(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) ListGrantsFunc func(ctx context.Context) ([]domain.Grant, error) GetGrantFunc func(ctx context.Context, grantID string) (*domain.Grant, error) RevokeGrantFunc func(ctx context.Context, grantID string) error @@ -111,6 +117,13 @@ func (m *MockClient) SetCredentials(clientID, clientSecret, apiKey string) { } // BuildAuthURL returns a mock auth URL. -func (m *MockClient) BuildAuthURL(provider domain.Provider, redirectURI string) string { +func (m *MockClient) BuildAuthURL(provider domain.Provider, redirectURI, state, codeChallenge string) string { + m.BuildAuthURLCalled = true + m.LastRedirectURI = redirectURI + m.LastAuthState = state + m.LastCodeChallenge = codeChallenge + if m.BuildAuthURLFunc != nil { + return m.BuildAuthURLFunc(provider, redirectURI, state, codeChallenge) + } return "https://mock.nylas.com/auth" } diff --git a/internal/adapters/nylas/mock_grants.go b/internal/adapters/nylas/mock_grants.go index 426eaaf..325ac4e 100644 --- a/internal/adapters/nylas/mock_grants.go +++ b/internal/adapters/nylas/mock_grants.go @@ -6,10 +6,12 @@ import ( "github.com/nylas/cli/internal/domain" ) -func (m *MockClient) ExchangeCode(ctx context.Context, code, redirectURI string) (*domain.Grant, error) { +func (m *MockClient) ExchangeCode(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { m.ExchangeCodeCalled = true + m.LastRedirectURI = redirectURI + m.LastCodeVerifier = codeVerifier if m.ExchangeCodeFunc != nil { - return m.ExchangeCodeFunc(ctx, code, redirectURI) + return m.ExchangeCodeFunc(ctx, code, redirectURI, codeVerifier) } return &domain.Grant{ ID: "mock-grant-id", diff --git a/internal/adapters/nylas/security_test.go b/internal/adapters/nylas/security_test.go index a0e64ee..15ae620 100644 --- a/internal/adapters/nylas/security_test.go +++ b/internal/adapters/nylas/security_test.go @@ -178,7 +178,7 @@ func TestInputValidation(t *testing.T) { // These should handle empty strings gracefully // Not crash or panic - _ = client.BuildAuthURL("google", "") + _ = client.BuildAuthURL("google", "", "", "") t.Log("Empty redirect URI handled") }) diff --git a/internal/adapters/oauth/mock.go b/internal/adapters/oauth/mock.go index 2da445f..711213b 100644 --- a/internal/adapters/oauth/mock.go +++ b/internal/adapters/oauth/mock.go @@ -11,6 +11,7 @@ import ( type MockServer struct { Port int AuthCode string + ExpectedState string StartCalled bool StopCalled bool WaitForCallbackCalled bool @@ -38,8 +39,9 @@ func (m *MockServer) Stop() error { } // WaitForCallback waits for the OAuth callback. -func (m *MockServer) WaitForCallback(ctx context.Context) (string, error) { +func (m *MockServer) WaitForCallback(ctx context.Context, expectedState string) (string, error) { m.WaitForCallbackCalled = true + m.ExpectedState = expectedState if m.TimeoutAfter > 0 { select { diff --git a/internal/adapters/oauth/server.go b/internal/adapters/oauth/server.go index 752fd06..af9898d 100644 --- a/internal/adapters/oauth/server.go +++ b/internal/adapters/oauth/server.go @@ -3,6 +3,7 @@ package oauth import ( "context" + "crypto/subtle" "fmt" "net" "net/http" @@ -20,6 +21,8 @@ type CallbackServer struct { codeChan chan string errChan chan error once sync.Once + mu sync.RWMutex + state string } // NewCallbackServer creates a new callback server. @@ -33,6 +36,11 @@ func NewCallbackServer(port int) *CallbackServer { // Start starts the callback server. func (s *CallbackServer) Start() error { + s.once = sync.Once{} + s.codeChan = make(chan string, 1) + s.errChan = make(chan error, 1) + s.setExpectedState("") + mux := http.NewServeMux() mux.HandleFunc("/callback", s.handleCallback) @@ -68,7 +76,9 @@ func (s *CallbackServer) Stop() error { } // WaitForCallback waits for the OAuth callback and returns the auth code. -func (s *CallbackServer) WaitForCallback(ctx context.Context) (string, error) { +func (s *CallbackServer) WaitForCallback(ctx context.Context, expectedState string) (string, error) { + s.setExpectedState(expectedState) + select { case code := <-s.codeChan: return code, nil @@ -98,6 +108,14 @@ func (s *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) return } + if !s.matchesExpectedState(r.URL.Query().Get("state")) { + s.once.Do(func() { + s.errChan <- fmt.Errorf("%w: invalid OAuth state", domain.ErrAuthFailed) + }) + http.Error(w, "Authentication failed: invalid OAuth state", http.StatusBadRequest) + return + } + s.once.Do(func() { s.codeChan <- code }) @@ -126,3 +144,21 @@ func (s *CallbackServer) handleCallback(w http.ResponseWriter, r *http.Request) `) } + +func (s *CallbackServer) setExpectedState(state string) { + s.mu.Lock() + defer s.mu.Unlock() + s.state = state +} + +func (s *CallbackServer) matchesExpectedState(state string) bool { + s.mu.RLock() + expected := s.state + s.mu.RUnlock() + + if expected == "" || state == "" || len(expected) != len(state) { + return false + } + + return subtle.ConstantTimeCompare([]byte(state), []byte(expected)) == 1 +} diff --git a/internal/adapters/oauth/server_integration_test.go b/internal/adapters/oauth/server_integration_test.go new file mode 100644 index 0000000..a88e529 --- /dev/null +++ b/internal/adapters/oauth/server_integration_test.go @@ -0,0 +1,73 @@ +//go:build integration + +package oauth + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "testing" + "time" + + "github.com/nylas/cli/internal/domain" +) + +func TestIntegration_CallbackServer_InvalidStateFailsImmediately(t *testing.T) { + server := NewCallbackServer(0) + if err := server.Start(); err != nil { + t.Fatalf("failed to start callback server: %v", err) + } + defer func() { _ = server.Stop() }() + + port := server.listener.Addr().(*net.TCPAddr).Port + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + resultCh := make(chan error, 1) + go func() { + _, err := server.WaitForCallback(ctx, "expected-state") + resultCh <- err + }() + + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + if server.matchesExpectedState("expected-state") { + break + } + time.Sleep(5 * time.Millisecond) + } + + callbackURL := fmt.Sprintf("http://127.0.0.1:%d/callback?code=test-code-123&state=wrong-state", port) + var resp *http.Response + var err error + requestDeadline := time.Now().Add(time.Second) + for time.Now().Before(requestDeadline) { + resp, err = http.Get(callbackURL) + if err == nil { + break + } + time.Sleep(10 * time.Millisecond) + } + if err != nil { + t.Fatalf("failed to send callback request: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } + + select { + case err := <-resultCh: + if !errors.Is(err, domain.ErrAuthFailed) { + t.Fatalf("error = %v, want %v", err, domain.ErrAuthFailed) + } + if err == nil || err.Error() == domain.ErrAuthFailed.Error() { + t.Fatalf("error = %v, want invalid state details", err) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("WaitForCallback did not fail after invalid state callback") + } +} diff --git a/internal/adapters/oauth/server_test.go b/internal/adapters/oauth/server_test.go index 07ca78c..3be99a9 100644 --- a/internal/adapters/oauth/server_test.go +++ b/internal/adapters/oauth/server_test.go @@ -2,6 +2,7 @@ package oauth import ( "context" + "errors" "net/http" "net/http/httptest" "testing" @@ -60,9 +61,10 @@ func TestCallbackServer_GetRedirectURI(t *testing.T) { func TestCallbackServer_handleCallback_Success(t *testing.T) { server := NewCallbackServer(8080) + server.setExpectedState("test-state-123") // Create request with auth code - req := httptest.NewRequest(http.MethodGet, "/callback?code=test-code-123", nil) + req := httptest.NewRequest(http.MethodGet, "/callback?code=test-code-123&state=test-state-123", nil) w := httptest.NewRecorder() // Handle callback @@ -148,6 +150,79 @@ func TestCallbackServer_handleCallback_MissingCode(t *testing.T) { } } +func TestCallbackServer_handleCallback_InvalidState(t *testing.T) { + server := NewCallbackServer(8080) + server.setExpectedState("expected-state") + + req := httptest.NewRequest(http.MethodGet, "/callback?code=test-code-123&state=wrong-state", nil) + w := httptest.NewRecorder() + + server.handleCallback(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Status code = %d, want %d", w.Code, http.StatusBadRequest) + } + + select { + case code := <-server.codeChan: + t.Errorf("unexpected code received: %q", code) + default: + } + + select { + case err := <-server.errChan: + if !errors.Is(err, domain.ErrAuthFailed) { + t.Fatalf("error = %v, want %v", err, domain.ErrAuthFailed) + } + if !contains(err.Error(), "invalid OAuth state") { + t.Fatalf("error = %q, want invalid state message", err.Error()) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("expected invalid state error to be sent") + } +} + +func TestCallbackServer_WaitForCallback_InvalidState(t *testing.T) { + server := NewCallbackServer(8080) + resultCh := make(chan error, 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + go func() { + _, err := server.WaitForCallback(ctx, "expected-state") + resultCh <- err + }() + + deadline := time.Now().Add(100 * time.Millisecond) + for time.Now().Before(deadline) { + if server.matchesExpectedState("expected-state") { + break + } + time.Sleep(5 * time.Millisecond) + } + + req := httptest.NewRequest(http.MethodGet, "/callback?code=test-code-123&state=wrong-state", nil) + w := httptest.NewRecorder() + server.handleCallback(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Status code = %d, want %d", w.Code, http.StatusBadRequest) + } + + select { + case err := <-resultCh: + if !errors.Is(err, domain.ErrAuthFailed) { + t.Fatalf("error = %v, want %v", err, domain.ErrAuthFailed) + } + if !contains(err.Error(), "invalid OAuth state") { + t.Fatalf("error = %q, want invalid state message", err.Error()) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("WaitForCallback did not return after invalid state") + } +} + func TestCallbackServer_WaitForCallback_Success(t *testing.T) { server := NewCallbackServer(8080) @@ -158,7 +233,7 @@ func TestCallbackServer_WaitForCallback_Success(t *testing.T) { }() ctx := context.Background() - code, err := server.WaitForCallback(ctx) + code, err := server.WaitForCallback(ctx, "expected-state") if err != nil { t.Errorf("WaitForCallback() error = %v, want nil", err) @@ -179,7 +254,7 @@ func TestCallbackServer_WaitForCallback_Error(t *testing.T) { }() ctx := context.Background() - code, err := server.WaitForCallback(ctx) + code, err := server.WaitForCallback(ctx, "expected-state") if err == nil { t.Error("WaitForCallback() error = nil, want error") @@ -199,7 +274,7 @@ func TestCallbackServer_WaitForCallback_Timeout(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) defer cancel() - code, err := server.WaitForCallback(ctx) + code, err := server.WaitForCallback(ctx, "expected-state") if err != domain.ErrAuthTimeout { t.Errorf("error = %v, want %v", err, domain.ErrAuthTimeout) @@ -232,12 +307,13 @@ func TestCallbackServer_handleCallback_OnlyOnce(t *testing.T) { server := NewCallbackServer(8080) // First callback - should succeed - req1 := httptest.NewRequest(http.MethodGet, "/callback?code=first", nil) + server.setExpectedState("test-state") + req1 := httptest.NewRequest(http.MethodGet, "/callback?code=first&state=test-state", nil) w1 := httptest.NewRecorder() server.handleCallback(w1, req1) // Second callback - should not overwrite - req2 := httptest.NewRequest(http.MethodGet, "/callback?code=second", nil) + req2 := httptest.NewRequest(http.MethodGet, "/callback?code=second&state=test-state", nil) w2 := httptest.NewRecorder() server.handleCallback(w2, req2) diff --git a/internal/adapters/webhookserver/server.go b/internal/adapters/webhookserver/server.go index 5036ba6..62924fa 100644 --- a/internal/adapters/webhookserver/server.go +++ b/internal/adapters/webhookserver/server.go @@ -203,12 +203,26 @@ func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) { } defer func() { _ = r.Body.Close() }() + signature := r.Header.Get("X-Nylas-Signature") + if s.config.WebhookSecret != "" { + if signature == "" { + http.Error(w, "Missing webhook signature", http.StatusUnauthorized) + return + } + if !s.verifySignature(body, signature) { + http.Error(w, "Invalid webhook signature", http.StatusForbidden) + return + } + } + // Parse webhook event event := &ports.WebhookEvent{ Timestamp: time.Now(), ReceivedAt: time.Now(), Headers: make(map[string]string), RawBody: body, + Signature: signature, + Verified: s.config.WebhookSecret != "", } // Copy relevant headers @@ -218,14 +232,6 @@ func (s *Server) handleWebhook(w http.ResponseWriter, r *http.Request) { } } - // Get signature header - event.Signature = r.Header.Get("X-Nylas-Signature") - - // Verify signature if secret is configured - if s.config.WebhookSecret != "" && event.Signature != "" { - event.Verified = s.verifySignature(body, event.Signature) - } - // Parse JSON body var payload map[string]any if err := json.Unmarshal(body, &payload); err == nil { diff --git a/internal/adapters/webhookserver/server_test.go b/internal/adapters/webhookserver/server_test.go index bef3b94..dad3bda 100644 --- a/internal/adapters/webhookserver/server_test.go +++ b/internal/adapters/webhookserver/server_test.go @@ -3,6 +3,9 @@ package webhookserver import ( "bytes" "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" "net/http" "net/http/httptest" @@ -125,6 +128,56 @@ func TestServer_HandleWebhook(t *testing.T) { }) } +func TestServer_HandleWebhook_RejectsInvalidSignatures(t *testing.T) { + secret := "test-webhook-secret" + server := NewServer(ports.WebhookServerConfig{ + Port: 3008, + Path: "/webhook", + WebhookSecret: secret, + }) + handler := http.HandlerFunc(server.handleWebhook) + payload := []byte(`{"type":"message.created","id":"event-123"}`) + + t.Run("missing signature", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(payload)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusUnauthorized, rec.Code) + assert.Equal(t, 0, server.GetStats().EventsReceived) + select { + case event := <-server.Events(): + t.Fatalf("unexpected event received: %+v", event) + default: + } + }) + + t.Run("invalid signature", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(payload)) + req.Header.Set("X-Nylas-Signature", "invalid-signature") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusForbidden, rec.Code) + assert.Equal(t, 0, server.GetStats().EventsReceived) + }) + + t.Run("valid signature", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook", bytes.NewReader(payload)) + req.Header.Set("X-Nylas-Signature", signWebhookPayload(secret, payload)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, 1, server.GetStats().EventsReceived) + event := <-server.Events() + assert.True(t, event.Verified) + }) +} + func TestServer_HandleHealth(t *testing.T) { server := NewServer(ports.WebhookServerConfig{ Port: 3003, @@ -259,3 +312,9 @@ func TestServer_GetPublicURL(t *testing.T) { url := server.GetPublicURL() assert.Equal(t, "http://localhost:8080/webhook", url) } + +func signWebhookPayload(secret string, payload []byte) string { + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write(payload) + return hex.EncodeToString(mac.Sum(nil)) +} diff --git a/internal/app/auth/service.go b/internal/app/auth/service.go index 5fa37a1..05fd915 100644 --- a/internal/app/auth/service.go +++ b/internal/app/auth/service.go @@ -3,6 +3,9 @@ package auth import ( "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" @@ -42,20 +45,39 @@ func (s *Service) Login(ctx context.Context, provider domain.Provider) (*domain. } defer func() { _ = s.server.Stop() }() + state, err := generateOAuthState() + if err != nil { + return nil, err + } + codeVerifier, codeChallenge, err := generatePKCEPair() + if err != nil { + return nil, err + } + + redirectURI := s.server.GetRedirectURI() + callbackCh := make(chan oauthCallbackResult, 1) + waitCtx, waitCancel := context.WithCancel(ctx) + defer waitCancel() + + go func() { + code, waitErr := s.server.WaitForCallback(waitCtx, state) + callbackCh <- oauthCallbackResult{code: code, err: waitErr} + }() + // Build auth URL and open browser - authURL := s.client.BuildAuthURL(provider, s.server.GetRedirectURI()) + authURL := s.client.BuildAuthURL(provider, redirectURI, state, codeChallenge) if err := s.browser.Open(authURL); err != nil { return nil, err } // Wait for callback - code, err := s.server.WaitForCallback(ctx) - if err != nil { - return nil, err + callback := <-callbackCh + if callback.err != nil { + return nil, callback.err } // Exchange code for tokens - grant, err := s.client.ExchangeCode(ctx, code, s.server.GetRedirectURI()) + grant, err := s.client.ExchangeCode(ctx, callback.code, redirectURI, codeVerifier) if err != nil { return nil, err } @@ -70,20 +92,11 @@ func (s *Service) Login(ctx context.Context, provider domain.Provider) (*domain. return nil, err } - // Set as default if no default exists or this is the first grant - isFirstGrant := false + // Set as default if no default exists. if _, err := s.grantStore.GetDefaultGrant(); err == domain.ErrNoDefaultGrant { _ = s.grantStore.SetDefaultGrant(grant.ID) - isFirstGrant = true - } - - // Update config with grant (only update default if this is the first grant) - cfg, _ := s.config.Load() - cfg.Grants = append(cfg.Grants, grantInfo) - if isFirstGrant { - cfg.DefaultGrant = grant.ID } - _ = s.config.Save(cfg) + s.syncConfigWithGrantStore() return grant, nil } @@ -130,6 +143,26 @@ func (s *Service) LogoutGrant(ctx context.Context, grantID string) error { // Auto-switch to another grant if we deleted the default if isDefault { s.autoSwitchDefault() + } else { + s.syncConfigWithGrantStore() + } + + return nil +} + +// RemoveLocalGrant removes a grant from local storage without revoking it on Nylas. +func (s *Service) RemoveLocalGrant(grantID string) error { + defaultID, _ := s.grantStore.GetDefaultGrant() + isDefault := grantID == defaultID + + if err := s.grantStore.DeleteGrant(grantID); err != nil { + return err + } + + if isDefault { + s.autoSwitchDefault() + } else { + s.syncConfigWithGrantStore() } return nil @@ -138,11 +171,70 @@ func (s *Service) LogoutGrant(ctx context.Context, grantID string) error { // autoSwitchDefault sets a new default grant from remaining grants. func (s *Service) autoSwitchDefault() { grants, err := s.grantStore.ListGrants() - if err != nil || len(grants) == 0 { + if err != nil { + return + } + if len(grants) == 0 { // No remaining grants - clear the default _ = s.grantStore.ClearGrants() + s.syncConfigWithGrantStore() return } // Set the first remaining grant as default - _ = s.grantStore.SetDefaultGrant(grants[0].ID) + if err := s.grantStore.SetDefaultGrant(grants[0].ID); err != nil { + return + } + s.syncConfigWithGrantStore() +} + +func (s *Service) syncConfigWithGrantStore() { + grants, err := s.grantStore.ListGrants() + if err != nil { + return + } + + defaultGrant, err := s.grantStore.GetDefaultGrant() + if err == domain.ErrNoDefaultGrant { + defaultGrant = "" + } else if err != nil { + return + } + + cfg, err := s.config.Load() + if err != nil { + return + } + + cfg.Grants = append([]domain.GrantInfo(nil), grants...) + cfg.DefaultGrant = defaultGrant + _ = s.config.Save(cfg) +} + +func generateOAuthState() (string, error) { + return generateOAuthToken(32) +} + +func generatePKCEPair() (string, string, error) { + verifier, err := generateOAuthToken(32) + if err != nil { + return "", "", err + } + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + return verifier, challenge, nil +} + +func generateOAuthToken(size int) (string, error) { + token := make([]byte, size) + if _, err := rand.Read(token); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(token), nil +} + +type oauthCallbackResult struct { + code string + err error } diff --git a/internal/app/auth/service_test.go b/internal/app/auth/service_test.go index 04ab502..5f98786 100644 --- a/internal/app/auth/service_test.go +++ b/internal/app/auth/service_test.go @@ -2,6 +2,8 @@ package auth import ( "context" + "crypto/sha256" + "encoding/base64" "errors" "testing" @@ -107,6 +109,8 @@ func TestService_autoSwitchDefault(t *testing.T) { t.Run("clears default when no grants remain", func(t *testing.T) { grantStore := newMockGrantStore() configStore := newMockConfigStore() + configStore.config.Grants = []domain.GrantInfo{{ID: "old-deleted-grant", Email: "user@example.com"}} + configStore.config.DefaultGrant = "old-deleted-grant" // Set up a default grant that doesn't exist in grants list grantStore.defaultGrant = "old-deleted-grant" @@ -122,6 +126,8 @@ func TestService_autoSwitchDefault(t *testing.T) { // Default should be cleared _, err := grantStore.GetDefaultGrant() assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) + assert.Empty(t, configStore.config.DefaultGrant) + assert.Empty(t, configStore.config.Grants) }) t.Run("sets first remaining grant as default", func(t *testing.T) { @@ -144,6 +150,8 @@ func TestService_autoSwitchDefault(t *testing.T) { defaultID, err := grantStore.GetDefaultGrant() require.NoError(t, err) assert.Contains(t, []string{"grant-1", "grant-2"}, defaultID) + assert.Equal(t, defaultID, configStore.config.DefaultGrant) + assert.Len(t, configStore.config.Grants, 2) }) } @@ -238,12 +246,13 @@ func TestService_FirstGrantBecomesDefault(t *testing.T) { // mockOAuthServer implements ports.OAuthServer for testing type mockOAuthServer struct { - redirectURI string - code string - startErr error - waitErr error - startCalled bool - stopCalled bool + redirectURI string + code string + expectedState string + startErr error + waitErr error + startCalled bool + stopCalled bool } func (m *mockOAuthServer) Start() error { @@ -260,7 +269,8 @@ func (m *mockOAuthServer) GetRedirectURI() string { return m.redirectURI } -func (m *mockOAuthServer) WaitForCallback(ctx context.Context) (string, error) { +func (m *mockOAuthServer) WaitForCallback(ctx context.Context, expectedState string) (string, error) { + m.expectedState = expectedState if m.waitErr != nil { return "", m.waitErr } @@ -298,7 +308,16 @@ func TestNewService(t *testing.T) { func TestService_Login(t *testing.T) { t.Run("successful login sets grant as default", func(t *testing.T) { client := nylas.NewMockClient() - client.ExchangeCodeFunc = func(ctx context.Context, code, redirectURI string) (*domain.Grant, error) { + var capturedState string + var capturedChallenge string + client.BuildAuthURLFunc = func(provider domain.Provider, redirectURI, state, codeChallenge string) string { + capturedState = state + capturedChallenge = codeChallenge + return "https://mock.nylas.com/auth?state=" + state + } + client.ExchangeCodeFunc = func(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { + assert.Equal(t, "auth-code-123", code) + assert.Equal(t, capturedChallenge, pkceChallenge(codeVerifier)) return &domain.Grant{ ID: "grant-123", Email: "user@example.com", @@ -327,8 +346,11 @@ func TestService_Login(t *testing.T) { assert.True(t, server.startCalled) assert.True(t, server.stopCalled) - // Verify browser was used (MockClient returns default auth URL) - assert.NotEmpty(t, browser.openedURL) + // Verify browser and callback state/PKCE values were wired through. + assert.Equal(t, "https://mock.nylas.com/auth?state="+capturedState, browser.openedURL) + assert.Equal(t, capturedState, server.expectedState) + assert.NotEmpty(t, capturedState) + assert.NotEmpty(t, capturedChallenge) // Verify grant was saved savedGrant, err := grantStore.GetGrant("grant-123") @@ -339,6 +361,8 @@ func TestService_Login(t *testing.T) { defaultID, err := grantStore.GetDefaultGrant() require.NoError(t, err) assert.Equal(t, "grant-123", defaultID) + assert.Equal(t, "grant-123", configStore.config.DefaultGrant) + assert.Len(t, configStore.config.Grants, 1) }) t.Run("server start failure returns error", func(t *testing.T) { @@ -392,7 +416,7 @@ func TestService_Login(t *testing.T) { t.Run("code exchange failure returns error", func(t *testing.T) { client := nylas.NewMockClient() - client.ExchangeCodeFunc = func(ctx context.Context, code, redirectURI string) (*domain.Grant, error) { + client.ExchangeCodeFunc = func(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) { return nil, errors.New("code exchange failed") } @@ -415,8 +439,11 @@ func TestService_Logout(t *testing.T) { grantStore := newMockGrantStore() grantStore.grants["grant-123"] = domain.GrantInfo{ID: "grant-123", Email: "user@example.com"} grantStore.defaultGrant = "grant-123" + configStore := newMockConfigStore() + configStore.config.Grants = []domain.GrantInfo{{ID: "grant-123", Email: "user@example.com"}} + configStore.config.DefaultGrant = "grant-123" - svc := NewService(client, grantStore, newMockConfigStore(), &mockOAuthServer{}, &mockBrowser{}) + svc := NewService(client, grantStore, configStore, &mockOAuthServer{}, &mockBrowser{}) err := svc.Logout(context.Background()) @@ -432,6 +459,8 @@ func TestService_Logout(t *testing.T) { // Verify default was cleared _, err = grantStore.GetDefaultGrant() assert.ErrorIs(t, err, domain.ErrNoDefaultGrant) + assert.Empty(t, configStore.config.DefaultGrant) + assert.Empty(t, configStore.config.Grants) }) t.Run("no default grant returns error", func(t *testing.T) { @@ -488,8 +517,14 @@ func TestService_Logout(t *testing.T) { grantStore.grants["grant-1"] = domain.GrantInfo{ID: "grant-1", Email: "user1@example.com"} grantStore.grants["grant-2"] = domain.GrantInfo{ID: "grant-2", Email: "user2@example.com"} grantStore.defaultGrant = "grant-1" + configStore := newMockConfigStore() + configStore.config.Grants = []domain.GrantInfo{ + {ID: "grant-1", Email: "user1@example.com"}, + {ID: "grant-2", Email: "user2@example.com"}, + } + configStore.config.DefaultGrant = "grant-1" - svc := NewService(client, grantStore, newMockConfigStore(), &mockOAuthServer{}, &mockBrowser{}) + svc := NewService(client, grantStore, configStore, &mockOAuthServer{}, &mockBrowser{}) err := svc.Logout(context.Background()) @@ -503,6 +538,9 @@ func TestService_Logout(t *testing.T) { defaultID, err := grantStore.GetDefaultGrant() require.NoError(t, err) assert.Equal(t, "grant-2", defaultID) + assert.Equal(t, "grant-2", configStore.config.DefaultGrant) + assert.Len(t, configStore.config.Grants, 1) + assert.Equal(t, "grant-2", configStore.config.Grants[0].ID) }) } @@ -513,8 +551,14 @@ func TestService_LogoutGrant(t *testing.T) { grantStore.grants["grant-1"] = domain.GrantInfo{ID: "grant-1", Email: "user1@example.com"} grantStore.grants["grant-2"] = domain.GrantInfo{ID: "grant-2", Email: "user2@example.com"} grantStore.defaultGrant = "grant-1" + configStore := newMockConfigStore() + configStore.config.Grants = []domain.GrantInfo{ + {ID: "grant-1", Email: "user1@example.com"}, + {ID: "grant-2", Email: "user2@example.com"}, + } + configStore.config.DefaultGrant = "grant-1" - svc := NewService(client, grantStore, newMockConfigStore(), &mockOAuthServer{}, &mockBrowser{}) + svc := NewService(client, grantStore, configStore, &mockOAuthServer{}, &mockBrowser{}) err := svc.LogoutGrant(context.Background(), "grant-2") @@ -528,6 +572,9 @@ func TestService_LogoutGrant(t *testing.T) { defaultID, err := grantStore.GetDefaultGrant() require.NoError(t, err) assert.Equal(t, "grant-1", defaultID) + assert.Equal(t, "grant-1", configStore.config.DefaultGrant) + assert.Len(t, configStore.config.Grants, 1) + assert.Equal(t, "grant-1", configStore.config.Grants[0].ID) }) t.Run("logging out default grant switches to another", func(t *testing.T) { @@ -536,8 +583,14 @@ func TestService_LogoutGrant(t *testing.T) { grantStore.grants["grant-1"] = domain.GrantInfo{ID: "grant-1", Email: "user1@example.com"} grantStore.grants["grant-2"] = domain.GrantInfo{ID: "grant-2", Email: "user2@example.com"} grantStore.defaultGrant = "grant-1" + configStore := newMockConfigStore() + configStore.config.Grants = []domain.GrantInfo{ + {ID: "grant-1", Email: "user1@example.com"}, + {ID: "grant-2", Email: "user2@example.com"}, + } + configStore.config.DefaultGrant = "grant-1" - svc := NewService(client, grantStore, newMockConfigStore(), &mockOAuthServer{}, &mockBrowser{}) + svc := NewService(client, grantStore, configStore, &mockOAuthServer{}, &mockBrowser{}) err := svc.LogoutGrant(context.Background(), "grant-1") @@ -551,6 +604,9 @@ func TestService_LogoutGrant(t *testing.T) { defaultID, err := grantStore.GetDefaultGrant() require.NoError(t, err) assert.Equal(t, "grant-2", defaultID) + assert.Equal(t, "grant-2", configStore.config.DefaultGrant) + assert.Len(t, configStore.config.Grants, 1) + assert.Equal(t, "grant-2", configStore.config.Grants[0].ID) }) t.Run("grant not found on revoke is ignored", func(t *testing.T) { @@ -572,3 +628,33 @@ func TestService_LogoutGrant(t *testing.T) { assert.ErrorIs(t, err, domain.ErrGrantNotFound) }) } + +func TestService_RemoveLocalGrant(t *testing.T) { + grantStore := newMockGrantStore() + grantStore.grants["grant-1"] = domain.GrantInfo{ID: "grant-1", Email: "user1@example.com"} + grantStore.grants["grant-2"] = domain.GrantInfo{ID: "grant-2", Email: "user2@example.com"} + grantStore.defaultGrant = "grant-1" + configStore := newMockConfigStore() + configStore.config.Grants = []domain.GrantInfo{ + {ID: "grant-1", Email: "user1@example.com"}, + {ID: "grant-2", Email: "user2@example.com"}, + } + configStore.config.DefaultGrant = "grant-1" + + svc := NewService(nylas.NewMockClient(), grantStore, configStore, &mockOAuthServer{}, &mockBrowser{}) + + err := svc.RemoveLocalGrant("grant-1") + require.NoError(t, err) + + defaultID, err := grantStore.GetDefaultGrant() + require.NoError(t, err) + assert.Equal(t, "grant-2", defaultID) + assert.Equal(t, "grant-2", configStore.config.DefaultGrant) + assert.Len(t, configStore.config.Grants, 1) + assert.Equal(t, "grant-2", configStore.config.Grants[0].ID) +} + +func pkceChallenge(verifier string) string { + hash := sha256.Sum256([]byte(verifier)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} diff --git a/internal/cli/auth/providers.go b/internal/cli/auth/providers.go index 49088c8..b74b718 100644 --- a/internal/cli/auth/providers.go +++ b/internal/cli/auth/providers.go @@ -3,10 +3,13 @@ package auth import ( "encoding/json" "fmt" + "io" + "strings" "github.com/spf13/cobra" "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/domain" ) func newProvidersCmd() *cobra.Command { @@ -50,26 +53,7 @@ This command shows connectors configured for your Nylas application.`, return enc.Encode(connectors) } - // Display as table - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Available Authentication Providers:") - _, _ = fmt.Fprintln(cmd.OutOrStdout()) - - if len(connectors) == 0 { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "No providers configured.") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "\nTo add a provider, use: nylas admin connectors create") - return nil - } - - for _, connector := range connectors { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), " %s\n", connector.Provider) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), " Name: %s\n", connector.Name) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), " ID: %s\n", connector.ID) - if len(connector.Scopes) > 0 { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), " Scopes: %d configured\n", len(connector.Scopes)) - } - _, _ = fmt.Fprintln(cmd.OutOrStdout()) - } - + renderProviders(cmd.OutOrStdout(), connectors) return nil }, } @@ -78,3 +62,67 @@ This command shows connectors configured for your Nylas application.`, return cmd } + +func renderProviders(w io.Writer, connectors []domain.Connector) { + _, _ = fmt.Fprintln(w, "Available Authentication Providers:") + _, _ = fmt.Fprintln(w) + + if len(connectors) == 0 { + _, _ = fmt.Fprintln(w, "No providers configured.") + _, _ = fmt.Fprintln(w, "\nTo add a provider, use: nylas admin connectors create") + return + } + + for _, connector := range connectors { + title := connector.Name + if title == "" { + title = providerDisplayName(connector.Provider) + } + + _, _ = fmt.Fprintf(w, " %s\n", title) + _, _ = fmt.Fprintf(w, " Provider: %s\n", connector.Provider) + if connector.Name != "" && connector.Name != title { + _, _ = fmt.Fprintf(w, " Name: %s\n", connector.Name) + } + if connector.ID != "" { + _, _ = fmt.Fprintf(w, " ID: %s\n", connector.ID) + } + if len(connector.Scopes) > 0 { + _, _ = fmt.Fprintf(w, " Scopes: %d configured\n", len(connector.Scopes)) + } + _, _ = fmt.Fprintln(w) + } +} + +func providerDisplayName(provider string) string { + switch provider { + case "google": + return "Google" + case "microsoft": + return "Microsoft" + case "imap": + return "IMAP" + case "icloud": + return "iCloud" + case "ews": + return "EWS" + case "inbox": + return "Inbox" + case "virtual-calendar": + return "Virtual Calendar" + default: + return titleProviderName(provider) + } +} + +func titleProviderName(provider string) string { + normalized := strings.ReplaceAll(strings.ReplaceAll(provider, "-", " "), "_", " ") + parts := strings.Fields(normalized) + for i, part := range parts { + if part == "" { + continue + } + parts[i] = strings.ToUpper(part[:1]) + part[1:] + } + return strings.Join(parts, " ") +} diff --git a/internal/cli/auth/providers_test.go b/internal/cli/auth/providers_test.go index aa8caf9..320734f 100644 --- a/internal/cli/auth/providers_test.go +++ b/internal/cli/auth/providers_test.go @@ -8,6 +8,7 @@ import ( "github.com/nylas/cli/internal/adapters/config" "github.com/nylas/cli/internal/adapters/keyring" + "github.com/nylas/cli/internal/domain" "github.com/nylas/cli/internal/ports" ) @@ -82,3 +83,98 @@ func TestProvidersCmd(t *testing.T) { }) } } + +func TestRenderProviders(t *testing.T) { + tests := []struct { + name string + connectors []domain.Connector + wantContain []string + wantAbsent []string + }{ + { + name: "omits empty connector fields", + connectors: []domain.Connector{ + { + Provider: "google", + Settings: &domain.ConnectorSettings{ClientID: "client-id"}, + }, + }, + wantContain: []string{ + "Available Authentication Providers:", + " Google", + " Provider: google", + }, + wantAbsent: []string{ + "Name: ", + "ID: ", + }, + }, + { + name: "prints populated connector metadata", + connectors: []domain.Connector{ + { + ID: "conn-imap-1", + Name: "Custom IMAP", + Provider: "imap", + Scopes: []string{"mail.read_only", "mail.send"}, + }, + }, + wantContain: []string{ + " Custom IMAP", + " Provider: imap", + " ID: conn-imap-1", + " Scopes: 2 configured", + }, + }, + { + name: "shows empty state", + connectors: nil, + wantContain: []string{ + "No providers configured.", + "nylas admin connectors create", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + renderProviders(&buf, tt.connectors) + + output := buf.String() + for _, want := range tt.wantContain { + if !strings.Contains(output, want) { + t.Fatalf("renderProviders() output = %q, want to contain %q", output, want) + } + } + for _, unwanted := range tt.wantAbsent { + if strings.Contains(output, unwanted) { + t.Fatalf("renderProviders() output = %q, should not contain %q", output, unwanted) + } + } + }) + } +} + +func TestProviderDisplayName(t *testing.T) { + tests := []struct { + provider string + want string + }{ + {provider: "google", want: "Google"}, + {provider: "microsoft", want: "Microsoft"}, + {provider: "imap", want: "IMAP"}, + {provider: "icloud", want: "iCloud"}, + {provider: "ews", want: "EWS"}, + {provider: "virtual-calendar", want: "Virtual Calendar"}, + {provider: "custom-provider", want: "Custom Provider"}, + } + + for _, tt := range tests { + t.Run(tt.provider, func(t *testing.T) { + if got := providerDisplayName(tt.provider); got != tt.want { + t.Fatalf("providerDisplayName(%q) = %q, want %q", tt.provider, got, tt.want) + } + }) + } +} diff --git a/internal/cli/auth/remove.go b/internal/cli/auth/remove.go index 86ab8f6..b02ef59 100644 --- a/internal/cli/auth/remove.go +++ b/internal/cli/auth/remove.go @@ -26,13 +26,18 @@ on the Nylas server.`, return err } + authSvc, _, err := createAuthService() + if err != nil { + return err + } + // Check if grant exists locally if _, err := grantStore.GetGrant(grantID); err != nil { return err } // Remove from local store only - if err := grantStore.DeleteGrant(grantID); err != nil { + if err := authSvc.RemoveLocalGrant(grantID); err != nil { return err } diff --git a/internal/cli/common/client.go b/internal/cli/common/client.go index 2feffa5..9334415 100644 --- a/internal/cli/common/client.go +++ b/internal/cli/common/client.go @@ -2,6 +2,7 @@ package common import ( "context" + "errors" "fmt" "os" "strings" @@ -49,14 +50,25 @@ func GetNylasClient() (ports.NylasClient, error) { // If API key not in env, try keyring/file store if apiKey == "" { - secretStore, err := keyring.NewSecretStore(config.DefaultConfigDir()) - if err == nil { - apiKey, _ = secretStore.Get(ports.KeyAPIKey) - if clientID == "" { - clientID, _ = secretStore.Get(ports.KeyClientID) + secretStore, err := openSecretStore() + if err != nil { + return nil, err + } + + apiKey, err = getStoredSecret(secretStore, ports.KeyAPIKey) + if err != nil { + return nil, err + } + if clientID == "" { + clientID, err = getStoredSecret(secretStore, ports.KeyClientID) + if err != nil { + return nil, err } - if clientSecret == "" { - clientSecret, _ = secretStore.Get(ports.KeyClientSecret) + } + if clientSecret == "" { + clientSecret, err = getStoredSecret(secretStore, ports.KeyClientSecret) + if err != nil { + return nil, err } } } @@ -118,9 +130,13 @@ func GetAPIKey() (string, error) { // If not in env, try keyring/file store if apiKey == "" { - secretStore, err := keyring.NewSecretStore(config.DefaultConfigDir()) - if err == nil { - apiKey, _ = secretStore.Get(ports.KeyAPIKey) + secretStore, err := openSecretStore() + if err != nil { + return "", err + } + apiKey, err = getStoredSecret(secretStore, ports.KeyAPIKey) + if err != nil { + return "", err } } @@ -140,69 +156,66 @@ func GetAPIKey() (string, error) { // It checks in this order: // 1. Command line argument (if provided) - supports email lookup if arg contains "@" // 2. Environment variable (NYLAS_GRANT_ID) -// 3. Config file (default_grant) -// 4. Stored default grant (from keyring/file) +// 3. Stored default grant (from keyring/file) +// 4. Config file (default_grant) func GetGrantID(args []string) (string, error) { - secretStore, err := keyring.NewSecretStore(config.DefaultConfigDir()) - if err != nil { - // Fall back to env var and config only if keyring unavailable - if grantID := os.Getenv("NYLAS_GRANT_ID"); grantID != "" { - return grantID, nil - } + // If provided as argument + if len(args) > 0 && args[0] != "" { + identifier := args[0] - // Try config file - configStore := config.NewDefaultFileStore() - cfg, err := configStore.Load() - if err == nil && cfg.DefaultGrant != "" { - return cfg.DefaultGrant, nil + // Direct grant IDs should not depend on local secret-store health. + if !containsAt(identifier) { + return identifier, nil } + } - return "", fmt.Errorf("couldn't access secret store and NYLAS_GRANT_ID not set: %w", err) + // Check environment variable + if grantID := os.Getenv("NYLAS_GRANT_ID"); grantID != "" { + return grantID, nil + } + + secretStore, err := openSecretStore() + if err != nil { + return "", err } grantStore := keyring.NewGrantStore(secretStore) - // If provided as argument + // Email arguments require a local grant lookup. if len(args) > 0 && args[0] != "" { identifier := args[0] - - // If it looks like an email, try to find by email - if containsAt(identifier) { - grant, err := grantStore.GetGrantByEmail(identifier) - if err != nil { + grant, err := grantStore.GetGrantByEmail(identifier) + if err != nil { + if errors.Is(err, domain.ErrGrantNotFound) { return "", fmt.Errorf("no grant found for email: %s", identifier) } - return grant.ID, nil + return "", wrapSecretStoreError(err) } - - // Otherwise treat as grant ID - return identifier, nil + return grant.ID, nil } - // Check environment variable - if grantID := os.Getenv("NYLAS_GRANT_ID"); grantID != "" { + // Try to get default grant from keyring first - it is the authoritative source. + grantID, err := grantStore.GetDefaultGrant() + switch { + case err == nil: return grantID, nil + case !errors.Is(err, domain.ErrNoDefaultGrant): + return "", wrapSecretStoreError(err) } - // Check config file + // Fall back to config for backward compatibility with older setups. configStore := config.NewDefaultFileStore() - cfg, err := configStore.Load() - if err == nil && cfg.DefaultGrant != "" { + cfg, cfgErr := configStore.Load() + if cfgErr == nil && cfg.DefaultGrant != "" { return cfg.DefaultGrant, nil } - // Try to get default grant from keyring - grantID, err := grantStore.GetDefaultGrant() - if err != nil { - return "", NewUserErrorWithSuggestions( - "No grant ID provided", - "Check available grants with: nylas auth list", - "Set default grant with: nylas config set default_grant ", - "Use environment variable: export NYLAS_GRANT_ID=", - "Or specify as argument: nylas [command] ", - ) - } - - return grantID, nil + return "", NewUserErrorWithSuggestions( + "No grant ID provided", + "Check available grants with: nylas auth list", + "Set default grant with: nylas config set default_grant ", + "Use environment variable: export NYLAS_GRANT_ID=", + "Or specify as argument: nylas [command] ", + ) } // containsAt checks if a string contains "@" (for email detection). @@ -210,6 +223,33 @@ func containsAt(s string) bool { return strings.ContainsRune(s, '@') } +func openSecretStore() (ports.SecretStore, error) { + secretStore, err := keyring.NewSecretStore(config.DefaultConfigDir()) + if err != nil { + return nil, wrapSecretStoreError(err) + } + return secretStore, nil +} + +func getStoredSecret(secretStore ports.SecretStore, key string) (string, error) { + value, err := secretStore.Get(key) + switch { + case err == nil: + return value, nil + case errors.Is(err, domain.ErrSecretNotFound): + return "", nil + default: + return "", wrapSecretStoreError(err) + } +} + +func wrapSecretStoreError(err error) error { + if err == nil || errors.Is(err, domain.ErrSecretStoreFailed) { + return err + } + return fmt.Errorf("%w: %v", domain.ErrSecretStoreFailed, err) +} + // WithClient is a generic helper that handles client setup, context creation, and grant ID resolution. // This reduces boilerplate in commands by handling all the common setup in one place. // diff --git a/internal/cli/common/client_test.go b/internal/cli/common/client_test.go index eaa0817..ffbc92d 100644 --- a/internal/cli/common/client_test.go +++ b/internal/cli/common/client_test.go @@ -4,8 +4,13 @@ package common import ( "os" + "path/filepath" "testing" + "github.com/nylas/cli/internal/adapters/config" + "github.com/nylas/cli/internal/adapters/keyring" + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -149,59 +154,29 @@ func TestGetAPIKey_NoAPIKey(t *testing.T) { } func TestGetGrantID_WithArgument(t *testing.T) { - // Save original env var - origGrantID := os.Getenv("NYLAS_GRANT_ID") - origDisableKeyring := os.Getenv("NYLAS_DISABLE_KEYRING") - - // Restore after test - defer func() { - setEnvOrUnset("NYLAS_GRANT_ID", origGrantID) - setEnvOrUnset("NYLAS_DISABLE_KEYRING", origDisableKeyring) - }() - - _ = os.Setenv("NYLAS_DISABLE_KEYRING", "true") - _ = os.Unsetenv("NYLAS_GRANT_ID") - - // Test with direct grant ID argument (not email) - args := []string{"grant-id-12345"} + configDir := seedLockedFileStore(t, func(store *keyring.EncryptedFileStore) { + require.NoError(t, store.Set("placeholder", "value")) + }) + require.DirExists(t, configDir) + t.Setenv("NYLAS_GRANT_ID", "") - grantID, err := GetGrantID(args) + grantID, err := GetGrantID([]string{"grant-id-12345"}) - // This may fail if keyring is not accessible, which is expected in test env - if err != nil { - // If keyring not accessible, check the error message - assert.Contains(t, err.Error(), "secret store") - } else { - assert.Equal(t, "grant-id-12345", grantID) - } + require.NoError(t, err) + assert.Equal(t, "grant-id-12345", grantID) } func TestGetGrantID_WithEnvVar(t *testing.T) { - // Save original env var - origGrantID := os.Getenv("NYLAS_GRANT_ID") - origDisableKeyring := os.Getenv("NYLAS_DISABLE_KEYRING") - - // Restore after test - defer func() { - setEnvOrUnset("NYLAS_GRANT_ID", origGrantID) - setEnvOrUnset("NYLAS_DISABLE_KEYRING", origDisableKeyring) - }() - testGrantID := "env-grant-id-67890" - _ = os.Setenv("NYLAS_GRANT_ID", testGrantID) - _ = os.Setenv("NYLAS_DISABLE_KEYRING", "true") + seedLockedFileStore(t, func(store *keyring.EncryptedFileStore) { + require.NoError(t, store.Set("placeholder", "value")) + }) + t.Setenv("NYLAS_GRANT_ID", testGrantID) - // Test with empty args - should fall back to env var grantID, err := GetGrantID([]string{}) - // This may fail if keyring is not accessible - if err != nil { - // If keyring fails but we have env var, we should still get the grant ID - // The function tries keyring first, so we need to check behavior - t.Logf("Error (expected in test env): %v", err) - } else { - assert.Equal(t, testGrantID, grantID) - } + require.NoError(t, err) + assert.Equal(t, testGrantID, grantID) } func TestGetGrantID_EmptyArgs(t *testing.T) { @@ -232,29 +207,79 @@ func TestGetGrantID_EmptyArgs(t *testing.T) { } func TestGetGrantID_EmptyStringArg(t *testing.T) { - // Save original env vars + testGrantID := "env-grant-fallback" + seedLockedFileStore(t, func(store *keyring.EncryptedFileStore) { + require.NoError(t, store.Set("placeholder", "value")) + }) + t.Setenv("NYLAS_GRANT_ID", testGrantID) + + grantID, err := GetGrantID([]string{""}) + + require.NoError(t, err) + assert.Equal(t, testGrantID, grantID) +} + +func TestGetGrantID_PrefersStoredDefaultOverStaleConfig(t *testing.T) { origGrantID := os.Getenv("NYLAS_GRANT_ID") origDisableKeyring := os.Getenv("NYLAS_DISABLE_KEYRING") + origFileStorePassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + origXDGConfigHome := os.Getenv("XDG_CONFIG_HOME") + origHome := os.Getenv("HOME") - // Restore after test defer func() { setEnvOrUnset("NYLAS_GRANT_ID", origGrantID) setEnvOrUnset("NYLAS_DISABLE_KEYRING", origDisableKeyring) + setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origFileStorePassphrase) + setEnvOrUnset("XDG_CONFIG_HOME", origXDGConfigHome) + setEnvOrUnset("HOME", origHome) }() - testGrantID := "env-grant-fallback" - _ = os.Setenv("NYLAS_GRANT_ID", testGrantID) + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + _ = os.Setenv("XDG_CONFIG_HOME", configHome) + _ = os.Setenv("HOME", tempDir) _ = os.Setenv("NYLAS_DISABLE_KEYRING", "true") + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "test-file-store-passphrase") + _ = os.Unsetenv("NYLAS_GRANT_ID") - // Test with empty string arg - should fall back to env var - grantID, err := GetGrantID([]string{""}) + configStore := config.NewFileStore(filepath.Join(configHome, "nylas", "config.yaml")) + require.NoError(t, configStore.Save(&domain.Config{ + Region: "us", + DefaultGrant: "stale-config-grant", + Grants: []domain.GrantInfo{{ID: "stale-config-grant", Email: "stale@example.com"}}, + })) - // May fail due to keyring access - if err != nil { - t.Logf("Error (expected in test env): %v", err) - } else { - assert.Equal(t, testGrantID, grantID) - } + secretStore, err := keyring.NewEncryptedFileStore(filepath.Join(configHome, "nylas")) + require.NoError(t, err) + grantStore := keyring.NewGrantStore(secretStore) + require.NoError(t, grantStore.SaveGrant(domain.GrantInfo{ID: "stored-default", Email: "active@example.com"})) + require.NoError(t, grantStore.SetDefaultGrant("stored-default")) + + grantID, err := GetGrantID(nil) + require.NoError(t, err) + assert.Equal(t, "stored-default", grantID) +} + +func TestGetGrantID_DoesNotFallbackToConfigWhenStoreLocked(t *testing.T) { + configDir := seedLockedFileStore(t, func(store *keyring.EncryptedFileStore) { + grantStore := keyring.NewGrantStore(store) + require.NoError(t, grantStore.SaveGrant(domain.GrantInfo{ID: "stored-default", Email: "active@example.com"})) + require.NoError(t, grantStore.SetDefaultGrant("stored-default")) + }) + + configStore := config.NewFileStore(filepath.Join(configDir, "config.yaml")) + require.NoError(t, configStore.Save(&domain.Config{ + Region: "us", + DefaultGrant: "stale-config-grant", + Grants: []domain.GrantInfo{{ID: "stale-config-grant", Email: "stale@example.com"}}, + })) + + grantID, err := GetGrantID(nil) + + require.Error(t, err) + assert.Empty(t, grantID) + assert.ErrorIs(t, err, domain.ErrSecretStoreFailed) + assert.Contains(t, err.Error(), "NYLAS_FILE_STORE_PASSPHRASE") } // setEnvOrUnset sets an environment variable if value is non-empty, otherwise unsets it @@ -266,6 +291,29 @@ func setEnvOrUnset(key, value string) { } } +func seedLockedFileStore(t *testing.T, seed func(store *keyring.EncryptedFileStore)) string { + t.Helper() + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + configDir := filepath.Join(configHome, "nylas") + t.Setenv("XDG_CONFIG_HOME", configHome) + t.Setenv("HOME", tempDir) + t.Setenv("NYLAS_DISABLE_KEYRING", "true") + t.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "test-file-store-passphrase") + t.Setenv("NYLAS_API_KEY", "") + t.Setenv("NYLAS_GRANT_ID", "") + + store, err := keyring.NewEncryptedFileStore(configDir) + require.NoError(t, err) + if seed != nil { + seed(store) + } + + t.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "") + return configDir +} + func TestGetNylasClient_EnvVarPriority(t *testing.T) { // This test verifies that environment variables take priority over keyring // Save original env vars @@ -309,6 +357,33 @@ func TestGetAPIKey_EnvVarPriority(t *testing.T) { assert.Equal(t, "priority-test-key", apiKey) } +func TestGetAPIKey_ReportsLockedFileStore(t *testing.T) { + seedLockedFileStore(t, func(store *keyring.EncryptedFileStore) { + require.NoError(t, store.Set(ports.KeyAPIKey, "stored-api-key")) + }) + + apiKey, err := GetAPIKey() + + require.Error(t, err) + assert.Empty(t, apiKey) + assert.ErrorIs(t, err, domain.ErrSecretStoreFailed) + assert.Contains(t, err.Error(), "NYLAS_FILE_STORE_PASSPHRASE") +} + +func TestGetNylasClient_ReportsLockedFileStore(t *testing.T) { + seedLockedFileStore(t, func(store *keyring.EncryptedFileStore) { + require.NoError(t, store.Set(ports.KeyAPIKey, "stored-api-key")) + require.NoError(t, store.Set(ports.KeyClientID, "stored-client-id")) + }) + + client, err := GetNylasClient() + + require.Error(t, err) + assert.Nil(t, client) + assert.ErrorIs(t, err, domain.ErrSecretStoreFailed) + assert.Contains(t, err.Error(), "NYLAS_FILE_STORE_PASSPHRASE") +} + func TestContainsAt_UnicodeSupport(t *testing.T) { tests := []struct { name string diff --git a/internal/cli/common/errors.go b/internal/cli/common/errors.go index 37f946a..f226150 100644 --- a/internal/cli/common/errors.go +++ b/internal/cli/common/errors.go @@ -92,6 +92,17 @@ func WrapError(err error) *CLIError { Code: ErrCodeNotConfigured, } + case errors.Is(err, domain.ErrSecretStoreFailed) && strings.Contains(err.Error(), "NYLAS_FILE_STORE_PASSPHRASE"): + return &CLIError{ + Err: err, + Message: "Failed to access encrypted file secret store", + Suggestions: []string{ + "Set NYLAS_FILE_STORE_PASSPHRASE before using the file-based secret store", + "Unset NYLAS_DISABLE_KEYRING to use the system keyring instead", + }, + Code: ErrCodePermissionDenied, + } + case errors.Is(err, domain.ErrSecretStoreFailed): return &CLIError{ Err: err, diff --git a/internal/cli/common/errors_test.go b/internal/cli/common/errors_test.go index f52f6ca..bee6365 100644 --- a/internal/cli/common/errors_test.go +++ b/internal/cli/common/errors_test.go @@ -4,6 +4,7 @@ package common import ( "errors" + "fmt" "strings" "testing" @@ -110,6 +111,22 @@ func TestWrapError_DomainErrors_Extended(t *testing.T) { } } +func TestWrapError_SecretStorePassphraseRequirement(t *testing.T) { + err := fmt.Errorf("%w: %s must be set to unlock the encrypted file store", domain.ErrSecretStoreFailed, "NYLAS_FILE_STORE_PASSPHRASE") + + result := WrapError(err) + + require.NotNil(t, result) + assert.Equal(t, "Failed to access encrypted file secret store", result.Message) + assert.Equal(t, ErrCodePermissionDenied, result.Code) + assert.Empty(t, result.Suggestion) + assert.Equal(t, []string{ + "Set NYLAS_FILE_STORE_PASSPHRASE before using the file-based secret store", + "Unset NYLAS_DISABLE_KEYRING to use the system keyring instead", + }, result.Suggestions) + assert.True(t, errors.Is(result, domain.ErrSecretStoreFailed)) +} + // TestWrapError_HTTPStatusPatterns tests HTTP status code patterns. func TestWrapError_HTTPStatusPatterns(t *testing.T) { tests := []struct { @@ -242,6 +259,23 @@ func TestFormatError_WithCodeAndSuggestion(t *testing.T) { assert.Contains(t, result, "• Try this fix") } +func TestFormatError_WithMultipleSuggestions(t *testing.T) { + cliErr := &CLIError{ + Message: "Secret store locked", + Code: ErrCodePermissionDenied, + Suggestions: []string{ + "Set NYLAS_FILE_STORE_PASSPHRASE", + "Unset NYLAS_DISABLE_KEYRING", + }, + } + + result := FormatError(cliErr) + + assert.Contains(t, result, "Suggestions:") + assert.Contains(t, result, "• Set NYLAS_FILE_STORE_PASSPHRASE") + assert.Contains(t, result, "• Unset NYLAS_DISABLE_KEYRING") +} + // TestErrorCodeConstants tests that all error codes are unique. func TestErrorCodeConstants(t *testing.T) { codes := map[string]bool{ diff --git a/internal/cli/doctor.go b/internal/cli/doctor.go index d38f4c7..c5b921b 100644 --- a/internal/cli/doctor.go +++ b/internal/cli/doctor.go @@ -226,9 +226,11 @@ func checkSecretStore() CheckResult { // Check if keyring is disabled via environment keyringDisabled := os.Getenv("NYLAS_DISABLE_KEYRING") == "true" - // First check if system keyring is available - kr := keyring.NewSystemKeyring() - keyringAvailable := kr.IsAvailable() + keyringAvailable := false + if !keyringDisabled { + kr := keyring.NewSystemKeyring() + keyringAvailable = kr.IsAvailable() + } secretStore, err := keyring.NewSecretStore(config.DefaultConfigDir()) if err != nil { @@ -257,7 +259,7 @@ func checkSecretStore() CheckResult { Name: "Secret Store", Status: CheckStatusWarning, Message: storeName, - Detail: "NYLAS_DISABLE_KEYRING is set. Unset to use system keyring.", + Detail: "NYLAS_DISABLE_KEYRING is set. Set NYLAS_FILE_STORE_PASSPHRASE for the fallback store, or unset NYLAS_DISABLE_KEYRING to use the system keyring.", } } @@ -266,7 +268,7 @@ func checkSecretStore() CheckResult { Name: "Secret Store", Status: CheckStatusWarning, Message: storeName, - Detail: "System keyring unavailable. Using encrypted file fallback.", + Detail: "System keyring unavailable. The encrypted file fallback requires NYLAS_FILE_STORE_PASSPHRASE.", } } diff --git a/internal/cli/doctor_test.go b/internal/cli/doctor_test.go new file mode 100644 index 0000000..579dca9 --- /dev/null +++ b/internal/cli/doctor_test.go @@ -0,0 +1,31 @@ +package cli + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestCheckSecretStore_WarnsWhenFileStoreIsForced(t *testing.T) { + tempDir := t.TempDir() + + t.Setenv("XDG_CONFIG_HOME", filepath.Join(tempDir, "xdg")) + t.Setenv("HOME", tempDir) + t.Setenv("NYLAS_DISABLE_KEYRING", "true") + t.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "doctor-test-passphrase") + + result := checkSecretStore() + + if result.Status != CheckStatusWarning { + t.Fatalf("Status = %v, want %v", result.Status, CheckStatusWarning) + } + if result.Message != "encrypted file" { + t.Fatalf("Message = %q, want %q", result.Message, "encrypted file") + } + if !strings.Contains(result.Detail, "NYLAS_FILE_STORE_PASSPHRASE") { + t.Fatalf("Detail %q does not mention NYLAS_FILE_STORE_PASSPHRASE", result.Detail) + } + if !strings.Contains(result.Detail, "unset NYLAS_DISABLE_KEYRING") { + t.Fatalf("Detail %q does not mention unsetting NYLAS_DISABLE_KEYRING", result.Detail) + } +} diff --git a/internal/cli/integration/local_regressions_test.go b/internal/cli/integration/local_regressions_test.go new file mode 100644 index 0000000..951700c --- /dev/null +++ b/internal/cli/integration/local_regressions_test.go @@ -0,0 +1,518 @@ +//go:build integration + +package integration + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "net" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/nylas/cli/internal/adapters/config" + "github.com/nylas/cli/internal/adapters/keyring" + "github.com/nylas/cli/internal/cli/common" + "github.com/nylas/cli/internal/domain" + "github.com/nylas/cli/internal/ports" +) + +func TestIntegration_GetGrantID_PrefersStoredDefaultOverConfig(t *testing.T) { + origGrantID := os.Getenv("NYLAS_GRANT_ID") + origDisableKeyring := os.Getenv("NYLAS_DISABLE_KEYRING") + origFileStorePassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + origXDGConfigHome := os.Getenv("XDG_CONFIG_HOME") + origHome := os.Getenv("HOME") + + defer func() { + setEnvOrUnset("NYLAS_GRANT_ID", origGrantID) + setEnvOrUnset("NYLAS_DISABLE_KEYRING", origDisableKeyring) + setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origFileStorePassphrase) + setEnvOrUnset("XDG_CONFIG_HOME", origXDGConfigHome) + setEnvOrUnset("HOME", origHome) + }() + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + _ = os.Setenv("XDG_CONFIG_HOME", configHome) + _ = os.Setenv("HOME", tempDir) + _ = os.Setenv("NYLAS_DISABLE_KEYRING", "true") + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "integration-test-file-store-passphrase") + _ = os.Setenv("NYLAS_GRANT_ID", "") + + configPath := filepath.Join(configHome, "nylas", "config.yaml") + configStore := config.NewFileStore(configPath) + if err := configStore.Save(&domain.Config{ + Region: "us", + DefaultGrant: "stale-config-grant", + Grants: []domain.GrantInfo{{ID: "stale-config-grant", Email: "stale@example.com"}}, + }); err != nil { + t.Fatalf("failed to save config: %v", err) + } + + secretStore, err := keyring.NewEncryptedFileStore(filepath.Dir(configPath)) + if err != nil { + t.Fatalf("failed to create secret store: %v", err) + } + grantStore := keyring.NewGrantStore(secretStore) + if err := grantStore.SaveGrant(domain.GrantInfo{ID: "stored-default", Email: "active@example.com"}); err != nil { + t.Fatalf("failed to save grant: %v", err) + } + if err := grantStore.SetDefaultGrant("stored-default"); err != nil { + t.Fatalf("failed to set default grant: %v", err) + } + + grantID, err := common.GetGrantID(nil) + if err != nil { + t.Fatalf("GetGrantID failed: %v", err) + } + if grantID != "stored-default" { + t.Fatalf("GetGrantID returned %q, want %q", grantID, "stored-default") + } +} + +func TestIntegration_GetGrantID_DoesNotFallbackToConfigWhenStoreLocked(t *testing.T) { + origGrantID := os.Getenv("NYLAS_GRANT_ID") + origDisableKeyring := os.Getenv("NYLAS_DISABLE_KEYRING") + origFileStorePassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + origXDGConfigHome := os.Getenv("XDG_CONFIG_HOME") + origHome := os.Getenv("HOME") + + defer func() { + setEnvOrUnset("NYLAS_GRANT_ID", origGrantID) + setEnvOrUnset("NYLAS_DISABLE_KEYRING", origDisableKeyring) + setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origFileStorePassphrase) + setEnvOrUnset("XDG_CONFIG_HOME", origXDGConfigHome) + setEnvOrUnset("HOME", origHome) + }() + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + configDir := filepath.Join(configHome, "nylas") + _ = os.Setenv("XDG_CONFIG_HOME", configHome) + _ = os.Setenv("HOME", tempDir) + _ = os.Setenv("NYLAS_DISABLE_KEYRING", "true") + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "integration-test-file-store-passphrase") + _ = os.Setenv("NYLAS_GRANT_ID", "") + + configStore := config.NewFileStore(filepath.Join(configDir, "config.yaml")) + if err := configStore.Save(&domain.Config{ + Region: "us", + DefaultGrant: "stale-config-grant", + Grants: []domain.GrantInfo{{ID: "stale-config-grant", Email: "stale@example.com"}}, + }); err != nil { + t.Fatalf("failed to save config: %v", err) + } + + secretStore, err := keyring.NewEncryptedFileStore(configDir) + if err != nil { + t.Fatalf("failed to create secret store: %v", err) + } + grantStore := keyring.NewGrantStore(secretStore) + if err := grantStore.SaveGrant(domain.GrantInfo{ID: "stored-default", Email: "active@example.com"}); err != nil { + t.Fatalf("failed to save grant: %v", err) + } + if err := grantStore.SetDefaultGrant("stored-default"); err != nil { + t.Fatalf("failed to set default grant: %v", err) + } + + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "") + + grantID, err := common.GetGrantID(nil) + if err == nil { + t.Fatalf("expected GetGrantID to fail, got %q", grantID) + } + if grantID != "" { + t.Fatalf("grantID = %q, want empty string", grantID) + } + if !strings.Contains(err.Error(), "NYLAS_FILE_STORE_PASSPHRASE") { + t.Fatalf("error %q does not mention NYLAS_FILE_STORE_PASSPHRASE", err.Error()) + } +} + +func TestCLI_AuthRemove_UpdatesDefaultGrantAndConfig(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + origFileStorePassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + defer setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origFileStorePassphrase) + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "integration-test-file-store-passphrase") + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + configPath := filepath.Join(configHome, "nylas", "config.yaml") + configStore := config.NewFileStore(configPath) + if err := configStore.Save(&domain.Config{ + Region: "us", + DefaultGrant: "grant-1", + Grants: []domain.GrantInfo{ + {ID: "grant-1", Email: "user1@example.com"}, + {ID: "grant-2", Email: "user2@example.com"}, + }, + }); err != nil { + t.Fatalf("failed to save config: %v", err) + } + + secretStore, err := keyring.NewEncryptedFileStore(filepath.Dir(configPath)) + if err != nil { + t.Fatalf("failed to create secret store: %v", err) + } + grantStore := keyring.NewGrantStore(secretStore) + if err := grantStore.SaveGrant(domain.GrantInfo{ID: "grant-1", Email: "user1@example.com"}); err != nil { + t.Fatalf("failed to save first grant: %v", err) + } + if err := grantStore.SaveGrant(domain.GrantInfo{ID: "grant-2", Email: "user2@example.com"}); err != nil { + t.Fatalf("failed to save second grant: %v", err) + } + if err := grantStore.SetDefaultGrant("grant-1"); err != nil { + t.Fatalf("failed to set default grant: %v", err) + } + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, map[string]string{ + "XDG_CONFIG_HOME": configHome, + "HOME": tempDir, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }, "auth", "remove", "grant-1") + if err != nil { + t.Fatalf("auth remove failed: %v\nstderr: %s", err, stderr) + } + if stdout == "" { + t.Fatal("expected auth remove output") + } + + defaultGrant, err := grantStore.GetDefaultGrant() + if err != nil { + t.Fatalf("failed to read default grant: %v", err) + } + if defaultGrant != "grant-2" { + t.Fatalf("default grant = %q, want %q", defaultGrant, "grant-2") + } + + grants, err := grantStore.ListGrants() + if err != nil { + t.Fatalf("failed to list grants: %v", err) + } + if len(grants) != 1 || grants[0].ID != "grant-2" { + t.Fatalf("unexpected grants after remove: %+v", grants) + } + + cfg, err := configStore.Load() + if err != nil { + t.Fatalf("failed to reload config: %v", err) + } + if cfg.DefaultGrant != "grant-2" { + t.Fatalf("config default grant = %q, want %q", cfg.DefaultGrant, "grant-2") + } + if len(cfg.Grants) != 1 || cfg.Grants[0].ID != "grant-2" { + t.Fatalf("unexpected config grants after remove: %+v", cfg.Grants) + } +} + +func TestCLI_AuthList_RequiresFileStorePassphrase(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + origFileStorePassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + defer setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origFileStorePassphrase) + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "integration-test-file-store-passphrase") + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + configDir := filepath.Join(configHome, "nylas") + + secretStore, err := keyring.NewEncryptedFileStore(configDir) + if err != nil { + t.Fatalf("failed to create secret store: %v", err) + } + grantStore := keyring.NewGrantStore(secretStore) + if err := grantStore.SaveGrant(domain.GrantInfo{ + ID: "grant-locked", + Email: "locked@example.com", + Provider: domain.ProviderGoogle, + }); err != nil { + t.Fatalf("failed to save grant: %v", err) + } + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, map[string]string{ + "XDG_CONFIG_HOME": configHome, + "HOME": tempDir, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + "NYLAS_FILE_STORE_PASSPHRASE": "", + }, "auth", "list") + if err == nil { + t.Fatalf("expected auth list to fail without passphrase\nstdout: %s\nstderr: %s", stdout, stderr) + } + if !strings.Contains(stderr, "NYLAS_FILE_STORE_PASSPHRASE") { + t.Fatalf("stderr %q does not mention NYLAS_FILE_STORE_PASSPHRASE", stderr) + } +} + +func TestCLI_AuthProviders_RequiresFileStorePassphrase(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + origFileStorePassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + defer setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origFileStorePassphrase) + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "integration-test-file-store-passphrase") + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + configDir := filepath.Join(configHome, "nylas") + + secretStore, err := keyring.NewEncryptedFileStore(configDir) + if err != nil { + t.Fatalf("failed to create secret store: %v", err) + } + if err := secretStore.Set(ports.KeyAPIKey, "stored-api-key"); err != nil { + t.Fatalf("failed to save api key: %v", err) + } + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, map[string]string{ + "XDG_CONFIG_HOME": configHome, + "HOME": tempDir, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + "NYLAS_FILE_STORE_PASSPHRASE": "", + }, "auth", "providers") + if err == nil { + t.Fatalf("expected auth providers to fail without passphrase\nstdout: %s\nstderr: %s", stdout, stderr) + } + if !strings.Contains(stderr, "NYLAS_FILE_STORE_PASSPHRASE") { + t.Fatalf("stderr %q does not mention NYLAS_FILE_STORE_PASSPHRASE", stderr) + } +} + +func TestCLI_AuthProviders_HidesEmptyConnectorFields(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v3/connectors" { + http.NotFound(w, r) + return + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"data":[ + {"name":"","provider":"google","settings":{"client_id":"google-client-id"}}, + {"id":"conn-imap-1","name":"Custom IMAP","provider":"imap","scopes":["mail.read_only","mail.send"]} + ]}`)) + })) + defer server.Close() + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + configPath := filepath.Join(configHome, "nylas", "config.yaml") + configStore := config.NewFileStore(configPath) + if err := configStore.Save(&domain.Config{ + Region: "us", + API: &domain.APIConfig{BaseURL: server.URL}, + }); err != nil { + t.Fatalf("failed to save config: %v", err) + } + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, map[string]string{ + "XDG_CONFIG_HOME": configHome, + "HOME": tempDir, + "NYLAS_API_KEY": "test-api-key", + "NYLAS_CLIENT_ID": "", + "NYLAS_CLIENT_SECRET": "", + "NYLAS_GRANT_ID": "", + "NYLAS_DISABLE_KEYRING": "true", + "NYLAS_FILE_STORE_PASSPHRASE": "integration-test-file-store-passphrase", + }, "auth", "providers") + if err != nil { + t.Fatalf("auth providers failed: %v\nstdout: %s\nstderr: %s", err, stdout, stderr) + } + + for _, want := range []string{ + "Available Authentication Providers:", + " Google", + " Provider: google", + " Custom IMAP", + " ID: conn-imap-1", + " Scopes: 2 configured", + } { + if !strings.Contains(stdout, want) { + t.Fatalf("stdout %q does not contain %q", stdout, want) + } + } + + for _, unwanted := range []string{ + "Name: \n", + "ID: \n", + } { + if strings.Contains(stdout, unwanted) { + t.Fatalf("stdout %q unexpectedly contains blank field %q", stdout, unwanted) + } + } +} + +func TestCLI_MCPServe_RequiresFileStorePassphrase(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + origFileStorePassphrase := os.Getenv("NYLAS_FILE_STORE_PASSPHRASE") + defer setEnvOrUnset("NYLAS_FILE_STORE_PASSPHRASE", origFileStorePassphrase) + _ = os.Setenv("NYLAS_FILE_STORE_PASSPHRASE", "integration-test-file-store-passphrase") + + tempDir := t.TempDir() + configHome := filepath.Join(tempDir, "xdg") + configDir := filepath.Join(configHome, "nylas") + + secretStore, err := keyring.NewEncryptedFileStore(configDir) + if err != nil { + t.Fatalf("failed to create secret store: %v", err) + } + if err := secretStore.Set(ports.KeyAPIKey, "stored-api-key"); err != nil { + t.Fatalf("failed to save api key: %v", err) + } + + stdout, stderr, err := runCLIWithOverrides(30*time.Second, map[string]string{ + "XDG_CONFIG_HOME": configHome, + "HOME": tempDir, + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + "NYLAS_FILE_STORE_PASSPHRASE": "", + }, "mcp", "serve") + if err == nil { + t.Fatalf("expected mcp serve to fail without passphrase\nstdout: %s\nstderr: %s", stdout, stderr) + } + if !strings.Contains(stderr, "NYLAS_FILE_STORE_PASSPHRASE") { + t.Fatalf("stderr %q does not mention NYLAS_FILE_STORE_PASSPHRASE", stderr) + } +} + +func TestCLI_WebhookServer_RejectsUnsignedRequests(t *testing.T) { + if testBinary == "" { + t.Skip("CLI binary not found") + } + + port := freeTCPPort(t) + secret := "test-webhook-secret" + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, testBinary, "webhook", "server", "--quiet", "--port", strconv.Itoa(port), "--secret", secret) + cmd.Env = cliTestEnv(map[string]string{ + "NYLAS_API_KEY": "", + "NYLAS_GRANT_ID": "", + }) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Start(); err != nil { + t.Fatalf("failed to start webhook server: %v", err) + } + t.Cleanup(func() { + if cmd.Process != nil { + _ = cmd.Process.Signal(os.Interrupt) + } + _ = cmd.Wait() + }) + + baseURL := "http://127.0.0.1:" + strconv.Itoa(port) + waitForServer(t, baseURL+"/health") + + payload := []byte(`{"type":"message.created","id":"event-123"}`) + + resp, err := http.Post(baseURL+"/webhook", "application/json", bytes.NewReader(payload)) + if err != nil { + t.Fatalf("missing-signature request failed: %v", err) + } + _ = resp.Body.Close() + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("missing signature status = %d, want %d", resp.StatusCode, http.StatusUnauthorized) + } + + req, err := http.NewRequest(http.MethodPost, baseURL+"/webhook", bytes.NewReader(payload)) + if err != nil { + t.Fatalf("failed to create invalid-signature request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Nylas-Signature", "invalid-signature") + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("invalid-signature request failed: %v", err) + } + _ = resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("invalid signature status = %d, want %d", resp.StatusCode, http.StatusForbidden) + } + + req, err = http.NewRequest(http.MethodPost, baseURL+"/webhook", bytes.NewReader(payload)) + if err != nil { + t.Fatalf("failed to create valid-signature request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Nylas-Signature", signTestWebhookPayload(secret, payload)) + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("valid-signature request failed: %v", err) + } + _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("valid signature status = %d, want %d\nstdout: %s\nstderr: %s", resp.StatusCode, http.StatusOK, stdout.String(), stderr.String()) + } +} + +func freeTCPPort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to allocate test port: %v", err) + } + defer func() { _ = listener.Close() }() + + return listener.Addr().(*net.TCPAddr).Port +} + +func waitForServer(t *testing.T, url string) { + t.Helper() + + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + resp, err := http.Get(url) + if err == nil { + _ = resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return + } + } + time.Sleep(100 * time.Millisecond) + } + + t.Fatalf("server %s did not become healthy in time", url) +} + +func signTestWebhookPayload(secret string, payload []byte) string { + mac := hmac.New(sha256.New, []byte(secret)) + _, _ = mac.Write(payload) + return hex.EncodeToString(mac.Sum(nil)) +} + +func setEnvOrUnset(key, value string) { + if value != "" { + _ = os.Setenv(key, value) + } else { + _ = os.Unsetenv(key) + } +} diff --git a/internal/cli/integration/test.go b/internal/cli/integration/test.go index d2494ce..75fd329 100644 --- a/internal/cli/integration/test.go +++ b/internal/cli/integration/test.go @@ -16,6 +16,7 @@ // - NYLAS_TEST_RATE_LIMIT_RPS: API rate limit (requests/sec, default: 2.0) // - NYLAS_TEST_RATE_LIMIT_BURST: API rate limit burst capacity (default: 5) // - NYLAS_INBOUND_GRANT_ID: Grant ID for inbound inbox tests (skips inbound tests if not set) +// - NYLAS_FILE_STORE_PASSPHRASE: Passphrase for the encrypted file secret-store fallback // // Parallel Testing: // @@ -133,7 +134,11 @@ func init() { } for _, c := range candidates { if _, err := os.Stat(c); err == nil { - testBinary = c + if abs, err := filepath.Abs(c); err == nil { + testBinary = abs + } else { + testBinary = c + } break } } @@ -204,32 +209,21 @@ func runCLIWithTimeout(timeout time.Duration, args ...string) (string, string, e cmd.Stdout = &stdout cmd.Stderr = &stderr - // Build environment with all necessary variables - env := []string{ - "NYLAS_API_KEY=" + testAPIKey, - "NYLAS_GRANT_ID=" + testGrantID, - "NYLAS_DISABLE_KEYRING=true", // Disable keyring during tests to avoid macOS prompts - } + cmd.Env = cliTestEnv(nil) - // Pass through AI provider credentials if set - if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { - env = append(env, "ANTHROPIC_API_KEY="+apiKey) - } - if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { - env = append(env, "OPENAI_API_KEY="+apiKey) - } - if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { - env = append(env, "GROQ_API_KEY="+apiKey) - } - if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" { - env = append(env, "OPENROUTER_API_KEY="+apiKey) - } - if ollamaHost := os.Getenv("OLLAMA_HOST"); ollamaHost != "" { - env = append(env, "OLLAMA_HOST="+ollamaHost) - } + err := cmd.Run() + return stdout.String(), stderr.String(), err +} - // Set environment for the CLI - cmd.Env = append(os.Environ(), env...) +func runCLIWithOverrides(timeout time.Duration, envOverrides map[string]string, args ...string) (string, string, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, testBinary, args...) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + cmd.Env = cliTestEnv(envOverrides) err := cmd.Run() return stdout.String(), stderr.String(), err @@ -255,33 +249,63 @@ func runCLIWithInput(input string, args ...string) (string, string, error) { cmd.Stdin = strings.NewReader(input) // Build environment with all necessary variables - env := []string{ - "NYLAS_API_KEY=" + testAPIKey, - "NYLAS_GRANT_ID=" + testGrantID, - "NYLAS_DISABLE_KEYRING=true", // Disable keyring during tests to avoid macOS prompts - } + cmd.Env = cliTestEnv(nil) - // Pass through AI provider credentials if set - if apiKey := os.Getenv("ANTHROPIC_API_KEY"); apiKey != "" { - env = append(env, "ANTHROPIC_API_KEY="+apiKey) - } - if apiKey := os.Getenv("OPENAI_API_KEY"); apiKey != "" { - env = append(env, "OPENAI_API_KEY="+apiKey) + err := cmd.Run() + return stdout.String(), stderr.String(), err +} + +func cliTestEnv(overrides map[string]string) []string { + env := os.Environ() + for key := range overrides { + env = removeEnvKey(env, key) } - if apiKey := os.Getenv("GROQ_API_KEY"); apiKey != "" { - env = append(env, "GROQ_API_KEY="+apiKey) + + defaults := map[string]string{ + "NYLAS_API_KEY": testAPIKey, + "NYLAS_GRANT_ID": testGrantID, + "NYLAS_DISABLE_KEYRING": "true", + "NYLAS_FILE_STORE_PASSPHRASE": "integration-test-file-store-passphrase", } - if apiKey := os.Getenv("OPENROUTER_API_KEY"); apiKey != "" { - env = append(env, "OPENROUTER_API_KEY="+apiKey) + + for key, value := range defaults { + if _, overridden := overrides[key]; overridden { + continue + } + env = append(env, key+"="+value) + } + + for _, key := range []string{ + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "GROQ_API_KEY", + "OPENROUTER_API_KEY", + "OLLAMA_HOST", + } { + if _, overridden := overrides[key]; overridden { + continue + } + if value := os.Getenv(key); value != "" { + env = append(env, key+"="+value) + } } - if ollamaHost := os.Getenv("OLLAMA_HOST"); ollamaHost != "" { - env = append(env, "OLLAMA_HOST="+ollamaHost) + + for key, value := range overrides { + env = append(env, key+"="+value) } - cmd.Env = append(os.Environ(), env...) + return env +} - err := cmd.Run() - return stdout.String(), stderr.String(), err +func removeEnvKey(env []string, key string) []string { + prefix := key + "=" + filtered := env[:0] + for _, entry := range env { + if !strings.HasPrefix(entry, prefix) { + filtered = append(filtered, entry) + } + } + return filtered } // runCLIWithInputAndRateLimit executes a CLI command with stdin input and rate limiting. diff --git a/internal/ports/auth.go b/internal/ports/auth.go index 517099b..98dae4d 100644 --- a/internal/ports/auth.go +++ b/internal/ports/auth.go @@ -9,10 +9,10 @@ import ( // AuthClient defines the interface for authentication and grant operations. type AuthClient interface { // BuildAuthURL builds an OAuth authorization URL for a provider. - BuildAuthURL(provider domain.Provider, redirectURI string) string + BuildAuthURL(provider domain.Provider, redirectURI, state, codeChallenge string) string // ExchangeCode exchanges an authorization code for a grant. - ExchangeCode(ctx context.Context, code, redirectURI string) (*domain.Grant, error) + ExchangeCode(ctx context.Context, code, redirectURI, codeVerifier string) (*domain.Grant, error) // ListGrants returns all grants for the authenticated application. ListGrants(ctx context.Context) ([]domain.Grant, error) diff --git a/internal/ports/nylas.go b/internal/ports/nylas.go index 5b2b67f..ed26fae 100644 --- a/internal/ports/nylas.go +++ b/internal/ports/nylas.go @@ -39,7 +39,7 @@ type OAuthServer interface { Stop() error // WaitForCallback waits for the OAuth callback and returns the auth code. - WaitForCallback(ctx context.Context) (string, error) + WaitForCallback(ctx context.Context, expectedState string) (string, error) // GetRedirectURI returns the redirect URI for OAuth. GetRedirectURI() string