From 487165625c401c2414e58189d84148e34a3759ce Mon Sep 17 00:00:00 2001 From: Alexis-Maurer Fortin Date: Tue, 5 May 2026 16:32:14 -0400 Subject: [PATCH 1/2] add version check call to poutine --- .poutine.sample.yml | 4 + README.md | 9 + cmd/root.go | 27 +++ models/config.go | 13 +- versioncheck/config.go | 71 +++++++ versioncheck/versioncheck.go | 245 +++++++++++++++++++++++ versioncheck/versioncheck_test.go | 316 ++++++++++++++++++++++++++++++ 7 files changed, 679 insertions(+), 6 deletions(-) create mode 100644 versioncheck/config.go create mode 100644 versioncheck/versioncheck.go create mode 100644 versioncheck/versioncheck_test.go diff --git a/.poutine.sample.yml b/.poutine.sample.yml index a458cc6b..e4096dae 100644 --- a/.poutine.sample.yml +++ b/.poutine.sample.yml @@ -10,6 +10,10 @@ # default: false ignoreForks: true +# Disable the once-per-day check for newer poutine releases +# default: false +# disableVersionCheck: true + # Set rule configuration options rulesConfig: pr_runs_on_self_hosted: diff --git a/README.md b/README.md index c7463939..ff52b6eb 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,19 @@ poutine analyze_org my-org/project --token "$GL_TOKEN" --scm gitlab --scm-base-u --skip Add rules to the skip list for the current run (can be specified multiple times) --verbose Enable debug logging --fail-on-violation Exit with a non-zero code (10) when violations are found +--disable-version-check Disable the once-per-day check for newer poutine releases (env: POUTINE_DISABLE_VERSION_CHECK, config: disableVersionCheck) ``` See [.poutine.sample.yml](.poutine.sample.yml) for an example configuration file. +#### Version check telemetry + +By default, `poutine` reaches out at most once every 24 hours to check whether a newer release is available. The request reports the current poutine version, an anonymous instance identifier persisted in `~/.poutine/config.yaml`, and a count of CLI invocations since the last check. No source, repository, or finding data is sent. To disable, use any of: + +- `--disable-version-check` flag +- `POUTINE_DISABLE_VERSION_CHECK=1` environment variable +- `disableVersionCheck: true` in `.poutine.yml` + ### Custom Rules `poutine` supports custom Rego rules to extend its security scanning capabilities. You can write your own rules and include them at runtime. diff --git a/cmd/root.go b/cmd/root.go index b62ff8b1..2bf185c0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -21,6 +21,7 @@ import ( "github.com/boostsecurityio/poutine/providers/gitops" "github.com/boostsecurityio/poutine/providers/scm" scm_domain "github.com/boostsecurityio/poutine/providers/scm/domain" + "github.com/boostsecurityio/poutine/versioncheck" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/viper" @@ -44,6 +45,7 @@ var config *models.Config = models.DefaultConfig() var skipRules []string var allowedRules []string var failOnViolation bool +var disableVersionCheck bool // ErrViolationsFound is returned when violations are detected and --fail-on-violation is set. var ErrViolationsFound = errors.New("poutine: violations found") @@ -88,9 +90,33 @@ By BoostSecurity.io - https://github.com/boostsecurityio/poutine `, return strings.ToUpper(fmt.Sprintf("| %-6s|", i)) } log.Logger = log.Output(output) + + runVersionCheck(cmd) }, } +// runVersionCheck performs the once-per-day update check unless disabled by +// flag, env var, or config. The "version" subcommand is excluded so users can +// inspect the binary without triggering a network call. +func runVersionCheck(cmd *cobra.Command) { + if cmd != nil && cmd.Name() == "version" { + return + } + disabled := disableVersionCheck || (config != nil && config.DisableVersionCheck) + result := versioncheck.Run(Version, disabled) + if result == nil || !result.UpdateAvailable { + return + } + target := result.LatestURL + if target == "" { + target = "https://github.com/boostsecurityio/poutine/releases" + } + log.Warn(). + Str("current_version", Version). + Str("latest_version", result.LatestVersion). + Msgf("A new version of poutine is available: %s — %s", result.LatestVersion, target) +} + // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { @@ -147,6 +173,7 @@ func init() { RootCmd.PersistentFlags().StringSliceVar(&skipRules, "skip", []string{}, "Adds rules to the configured skip list for the current run (optional)") RootCmd.PersistentFlags().StringSliceVar(&allowedRules, "allowed-rules", []string{}, "Overwrite the configured allowedRules list for the current run (optional)") RootCmd.PersistentFlags().BoolVar(&failOnViolation, "fail-on-violation", false, "Exit with a non-zero code (10) when violations are found") + RootCmd.PersistentFlags().BoolVar(&disableVersionCheck, "disable-version-check", false, "Disable the once-per-day check for newer poutine releases") _ = viper.BindPFlag("quiet", RootCmd.PersistentFlags().Lookup("quiet")) } diff --git a/models/config.go b/models/config.go index ba3d42f1..5a40b558 100644 --- a/models/config.go +++ b/models/config.go @@ -23,12 +23,13 @@ type ConfigInclude struct { } type Config struct { - Skip []ConfigSkip `json:"skip"` - AllowedRules []string `json:"allowed_rules"` - Include []ConfigInclude `json:"include"` - IgnoreForks bool `json:"ignore_forks"` - Quiet bool `json:"quiet,omitempty"` - RulesConfig map[string]map[string]interface{} `json:"rules_config"` + Skip []ConfigSkip `json:"skip"` + AllowedRules []string `json:"allowed_rules"` + Include []ConfigInclude `json:"include"` + IgnoreForks bool `json:"ignore_forks"` + Quiet bool `json:"quiet,omitempty"` + RulesConfig map[string]map[string]interface{} `json:"rules_config"` + DisableVersionCheck bool `json:"disable_version_check,omitempty"` } func DefaultConfig() *Config { diff --git a/versioncheck/config.go b/versioncheck/config.go new file mode 100644 index 00000000..03b6e5e8 --- /dev/null +++ b/versioncheck/config.go @@ -0,0 +1,71 @@ +// Package versioncheck implements a lightweight, opt-out version-check that +// reports anonymous start telemetry to the Boost OSS telemetry endpoint and +// notifies the user when a newer poutine release is available. +package versioncheck + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "gopkg.in/yaml.v3" +) + +// Config is the user-level state persisted between poutine invocations to +// support the once-per-day version check. It is intentionally separate from +// the project-level models.Config that is loaded from .poutine.yml. +type Config struct { + InstanceID string `yaml:"instance_id,omitempty"` + StartCount int `yaml:"start_count,omitempty"` + LastReportedStartCount int `yaml:"last_reported_start_count,omitempty"` + LastVersionCheckAt time.Time `yaml:"last_version_check_timestamp,omitempty"` +} + +// ConfigPath returns the path to the user-level state file. It honors +// POUTINE_CONFIG_DIR for tests and constrained environments and otherwise +// defaults to ~/.poutine/config.yaml. +func ConfigPath() string { + if dir := os.Getenv("POUTINE_CONFIG_DIR"); dir != "" { + return filepath.Join(dir, "config.yaml") + } + home, _ := os.UserHomeDir() + return filepath.Join(home, ".poutine", "config.yaml") +} + +// LoadConfig reads the user-level state file. A missing file is not an error +// and yields a nil Config so callers can treat it as a first run. +func LoadConfig() (*Config, error) { + path := ConfigPath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("read version-check state %s: %w", path, err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("invalid version-check state: %w", err) + } + return &cfg, nil +} + +// SaveConfig writes the user-level state file, creating the parent directory +// when needed. The file is written with restrictive permissions because it +// holds an anonymous instance identifier. +func SaveConfig(cfg *Config) error { + path := ConfigPath() + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return fmt.Errorf("create version-check state dir: %w", err) + } + data, err := yaml.Marshal(cfg) + if err != nil { + return fmt.Errorf("marshal version-check state: %w", err) + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return fmt.Errorf("write version-check state %s: %w", path, err) + } + return nil +} diff --git a/versioncheck/versioncheck.go b/versioncheck/versioncheck.go new file mode 100644 index 00000000..e612102f --- /dev/null +++ b/versioncheck/versioncheck.go @@ -0,0 +1,245 @@ +package versioncheck + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/google/uuid" +) + +const ( + // DisableEnv toggles the version check off when set to a truthy value. + DisableEnv = "POUTINE_DISABLE_VERSION_CHECK" + // URLEnv overrides the compiled-in endpoint, primarily for staging. + URLEnv = "POUTINE_VERSION_CHECK_URL" +) + +const ( + checkInterval = 24 * time.Hour + checkTimeout = 1200 * time.Millisecond +) + +// VersionCheckURL is the default endpoint. It can be overridden at runtime +// via the POUTINE_VERSION_CHECK_URL environment variable, primarily for +// staging or local testing. +var VersionCheckURL = "https://version-check.cicd.fun/v1/check" + +// Result mirrors the JSON response from the oss-telemetry /v1/check endpoint. +type Result struct { + LatestVersion string `json:"latest_version,omitempty"` + LatestURL string `json:"latest_url,omitempty"` + UpdateAvailable bool `json:"update_available"` +} + +type options struct { + Config *Config + Version string + URL string + Disabled bool + Client *http.Client + Now func() time.Time + SaveConfig func(*Config) error + Env func(string) string + NewID func() string +} + +// Run records a CLI start and, at most once every 24 hours, reports anonymous +// telemetry to the configured endpoint and returns the latest release info. +// The whole operation is bounded by checkTimeout so it never noticeably +// slows down poutine startup. +func Run(version string, disabled bool) *Result { + if disabled || isDisabledByEnv(os.Getenv(DisableEnv)) { + return nil + } + + cfg, _ := LoadConfig() + if cfg == nil { + cfg = &Config{} + } + recordStart(cfg, uuid.NewString) + _ = SaveConfig(cfg) + + ctx, cancel := context.WithTimeout(context.Background(), checkTimeout) + defer cancel() + + result, _ := run(ctx, options{ + Config: cfg, + Version: version, + URL: VersionCheckURL, + Client: &http.Client{Timeout: checkTimeout}, + Now: time.Now, + SaveConfig: SaveConfig, + Env: os.Getenv, + NewID: uuid.NewString, + }) + return result +} + +func run(ctx context.Context, opts options) (*Result, error) { + if opts.Config == nil { + opts.Config = &Config{} + } + if opts.Env == nil { + opts.Env = os.Getenv + } + if opts.Disabled || isDisabledByEnv(opts.Env(DisableEnv)) { + return nil, nil + } + + endpoint := strings.TrimSpace(opts.Env(URLEnv)) + envEndpoint := endpoint != "" + if endpoint == "" { + endpoint = strings.TrimSpace(opts.URL) + } + version := strings.TrimSpace(opts.Version) + if endpoint == "" || version == "" { + return nil, nil + } + if isDevVersion(version) && !envEndpoint { + return nil, nil + } + + now := time.Now() + if opts.Now != nil { + now = opts.Now() + } + if !opts.Config.LastVersionCheckAt.IsZero() && now.Sub(opts.Config.LastVersionCheckAt) < checkInterval { + return nil, nil + } + + u, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("parse version-check endpoint: %w", err) + } + instanceID := ensureInstanceID(opts.Config, opts.NewID) + startsSinceLastCheck := startsSinceLastReport(opts.Config) + requestURL := buildRequestURL(u, version, instanceID, opts.Config.StartCount, startsSinceLastCheck) + + client := opts.Client + if client == nil { + client = &http.Client{Timeout: checkTimeout} + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL, http.NoBody) + if err != nil { + return nil, fmt.Errorf("build version-check request: %w", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "poutine/"+version) + + resp, doErr := client.Do(req) + opts.Config.LastVersionCheckAt = now + if opts.SaveConfig == nil { + opts.SaveConfig = SaveConfig + } + var result *Result + var resultErr error + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + opts.Config.LastReportedStartCount = opts.Config.StartCount + result, resultErr = readResult(resp) + } else { + _, _ = io.Copy(io.Discard, resp.Body) + } + } + saveErr := opts.SaveConfig(opts.Config) + if doErr != nil { + return nil, fmt.Errorf("call version-check endpoint: %w", doErr) + } + if resultErr != nil { + return nil, resultErr + } + if saveErr != nil { + return result, saveErr + } + return result, nil +} + +func isDisabledByEnv(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func isDevVersion(version string) bool { + version = strings.TrimSpace(version) + return version == "" || + version == "dev" || + version == "development" || + version == "unknown" || + strings.Contains(version, "SNAPSHOT") +} + +func recordStart(cfg *Config, newID func() string) { + if cfg == nil { + return + } + ensureInstanceID(cfg, newID) + cfg.StartCount++ +} + +func ensureInstanceID(cfg *Config, newID func() string) string { + id := strings.TrimSpace(cfg.InstanceID) + if id == "" { + if newID == nil { + newID = uuid.NewString + } + id = strings.TrimSpace(newID()) + cfg.InstanceID = id + return id + } + cfg.InstanceID = id + return id +} + +func startsSinceLastReport(cfg *Config) int { + starts := cfg.StartCount - cfg.LastReportedStartCount + if starts < 0 { + return cfg.StartCount + } + return starts +} + +func readResult(resp *http.Response) (*Result, error) { + if resp.StatusCode == http.StatusNoContent { + return nil, nil + } + data, err := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + if err != nil { + return nil, fmt.Errorf("read version-check response: %w", err) + } + if strings.TrimSpace(string(data)) == "" { + return nil, nil + } + var result Result + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("decode version-check response: %w", err) + } + return &result, nil +} + +func buildRequestURL(u *url.URL, version, instanceID string, startCount, startsSinceLastCheck int) string { + q := u.Query() + q.Set("project", "poutine") + q.Set("component", "cli") + q.Set("version", version) + q.Set("instance_id", instanceID) + if startCount > 0 { + q.Set("start_count", strconv.Itoa(startCount)) + q.Set("starts_since_last_check", strconv.Itoa(startsSinceLastCheck)) + } + u.RawQuery = q.Encode() + return u.String() +} diff --git a/versioncheck/versioncheck_test.go b/versioncheck/versioncheck_test.go new file mode 100644 index 00000000..7c79def6 --- /dev/null +++ b/versioncheck/versioncheck_test.go @@ -0,0 +1,316 @@ +package versioncheck + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestRun_SendsVersionAndRecordsTimestamp(t *testing.T) { + now := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC) + instanceID := "2ed05245-10d7-4d21-a8e8-7c4e8a9851b4" + cfg := &Config{ + InstanceID: instanceID, + StartCount: 7, + } + var gotReq *http.Request + var saved *Config + + result, err := run(context.Background(), options{ + Config: cfg, + Version: "v0.18.0", + URL: "https://updates.example/check?channel=stable", + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + gotReq = req + return &http.Response{ + StatusCode: http.StatusNoContent, + Body: io.NopCloser(strings.NewReader("")), + }, nil + })}, + Now: func() time.Time { return now }, + SaveConfig: func(cfg *Config) error { + c := *cfg + saved = &c + return nil + }, + Env: func(string) string { return "" }, + NewID: func() string { return instanceID }, + }) + + require.NoError(t, err) + assert.Nil(t, result) + require.NotNil(t, gotReq) + assert.Equal(t, "poutine", gotReq.URL.Query().Get("project")) + assert.Equal(t, "cli", gotReq.URL.Query().Get("component")) + assert.Equal(t, "v0.18.0", gotReq.URL.Query().Get("version")) + assert.Equal(t, instanceID, gotReq.URL.Query().Get("instance_id")) + assert.Equal(t, "7", gotReq.URL.Query().Get("start_count")) + assert.Equal(t, "7", gotReq.URL.Query().Get("starts_since_last_check")) + assert.Equal(t, "stable", gotReq.URL.Query().Get("channel")) + assert.Equal(t, "poutine/v0.18.0", gotReq.Header.Get("User-Agent")) + require.NotNil(t, saved) + assert.Equal(t, instanceID, saved.InstanceID) + assert.Equal(t, 7, saved.StartCount) + assert.Equal(t, 7, saved.LastReportedStartCount) + assert.Equal(t, now, saved.LastVersionCheckAt) +} + +func TestRun_ReturnsUpdateResult(t *testing.T) { + now := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC) + cfg := &Config{ + InstanceID: "2ed05245-10d7-4d21-a8e8-7c4e8a9851b4", + StartCount: 5, + LastReportedStartCount: 4, + } + + result, err := run(context.Background(), options{ + Config: cfg, + Version: "v0.18.0", + URL: "https://updates.example/check", + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + assert.Equal(t, "1", req.URL.Query().Get("starts_since_last_check")) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{ + "latest_version":"v0.18.1", + "latest_url":"https://github.com/boostsecurityio/poutine/releases/tag/v0.18.1", + "update_available":true + }`)), + }, nil + })}, + Now: func() time.Time { return now }, + SaveConfig: func(*Config) error { return nil }, + Env: func(string) string { return "" }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.UpdateAvailable) + assert.Equal(t, "v0.18.1", result.LatestVersion) + assert.Equal(t, "https://github.com/boostsecurityio/poutine/releases/tag/v0.18.1", result.LatestURL) + assert.Equal(t, 5, cfg.LastReportedStartCount) +} + +func TestRun_DisabledByEnv(t *testing.T) { + called := false + cfg := &Config{} + + result, err := run(context.Background(), options{ + Config: cfg, + Version: "v0.18.0", + URL: "https://updates.example/check", + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + called = true + return nil, errors.New("unexpected request") + })}, + Env: func(key string) string { + if key == DisableEnv { + return "true" + } + return "" + }, + }) + + require.NoError(t, err) + assert.Nil(t, result) + assert.False(t, called) + assert.Empty(t, cfg.InstanceID) + assert.Zero(t, cfg.StartCount) + assert.True(t, cfg.LastVersionCheckAt.IsZero()) +} + +func TestRun_DisabledByOption(t *testing.T) { + called := false + cfg := &Config{} + + result, err := run(context.Background(), options{ + Config: cfg, + Version: "v0.18.0", + URL: "https://updates.example/check", + Disabled: true, + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + called = true + return nil, errors.New("unexpected request") + })}, + Env: func(string) string { return "" }, + }) + + require.NoError(t, err) + assert.Nil(t, result) + assert.False(t, called) + assert.True(t, cfg.LastVersionCheckAt.IsZero()) +} + +func TestRun_RespectsInterval(t *testing.T) { + now := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC) + called := false + cfg := &Config{LastVersionCheckAt: now.Add(-time.Hour)} + + result, err := run(context.Background(), options{ + Config: cfg, + Version: "v0.18.0", + URL: "https://updates.example/check", + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + called = true + return nil, errors.New("unexpected request") + })}, + Now: func() time.Time { return now }, + Env: func(string) string { return "" }, + }) + + require.NoError(t, err) + assert.Nil(t, result) + assert.False(t, called) + assert.Empty(t, cfg.InstanceID) +} + +func TestRecordStart_GeneratesInstanceIDAndIncrementsCount(t *testing.T) { + cfg := &Config{} + + recordStart(cfg, func() string { + return "97c5d9f0-7a5c-4a61-9f2a-09f4903de44e" + }) + recordStart(cfg, func() string { + t.Fatal("existing instance_id should be reused") + return "" + }) + + assert.Equal(t, "97c5d9f0-7a5c-4a61-9f2a-09f4903de44e", cfg.InstanceID) + assert.Equal(t, 2, cfg.StartCount) +} + +func TestRun_ReusesInstanceIDAndReportsStartCount(t *testing.T) { + now := time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC) + cfg := &Config{ + InstanceID: "97c5d9f0-7a5c-4a61-9f2a-09f4903de44e", + StartCount: 42, + LastReportedStartCount: 40, + LastVersionCheckAt: now.Add(-25 * time.Hour), + } + var gotReq *http.Request + + result, err := run(context.Background(), options{ + Config: cfg, + Version: "v0.18.0", + URL: "https://updates.example/check", + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + gotReq = req + return &http.Response{ + StatusCode: http.StatusNoContent, + Body: io.NopCloser(strings.NewReader("")), + }, nil + })}, + Now: func() time.Time { return now }, + SaveConfig: func(*Config) error { return nil }, + Env: func(string) string { return "" }, + NewID: func() string { + t.Fatal("existing instance_id should be reused") + return "" + }, + }) + + require.NoError(t, err) + assert.Nil(t, result) + require.NotNil(t, gotReq) + assert.Equal(t, "97c5d9f0-7a5c-4a61-9f2a-09f4903de44e", gotReq.URL.Query().Get("instance_id")) + assert.Equal(t, "42", gotReq.URL.Query().Get("start_count")) + assert.Equal(t, "2", gotReq.URL.Query().Get("starts_since_last_check")) + assert.Equal(t, 42, cfg.StartCount) + assert.Equal(t, 42, cfg.LastReportedStartCount) + assert.Equal(t, now, cfg.LastVersionCheckAt) +} + +func TestRun_SkipsDevVersionWithoutExplicitEndpoint(t *testing.T) { + called := false + + result, err := run(context.Background(), options{ + Config: &Config{}, + Version: "development", + URL: "https://updates.example/check", + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + called = true + return nil, errors.New("unexpected request") + })}, + Env: func(string) string { return "" }, + }) + + require.NoError(t, err) + assert.Nil(t, result) + assert.False(t, called) +} + +func TestRun_ExplicitEndpointAllowsDevVersion(t *testing.T) { + called := false + + result, err := run(context.Background(), options{ + Config: &Config{}, + Version: "development", + URL: "https://updates.example/default", + Client: &http.Client{Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + called = true + assert.Equal(t, "https://override.example/check", req.URL.Scheme+"://"+req.URL.Host+req.URL.Path) + return &http.Response{ + StatusCode: http.StatusNoContent, + Body: io.NopCloser(strings.NewReader("")), + }, nil + })}, + SaveConfig: func(*Config) error { return nil }, + Env: func(key string) string { + if key == URLEnv { + return "https://override.example/check" + } + return "" + }, + }) + + require.NoError(t, err) + assert.Nil(t, result) + assert.True(t, called) +} + +func TestLoadConfig_RoundTrip(t *testing.T) { + dir := t.TempDir() + t.Setenv("POUTINE_CONFIG_DIR", dir) + + loaded, err := LoadConfig() + require.NoError(t, err) + assert.Nil(t, loaded) + + cfg := &Config{ + InstanceID: "2ed05245-10d7-4d21-a8e8-7c4e8a9851b4", + StartCount: 3, + LastReportedStartCount: 2, + LastVersionCheckAt: time.Date(2026, 5, 4, 12, 0, 0, 0, time.UTC), + } + require.NoError(t, SaveConfig(cfg)) + + loaded, err = LoadConfig() + require.NoError(t, err) + require.NotNil(t, loaded) + assert.Equal(t, cfg.InstanceID, loaded.InstanceID) + assert.Equal(t, cfg.StartCount, loaded.StartCount) + assert.Equal(t, cfg.LastReportedStartCount, loaded.LastReportedStartCount) + assert.True(t, cfg.LastVersionCheckAt.Equal(loaded.LastVersionCheckAt)) +} + +func TestIsDisabledByEnv(t *testing.T) { + for _, value := range []string{"1", "true", "TRUE", "yes", "on"} { + assert.True(t, isDisabledByEnv(value), value) + } + for _, value := range []string{"", "0", "false", "no", "off", "anything"} { + assert.False(t, isDisabledByEnv(value), value) + } +} From 153ea5c2e2a189025ba527a66355d5c71f55c4e6 Mon Sep 17 00:00:00 2001 From: Alexis-Maurer Fortin Date: Tue, 5 May 2026 16:59:50 -0400 Subject: [PATCH 2/2] improved --- cmd/root.go | 23 ++++++++++++++++++----- versioncheck/versioncheck.go | 10 +++++----- versioncheck/versioncheck_test.go | 23 +++++++++++++++++++++++ 3 files changed, 46 insertions(+), 10 deletions(-) diff --git a/cmd/root.go b/cmd/root.go index 2bf185c0..3f99be07 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -95,15 +95,28 @@ By BoostSecurity.io - https://github.com/boostsecurityio/poutine `, }, } +// versionCheckSkipCommands lists subcommands that must not trigger the +// version check: "mcp-server" speaks JSON-RPC over stdio (no point delaying +// its handshake), and "completion" is invoked by shells for tab-completion +// lookups. Other subcommands (including "version" and "help") still pay the +// once-per-day check, since the 24h cache means at most one network call. +var versionCheckSkipCommands = map[string]struct{}{ + "mcp-server": {}, + "completion": {}, +} + // runVersionCheck performs the once-per-day update check unless disabled by -// flag, env var, or config. The "version" subcommand is excluded so users can -// inspect the binary without triggering a network call. +// flag, env var, or config. Commands listed in versionCheckSkipCommands and +// any subcommand under them are excluded so users can inspect the binary or +// run the MCP server without triggering a network call. func runVersionCheck(cmd *cobra.Command) { - if cmd != nil && cmd.Name() == "version" { - return + for c := cmd; c != nil; c = c.Parent() { + if _, skip := versionCheckSkipCommands[c.Name()]; skip { + return + } } disabled := disableVersionCheck || (config != nil && config.DisableVersionCheck) - result := versioncheck.Run(Version, disabled) + result := versioncheck.Run(cmd.Context(), Version, disabled) if result == nil || !result.UpdateAvailable { return } diff --git a/versioncheck/versioncheck.go b/versioncheck/versioncheck.go index e612102f..a35d4068 100644 --- a/versioncheck/versioncheck.go +++ b/versioncheck/versioncheck.go @@ -53,9 +53,9 @@ type options struct { // Run records a CLI start and, at most once every 24 hours, reports anonymous // telemetry to the configured endpoint and returns the latest release info. -// The whole operation is bounded by checkTimeout so it never noticeably -// slows down poutine startup. -func Run(version string, disabled bool) *Result { +// The HTTP call is bounded by checkTimeout (derived from ctx) so it never +// noticeably slows down poutine startup. +func Run(ctx context.Context, version string, disabled bool) *Result { if disabled || isDisabledByEnv(os.Getenv(DisableEnv)) { return nil } @@ -67,10 +67,10 @@ func Run(version string, disabled bool) *Result { recordStart(cfg, uuid.NewString) _ = SaveConfig(cfg) - ctx, cancel := context.WithTimeout(context.Background(), checkTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, checkTimeout) defer cancel() - result, _ := run(ctx, options{ + result, _ := run(timeoutCtx, options{ Config: cfg, Version: version, URL: VersionCheckURL, diff --git a/versioncheck/versioncheck_test.go b/versioncheck/versioncheck_test.go index 7c79def6..6e4dcad6 100644 --- a/versioncheck/versioncheck_test.go +++ b/versioncheck/versioncheck_test.go @@ -5,6 +5,8 @@ import ( "errors" "io" "net/http" + "os" + "path/filepath" "strings" "testing" "time" @@ -306,6 +308,27 @@ func TestLoadConfig_RoundTrip(t *testing.T) { assert.True(t, cfg.LastVersionCheckAt.Equal(loaded.LastVersionCheckAt)) } +func TestRun_DisabledShortCircuitsBeforeAnyDiskWrite(t *testing.T) { + dir := t.TempDir() + t.Setenv("POUTINE_CONFIG_DIR", dir) + t.Setenv(DisableEnv, "") + + // Disabled via the option (CLI flag / config path). + result := Run(context.Background(), "v0.18.0", true) + assert.Nil(t, result) + + _, err := os.Stat(filepath.Join(dir, "config.yaml")) + assert.True(t, os.IsNotExist(err), "no state file should be written when disabled") + + // Disabled via env var, even if the option is false. + t.Setenv(DisableEnv, "1") + result = Run(context.Background(), "v0.18.0", false) + assert.Nil(t, result) + + _, err = os.Stat(filepath.Join(dir, "config.yaml")) + assert.True(t, os.IsNotExist(err), "no state file should be written when disabled by env") +} + func TestIsDisabledByEnv(t *testing.T) { for _, value := range []string{"1", "true", "TRUE", "yes", "on"} { assert.True(t, isDisabledByEnv(value), value)