diff --git a/.github/workflows/release_build_infisical_cli.yml b/.github/workflows/release_build_infisical_cli.yml index ad8f736a..140bd0a1 100644 --- a/.github/workflows/release_build_infisical_cli.yml +++ b/.github/workflows/release_build_infisical_cli.yml @@ -2,6 +2,11 @@ name: Build and release CLI on: workflow_dispatch: + inputs: + is_urgent: + description: "Mark this release as urgent (bypasses 48h grace period for update notifications)" + type: boolean + default: false push: # run only against tags @@ -123,6 +128,14 @@ jobs: FURY_TOKEN: ${{ secrets.FURYPUSHTOKEN }} AUR_KEY: ${{ secrets.AUR_KEY }} GORELEASER_KEY: ${{ secrets.GORELEASER_KEY }} + - name: Mark release as urgent + if: ${{ inputs.is_urgent == true }} + env: + GH_TOKEN: ${{ secrets.GO_RELEASER_GITHUB_TOKEN }} + run: | + CURRENT_BODY=$(gh release view "${{ github.ref_name }}" --json body -q .body) + gh release edit "${{ github.ref_name }}" --notes "${CURRENT_BODY} + " - uses: actions/setup-python@v4 with: python-version: "3.12" diff --git a/e2e/go.mod b/e2e/go.mod index 4c13b036..c03a52c1 100644 --- a/e2e/go.mod +++ b/e2e/go.mod @@ -125,7 +125,7 @@ require ( github.com/godbus/dbus/v5 v5.1.0 // indirect github.com/gofrs/flock v0.12.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golang/glog v1.2.5 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect diff --git a/e2e/go.sum b/e2e/go.sum index 1c7eac9e..f7ebe67d 100644 --- a/e2e/go.sum +++ b/e2e/go.sum @@ -379,8 +379,8 @@ github.com/gogo/protobuf v1.0.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7a github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.5 h1:DrW6hGnjIhtvhOIiAKT6Psh/Kd/ldepEa81DKeiRJ5I= diff --git a/go.mod b/go.mod index 78362488..107684af 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/fatih/semgroup v1.2.0 github.com/gitleaks/go-gitdiff v0.9.1 github.com/go-mysql-org/go-mysql v1.13.0 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/google/uuid v1.6.0 github.com/h2non/filetype v1.1.3 github.com/infisical/go-sdk v0.6.8 diff --git a/go.sum b/go.sum index 783e9070..d5a78274 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,8 @@ github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= +github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.0 h1:uCdmnmatrKCgMBlM4rMuJZWOkPDqdbZPnrMXDY4gI68= github.com/golang/glog v1.2.0/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= diff --git a/packages/cmd/root.go b/packages/cmd/root.go index 193b7e9b..b8c33560 100644 --- a/packages/cmd/root.go +++ b/packages/cmd/root.go @@ -64,6 +64,7 @@ func RootCmdStdoutWriter() io.Writer { // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the RootCmd. func Execute() { + defer util.WaitForUpdateCheck() err := RootCmd.Execute() if err != nil { os.Exit(1) diff --git a/packages/util/check-for-update.go b/packages/util/check-for-update.go index 3b7044e9..d84600eb 100644 --- a/packages/util/check-for-update.go +++ b/packages/util/check-for-update.go @@ -5,41 +5,83 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "os" "os/exec" + "path/filepath" "runtime" "strings" + "sync" "time" "github.com/fatih/color" "github.com/rs/zerolog/log" ) -func CheckForUpdate() { - CheckForUpdateWithWriter(os.Stderr) +var githubHTTPClient = &http.Client{Timeout: 8 * time.Second} + +var updateCheckWg sync.WaitGroup + +const updateCheckCacheTTL = 24 * time.Hour +const urgentUpdateCheckCacheTTL = 5 * time.Minute + +type UpdateCheckCache struct { + LastCheckTime time.Time `json:"lastCheckTime"` + LatestVersion string `json:"latestVersion"` + LatestVersionPublishedAt time.Time `json:"latestVersionPublishedAt"` + CurrentVersionPublishedAt time.Time `json:"currentVersionPublishedAt"` + IsUrgent bool `json:"isUrgent"` + CurrentVersionAtCheck string `json:"currentVersionAtCheck"` } func CheckForUpdateWithWriter(w io.Writer) { if checkEnv := os.Getenv("INFISICAL_DISABLE_UPDATE_CHECK"); checkEnv != "" { return } - latestVersion, _, isUrgent, err := getLatestTag("Infisical", "cli") - if err != nil { - log.Debug().Err(err) - // do nothing and continue - return + + cache := readUpdateCheckCache() + + displayCachedUpdateNotice(w, cache) + + if !isCacheFresh(cache) { + updateCheckWg.Add(1) + go func() { + defer updateCheckWg.Done() + performUpdateCheckInBackground() + }() } +} - if latestVersion == CLI_VERSION { - return +// WaitForUpdateCheck blocks until the background update check goroutine completes. +// Call this before program exit to ensure the cache gets written. +func WaitForUpdateCheck() { + updateCheckWg.Wait() +} + +// isCacheFresh returns true if the cache is fresh enough to skip a network check. +func isCacheFresh(cache *UpdateCheckCache) bool { + if cache == nil || cache.LatestVersion == "" || cache.CurrentVersionAtCheck != CLI_VERSION { + return false } + ttl := updateCheckCacheTTL + if cache.IsUrgent { + ttl = urgentUpdateCheckCacheTTL + } + return time.Since(cache.LastCheckTime) < ttl +} - // Only prompt if the user's current version is at least 48 hours old, unless urgent. - // This avoids nagging users who recently updated. - currentVersionPublishedAt, err := getReleasePublishedAt("Infisical", "cli", CLI_VERSION) - if err == nil && !isUrgent && time.Since(currentVersionPublishedAt).Hours() < 48 { +// displayCachedUpdateNotice prints an update notification from cached data. +func displayCachedUpdateNotice(w io.Writer, cache *UpdateCheckCache) { + if cache == nil || cache.LatestVersion == "" || cache.LatestVersion == CLI_VERSION { + return + } + // Don't show stale notifications after the user has upgraded. + if cache.CurrentVersionAtCheck != CLI_VERSION { + return + } + // Unless urgent, skip notification if the current version is less than 48h old. + if !cache.IsUrgent && !cache.CurrentVersionPublishedAt.IsZero() && + time.Since(cache.CurrentVersionPublishedAt).Hours() < 48 { return } @@ -51,19 +93,122 @@ func CheckForUpdateWithWriter(w io.Writer) { yellow("A new release of infisical is available:"), blue(CLI_VERSION), black("->"), - blue(latestVersion), + blue(cache.LatestVersion), ) fmt.Fprintln(w, msg) updateInstructions := GetUpdateInstructions() - if updateInstructions != "" { - msg = fmt.Sprintf("\n%s\n", GetUpdateInstructions()) + msg = fmt.Sprintf("\n%s\n", updateInstructions) fmt.Fprintln(w, msg) } } +// performUpdateCheckInBackground fetches update info from GitHub and writes to cache. +// It is designed to be called as a fire-and-forget goroutine. +func performUpdateCheckInBackground() { + latestVersion, latestPublishedAt, isUrgent, err := getLatestTag("Infisical", "cli") + if err != nil { + log.Debug().Err(err).Msg("background update check: failed to get latest tag") + return + } + + cache := &UpdateCheckCache{ + LastCheckTime: time.Now(), + LatestVersion: latestVersion, + LatestVersionPublishedAt: latestPublishedAt, + IsUrgent: isUrgent, + CurrentVersionAtCheck: CLI_VERSION, + } + + // If versions differ, fetch the publish date for the current version (for 48h grace). + if latestVersion != CLI_VERSION { + currentPublishedAt, err := getReleasePublishedAt("Infisical", "cli", CLI_VERSION) + if err != nil { + log.Debug().Err(err).Msg("background update check: failed to get current version publish date") + // Non-fatal — we just won't have the 48h grace period data. + } else { + cache.CurrentVersionPublishedAt = currentPublishedAt + } + } + + if err := writeUpdateCheckCache(cache); err != nil { + log.Debug().Err(err).Msg("background update check: failed to write cache") + } +} + +// getUpdateCheckCachePath returns the path to ~/.infisical/update-check.json. +func getUpdateCheckCachePath() (string, error) { + homeDir, err := GetHomeDir() + if err != nil { + return "", err + } + return filepath.Join(homeDir, CONFIG_FOLDER_NAME, UPDATE_CHECK_CACHE_FILE_NAME), nil +} + +// readUpdateCheckCache reads and unmarshals the cache file. Returns nil on any error (cache miss). +func readUpdateCheckCache() *UpdateCheckCache { + path, err := getUpdateCheckCachePath() + if err != nil { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + return nil + } + + var cache UpdateCheckCache + if err := json.Unmarshal(data, &cache); err != nil { + return nil + } + + return &cache +} + +// writeUpdateCheckCache atomically writes the cache file using a temp file + rename. +func writeUpdateCheckCache(cache *UpdateCheckCache) error { + path, err := getUpdateCheckCachePath() + if err != nil { + return err + } + + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("failed to create cache directory: %w", err) + } + + data, err := json.Marshal(cache) + if err != nil { + return fmt.Errorf("failed to marshal cache: %w", err) + } + + tmpFile, err := os.CreateTemp(dir, "update-check-*.json.tmp") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + tmpPath := tmpFile.Name() + + if _, err := tmpFile.Write(data); err != nil { + tmpFile.Close() + os.Remove(tmpPath) + return fmt.Errorf("failed to write temp file: %w", err) + } + + if err := tmpFile.Close(); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("failed to close temp file: %w", err) + } + + if err := os.Rename(tmpPath, path); err != nil { + os.Remove(tmpPath) + return fmt.Errorf("failed to rename temp file: %w", err) + } + + return nil +} + func DisplayAptInstallationChangeBanner(isSilent bool) { DisplayAptInstallationChangeBannerWithWriter(isSilent, os.Stderr) } @@ -89,7 +234,7 @@ func DisplayAptInstallationChangeBannerWithWriter(isSilent bool, w io.Writer) { func getLatestTag(repoOwner string, repoName string) (string, time.Time, bool, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", repoOwner, repoName) - resp, err := http.Get(url) + resp, err := githubHTTPClient.Get(url) if err != nil { return "", time.Time{}, false, err } @@ -132,7 +277,7 @@ func getLatestTag(repoOwner string, repoName string) (string, time.Time, bool, e func getReleasePublishedAt(repoOwner string, repoName string, version string) (time.Time, error) { tag := "v" + version url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/tags/%s", repoOwner, repoName, tag) - resp, err := http.Get(url) + resp, err := githubHTTPClient.Get(url) if err != nil { return time.Time{}, err } @@ -218,7 +363,7 @@ func IsRunningInDocker() bool { return true } - cgroup, err := ioutil.ReadFile("/proc/self/cgroup") + cgroup, err := os.ReadFile("/proc/self/cgroup") if err != nil { return false } diff --git a/packages/util/check-for-update_test.go b/packages/util/check-for-update_test.go new file mode 100644 index 00000000..99ad1204 --- /dev/null +++ b/packages/util/check-for-update_test.go @@ -0,0 +1,299 @@ +package util + +import ( + "bytes" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/fatih/color" +) + +func init() { + // Disable color output in tests so we can assert on plain text. + color.NoColor = true +} + +func TestIsCacheFresh(t *testing.T) { + tests := []struct { + name string + cache *UpdateCheckCache + expected bool + }{ + { + name: "nil cache needs updating", + cache: nil, + expected: false, + }, + { + name: "blank LatestVersion needs updating", + cache: &UpdateCheckCache{ + LastCheckTime: time.Now(), + LatestVersion: "", + CurrentVersionAtCheck: CLI_VERSION, + }, + expected: false, + }, + { + name: "mismatched CurrentVersionAtCheck needs updating", + cache: &UpdateCheckCache{ + LastCheckTime: time.Now(), + LatestVersion: "1.0.0", + CurrentVersionAtCheck: "old-version", + }, + expected: false, + }, + { + name: "urgent with expired short TTL needs updating", + cache: &UpdateCheckCache{ + LastCheckTime: time.Now().Add(-10 * time.Minute), + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + IsUrgent: true, + }, + expected: false, + }, + { + name: "urgent with recent check is still fresh", + cache: &UpdateCheckCache{ + LastCheckTime: time.Now().Add(-2 * time.Minute), + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + IsUrgent: true, + }, + expected: true, + }, + { + name: "expired TTL (>24h) needs updating", + cache: &UpdateCheckCache{ + LastCheckTime: time.Now().Add(-25 * time.Hour), + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + }, + expected: false, + }, + { + name: "fresh cache does not need updating", + cache: &UpdateCheckCache{ + LastCheckTime: time.Now().Add(-1 * time.Hour), + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isCacheFresh(tt.cache); got != tt.expected { + t.Errorf("isCacheFresh() = %v, expected %v", got, tt.expected) + } + }) + } +} + +func TestDisplayCachedUpdateNotice(t *testing.T) { + tests := []struct { + name string + cache *UpdateCheckCache + expectOutput bool + expectContains []string + }{ + { + name: "nil cache produces no output", + cache: nil, + expectOutput: false, + }, + { + name: "blank LatestVersion produces no output", + cache: &UpdateCheckCache{ + LatestVersion: "", + CurrentVersionAtCheck: CLI_VERSION, + }, + expectOutput: false, + }, + { + name: "same version produces no output", + cache: &UpdateCheckCache{ + LatestVersion: CLI_VERSION, + CurrentVersionAtCheck: CLI_VERSION, + }, + expectOutput: false, + }, + { + name: "stale cache after upgrade produces no output", + cache: &UpdateCheckCache{ + LatestVersion: "2.0.0", + CurrentVersionAtCheck: "old-version-that-doesnt-match", + }, + expectOutput: false, + }, + { + name: "current version <48h old produces no output", + cache: &UpdateCheckCache{ + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + CurrentVersionPublishedAt: time.Now().Add(-24 * time.Hour), + }, + expectOutput: false, + }, + { + name: "urgent ignores 48h grace period", + cache: &UpdateCheckCache{ + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + CurrentVersionPublishedAt: time.Now().Add(-24 * time.Hour), + IsUrgent: true, + }, + expectOutput: true, + expectContains: []string{"2.0.0"}, + }, + { + name: "shows banner when current version >48h old", + cache: &UpdateCheckCache{ + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + CurrentVersionPublishedAt: time.Now().Add(-72 * time.Hour), + }, + expectOutput: true, + expectContains: []string{"A new release of infisical is available", CLI_VERSION, "2.0.0"}, + }, + { + name: "zero publish date shows banner (cannot enforce 48h grace)", + cache: &UpdateCheckCache{ + LatestVersion: "2.0.0", + CurrentVersionAtCheck: CLI_VERSION, + }, + expectOutput: true, + expectContains: []string{"2.0.0"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + displayCachedUpdateNotice(&buf, tt.cache) + + if tt.expectOutput && buf.Len() == 0 { + t.Error("expected output but got none") + } + if !tt.expectOutput && buf.Len() != 0 { + t.Errorf("expected no output, got: %s", buf.String()) + } + for _, s := range tt.expectContains { + if !strings.Contains(buf.String(), s) { + t.Errorf("expected output to contain %q, got: %s", s, buf.String()) + } + } + }) + } +} + +func TestWriteAndReadUpdateCheckCache(t *testing.T) { + tmpDir := t.TempDir() + cachePath := filepath.Join(tmpDir, UPDATE_CHECK_CACHE_FILE_NAME) + + original := &UpdateCheckCache{ + LastCheckTime: time.Now().Truncate(time.Second), + LatestVersion: "2.0.0", + LatestVersionPublishedAt: time.Now().Add(-1 * time.Hour).Truncate(time.Second), + CurrentVersionPublishedAt: time.Now().Add(-48 * time.Hour).Truncate(time.Second), + IsUrgent: false, + CurrentVersionAtCheck: "1.0.0", + } + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("failed to marshal cache: %v", err) + } + if err := os.WriteFile(cachePath, data, 0600); err != nil { + t.Fatalf("failed to write cache file: %v", err) + } + + readData, err := os.ReadFile(cachePath) + if err != nil { + t.Fatalf("failed to read cache file: %v", err) + } + + var loaded UpdateCheckCache + if err := json.Unmarshal(readData, &loaded); err != nil { + t.Fatalf("failed to unmarshal cache: %v", err) + } + + if loaded.LatestVersion != original.LatestVersion { + t.Errorf("LatestVersion: got %s, want %s", loaded.LatestVersion, original.LatestVersion) + } + if loaded.CurrentVersionAtCheck != original.CurrentVersionAtCheck { + t.Errorf("CurrentVersionAtCheck: got %s, want %s", loaded.CurrentVersionAtCheck, original.CurrentVersionAtCheck) + } + if loaded.IsUrgent != original.IsUrgent { + t.Errorf("IsUrgent: got %v, want %v", loaded.IsUrgent, original.IsUrgent) + } + if !loaded.LastCheckTime.Equal(original.LastCheckTime) { + t.Errorf("LastCheckTime: got %v, want %v", loaded.LastCheckTime, original.LastCheckTime) + } +} + +func TestReadUpdateCheckCache_CorruptJSON(t *testing.T) { + var cache UpdateCheckCache + err := json.Unmarshal([]byte(`{not valid json!!!`), &cache) + if err == nil { + t.Error("expected error for corrupt JSON") + } +} + +func TestWriteUpdateCheckCache_AtomicWrite(t *testing.T) { + tmpDir := t.TempDir() + cachePath := filepath.Join(tmpDir, "update-check.json") + + cache := &UpdateCheckCache{ + LastCheckTime: time.Now().Truncate(time.Second), + LatestVersion: "3.0.0", + CurrentVersionAtCheck: "2.0.0", + } + + data, err := json.Marshal(cache) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + // Simulate the atomic write pattern. + tmpFile, err := os.CreateTemp(tmpDir, "update-check-*.json.tmp") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + if _, err := tmpFile.Write(data); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + tmpFile.Close() + + if err := os.Rename(tmpFile.Name(), cachePath); err != nil { + t.Fatalf("failed to rename: %v", err) + } + + readData, err := os.ReadFile(cachePath) + if err != nil { + t.Fatalf("failed to read final file: %v", err) + } + + var loaded UpdateCheckCache + if err := json.Unmarshal(readData, &loaded); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + + if loaded.LatestVersion != "3.0.0" { + t.Errorf("got %s, want 3.0.0", loaded.LatestVersion) + } + + info, err := os.Stat(cachePath) + if err != nil { + t.Fatalf("failed to stat: %v", err) + } + if perm := info.Mode().Perm(); perm != 0600 { + t.Errorf("got permissions %o, want 0600", perm) + } +} diff --git a/packages/util/constants.go b/packages/util/constants.go index 57a1cddb..0a19dc59 100644 --- a/packages/util/constants.go +++ b/packages/util/constants.go @@ -65,6 +65,8 @@ const ( KUBERNETES_SERVICE_PORT_HTTPS_ENV_NAME = "KUBERNETES_SERVICE_PORT_HTTPS" KUBERNETES_SERVICE_ACCOUNT_CA_CERT_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH = "/var/run/secrets/kubernetes.io/serviceaccount/token" + + UPDATE_CHECK_CACHE_FILE_NAME = "update-check.json" ) var ( diff --git a/packages/util/credentials.go b/packages/util/credentials.go index cd73e47c..68a17cdf 100644 --- a/packages/util/credentials.go +++ b/packages/util/credentials.go @@ -5,10 +5,11 @@ import ( "errors" "fmt" "strings" + "time" - "github.com/Infisical/infisical-merge/packages/api" "github.com/Infisical/infisical-merge/packages/config" "github.com/Infisical/infisical-merge/packages/models" + jwt "github.com/golang-jwt/jwt/v5" "github.com/zalando/go-keyring" ) @@ -83,17 +84,8 @@ func GetCurrentLoggedInUserDetails(setConfigVariables bool) (LoggedInUserDetails } } - // check to to see if the JWT is still valid - httpClient, err := GetRestyClientWithCustomHeaders() - if err != nil { - return LoggedInUserDetails{}, fmt.Errorf("getCurrentLoggedInUserDetails: unable to get client with custom headers [err=%s]", err) - } - - httpClient. - SetAuthToken(userCreds.JTWToken). - SetHeader("Accept", "application/json") + isAuthenticated := !IsJWTExpired(userCreds.JTWToken) - isAuthenticated := api.CallIsAuthenticated(httpClient) // TODO: add refresh token // if !isAuthenticated { // accessTokenResponse, err := api.CallGetNewAccessTokenWithRefreshToken(httpClient, userCreds.RefreshToken) @@ -103,11 +95,6 @@ func GetCurrentLoggedInUserDetails(setConfigVariables bool) (LoggedInUserDetails // } // } - // err = StoreUserCredsInKeyRing(&userCreds) - // if err != nil { - // log.Debug().Msg("unable to store your user credentials with new access token") - // } - if !isAuthenticated { return LoggedInUserDetails{ IsUserLoggedIn: true, // was logged in @@ -125,3 +112,17 @@ func GetCurrentLoggedInUserDetails(setConfigVariables bool) (LoggedInUserDetails return LoggedInUserDetails{}, nil } } + +func IsJWTExpired(token string) bool { + parser := jwt.NewParser() + claims := &jwt.RegisteredClaims{} + _, _, err := parser.ParseUnverified(token, claims) + if err != nil { + return true + } + if claims.ExpiresAt == nil { + return true + } + // 30-second buffer to avoid race between local check and subsequent API call + return claims.ExpiresAt.Before(time.Now().Add(30 * time.Second)) +} diff --git a/packages/util/credentials_test.go b/packages/util/credentials_test.go new file mode 100644 index 00000000..040e07db --- /dev/null +++ b/packages/util/credentials_test.go @@ -0,0 +1,85 @@ +package util + +import ( + "testing" + "time" + + jwt "github.com/golang-jwt/jwt/v5" +) + +// testSigningKey is a dummy key used only to produce validly-formatted JWTs for tests. +var testSigningKey = []byte("test-secret-key") + +func createToken(t *testing.T, exp *time.Time) string { + t.Helper() + claims := jwt.RegisteredClaims{} + if exp != nil { + claims.ExpiresAt = jwt.NewNumericDate(*exp) + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signed, err := token.SignedString(testSigningKey) + if err != nil { + t.Fatalf("failed to sign test token: %v", err) + } + return signed +} + +func TestIsJWTExpired_ValidFutureToken(t *testing.T) { + exp := time.Now().Add(1 * time.Hour) + token := createToken(t, &exp) + if IsJWTExpired(token) { + t.Error("expected token with future exp to not be expired") + } +} + +func TestIsJWTExpired_ExpiredToken(t *testing.T) { + exp := time.Now().Add(-1 * time.Hour) + token := createToken(t, &exp) + if !IsJWTExpired(token) { + t.Error("expected token with past exp to be expired") + } +} + +func TestIsJWTExpired_WithinBuffer(t *testing.T) { + // 20 seconds from now — within the 30-second buffer + exp := time.Now().Add(20 * time.Second) + token := createToken(t, &exp) + if !IsJWTExpired(token) { + t.Error("expected token expiring within 30s buffer to be treated as expired") + } +} + +func TestIsJWTExpired_JustOutsideBuffer(t *testing.T) { + // 31 seconds from now — outside the 30-second buffer + exp := time.Now().Add(31 * time.Second) + token := createToken(t, &exp) + if IsJWTExpired(token) { + t.Error("expected token expiring in 31s to not be treated as expired") + } +} + +func TestIsJWTExpired_EmptyString(t *testing.T) { + if !IsJWTExpired("") { + t.Error("expected empty string to be treated as expired") + } +} + +func TestIsJWTExpired_MalformedJWT(t *testing.T) { + if !IsJWTExpired("not-a-jwt") { + t.Error("expected malformed JWT to be treated as expired") + } +} + +func TestIsJWTExpired_InvalidBase64Payload(t *testing.T) { + // Three parts but invalid base64 in payload + if !IsJWTExpired("header.!!!invalid-base64!!!.signature") { + t.Error("expected invalid base64 payload to be treated as expired") + } +} + +func TestIsJWTExpired_MissingExpClaim(t *testing.T) { + token := createToken(t, nil) + if !IsJWTExpired(token) { + t.Error("expected token without exp claim to be treated as expired") + } +}