diff --git a/main.go b/main.go index cf2f969..ae0c271 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,10 @@ var version string = "dev" //go:embed skills/dci-cli var skillFS embed.FS +// customerContextFlagValue holds the --customer-context / -D flag value when +// set, used to suppress the Doer hint even when no persistent context file exists. +var customerContextFlagValue string + func dciConfigDir() string { if dir, err := os.UserConfigDir(); err == nil && dir != "" { cfgDir := filepath.Join(dir, "dci") @@ -129,6 +133,9 @@ func main() { } func run() (exitCode int) { + // Reset per-invocation state so repeated calls (e.g. in tests) start clean. + customerContextFlagValue = "" + defer func() { if r := recover(); r != nil { fmt.Fprintf(os.Stderr, "dci encountered an internal error: %v\n", r) @@ -193,11 +200,11 @@ func run() (exitCode int) { if err := cli.Run(); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) - maybeHintDoerContext(1, configDir) + maybeHintDoerContext(1, cli.GetLastStatus(), configDir) return 1 } code := cli.GetExitCode() - maybeHintDoerContext(code, configDir) + maybeHintDoerContext(code, cli.GetLastStatus(), configDir) return code } @@ -795,15 +802,15 @@ func authSource() string { // maybeHintDoerContext prints a targeted hint when a @doit.com user hits a 403 // without a customer context set — covering both interactive and CI/CD usage. -func maybeHintDoerContext(exitCode int, configDir string) { - status := cli.GetLastStatus() +// status is the HTTP status code from the last request (pass cli.GetLastStatus()). +func maybeHintDoerContext(exitCode int, status int, configDir string) { if exitCode == 0 || (status != 401 && status != 403) { return } if !cachedTokenIsDoer() { return } - if readCustomerContext(configDir) != "" { + if readCustomerContext(configDir) != "" || customerContextFlagValue != "" { return } if term.IsTerminal(int(os.Stderr.Fd())) { @@ -983,6 +990,7 @@ func addOutputFlag() { dciCmd.PersistentFlags().StringP("table-columns", "C", "", "Comma-separated list of columns to include (default: all)") dciCmd.PersistentFlags().IntP("table-width", "W", 0, "Table width in columns (default: auto-detect terminal width)") dciCmd.PersistentFlags().IntP("table-max-col-width", "X", 0, "Maximum width per column when fitting or wrapping (0 = auto)") + dciCmd.PersistentFlags().StringP("customer-context", "D", "", "Override the active customer context for this command (e.g. acme.com)") // Bind table flags into viper so the renderer can pick them up. prev := dciCmd.PersistentPreRunE @@ -1021,6 +1029,24 @@ func addOutputFlag() { bindNonNegativeIntFlag(cmd, "table-width") bindNonNegativeIntFlag(cmd, "table-max-col-width") + // If --customer-context / -D was explicitly passed, override whatever + // applyCustomerContext() injected from the file or env var. + if flag := cmd.Flags().Lookup("customer-context"); flag != nil && flag.Changed { + val := strings.TrimSpace(flag.Value.String()) + if val == "" { + return fmt.Errorf("--customer-context requires a non-empty domain name") + } + existing := viper.GetStringSlice("rsh-query") + filtered := existing[:0] + for _, q := range existing { + if !strings.HasPrefix(q, "customerContext=") { + filtered = append(filtered, q) + } + } + viper.Set("rsh-query", append(filtered, "customerContext="+val)) + customerContextFlagValue = val + } + return nil } } diff --git a/main_test.go b/main_test.go index 0cb62b4..457e57b 100644 --- a/main_test.go +++ b/main_test.go @@ -1316,3 +1316,64 @@ func TestApplyDoerContext(t *testing.T) { }) } } + +func TestCustomerContextFlag(t *testing.T) { + bin := buildBinary(t) + + t.Run("empty --customer-context errors", func(t *testing.T) { + home := t.TempDir() + res := runCLIWithEnv(t, bin, home, []string{"DCI_API_KEY=test-key"}, "list-budgets", "--customer-context", "") + if res.timedOut { + t.Fatalf("command timed out; output:\n%s", res.output) + } + if res.exitCode == 0 { + t.Fatalf("expected non-zero exit; output:\n%s", res.output) + } + if !strings.Contains(res.output, "--customer-context requires a non-empty domain name") { + t.Fatalf("expected error message in output:\n%s", res.output) + } + }) + + t.Run("-D short form appears in help", func(t *testing.T) { + home := t.TempDir() + res := runCLIWithEnv(t, bin, home, []string{"DCI_API_KEY=test-key"}, "list-budgets", "--help") + if res.timedOut { + t.Fatalf("command timed out; output:\n%s", res.output) + } + if !strings.Contains(res.output, "-D, --customer-context") { + t.Fatalf("expected -D/--customer-context flag in help output:\n%s", res.output) + } + }) + + t.Run("Doer hint suppressed when customerContextFlagValue set", func(t *testing.T) { + setupTestCache(t) + cli.Cache.Set(testTokenCacheKey, doerJWT()) + + // Simulate --customer-context flag having been set for this invocation. + customerContextFlagValue = "acme.com" + t.Cleanup(func() { customerContextFlagValue = "" }) + + dir := t.TempDir() + // No persistent context file — conditions that would normally trigger the hint. + + // Capture stderr. + r, w, _ := os.Pipe() + oldStderr := os.Stderr + os.Stderr = w + + // Call with exitCode=1 and status=403 — would print the hint for a Doer + // with no persistent context, unless customerContextFlagValue suppresses it. + maybeHintDoerContext(1, 403, dir) + + w.Close() + os.Stderr = oldStderr + buf := make([]byte, 4096) + n, _ := r.Read(buf) + output := string(buf[:n]) + r.Close() + + if strings.Contains(output, "DoiT employees need a customer context") { + t.Fatalf("expected hint to be suppressed, but got:\n%s", output) + } + }) +}