From d0e853695c030f9f0c42d7efe3aa9d5811ae25d3 Mon Sep 17 00:00:00 2001 From: jtdelia Date: Wed, 1 Apr 2026 15:16:39 -0700 Subject: [PATCH 1/3] feat: add --customer-context / -D flag for per-command context override --- main.go | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index cf2f969..dc036e8 100644 --- a/main.go +++ b/main.go @@ -31,6 +31,10 @@ var version string = "dev" //go:embed skills/dci-cli var skillFS embed.FS +// customerContextFlagValue holds the --customer-context / -D flag value when +// set, used to suppress the Doer hint even when no persistent context file exists. +var customerContextFlagValue string + func dciConfigDir() string { if dir, err := os.UserConfigDir(); err == nil && dir != "" { cfgDir := filepath.Join(dir, "dci") @@ -803,7 +807,7 @@ func maybeHintDoerContext(exitCode int, configDir string) { if !cachedTokenIsDoer() { return } - if readCustomerContext(configDir) != "" { + if readCustomerContext(configDir) != "" || customerContextFlagValue != "" { return } if term.IsTerminal(int(os.Stderr.Fd())) { @@ -983,6 +987,7 @@ func addOutputFlag() { dciCmd.PersistentFlags().StringP("table-columns", "C", "", "Comma-separated list of columns to include (default: all)") dciCmd.PersistentFlags().IntP("table-width", "W", 0, "Table width in columns (default: auto-detect terminal width)") dciCmd.PersistentFlags().IntP("table-max-col-width", "X", 0, "Maximum width per column when fitting or wrapping (0 = auto)") + dciCmd.PersistentFlags().StringP("customer-context", "D", "", "Override the active customer context for this command (e.g. acme.com)") // Bind table flags into viper so the renderer can pick them up. prev := dciCmd.PersistentPreRunE @@ -1021,6 +1026,24 @@ func addOutputFlag() { bindNonNegativeIntFlag(cmd, "table-width") bindNonNegativeIntFlag(cmd, "table-max-col-width") + // If --customer-context / -D was explicitly passed, override whatever + // applyCustomerContext() injected from the file or env var. + if flag := cmd.Flags().Lookup("customer-context"); flag != nil && flag.Changed { + val := strings.TrimSpace(flag.Value.String()) + if val == "" { + return fmt.Errorf("--customer-context requires a non-empty domain name") + } + existing := viper.GetStringSlice("rsh-query") + filtered := existing[:0] + for _, q := range existing { + if !strings.HasPrefix(q, "customerContext=") { + filtered = append(filtered, q) + } + } + viper.Set("rsh-query", append(filtered, "customerContext="+val)) + customerContextFlagValue = val + } + return nil } } From 3201acef9da15cf2112d8c7322194c81c075372b Mon Sep 17 00:00:00 2001 From: jtdelia Date: Wed, 1 Apr 2026 19:27:06 -0700 Subject: [PATCH 2/3] =?UTF-8?q?test:=20address=20PR=20review=20comments=20?= =?UTF-8?q?=E2=80=94=20reset=20global=20state=20and=20add=20unit=20tests?= =?UTF-8?q?=20for=20--customer-context=20flag?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 3 +++ main_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index dc036e8..70412df 100644 --- a/main.go +++ b/main.go @@ -133,6 +133,9 @@ func main() { } func run() (exitCode int) { + // Reset per-invocation state so repeated calls (e.g. in tests) start clean. + customerContextFlagValue = "" + defer func() { if r := recover(); r != nil { fmt.Fprintf(os.Stderr, "dci encountered an internal error: %v\n", r) diff --git a/main_test.go b/main_test.go index 0cb62b4..eb2f5c9 100644 --- a/main_test.go +++ b/main_test.go @@ -1310,9 +1310,76 @@ func TestApplyDoerContext(t *testing.T) { if got := applyDoerContext(dir); got != tt.wantResult { t.Errorf("applyDoerContext() = %v, want %v", got, tt.wantResult) } - if ctx := readCustomerContext(dir); ctx != tt.wantContext { - t.Errorf("customerContext = %q, want %q", ctx, tt.wantContext) - } - }) + if ctx := readCustomerContext(dir); ctx != tt.wantContext { + t.Errorf("customerContext = %q, want %q", ctx, tt.wantContext) + } + }) } } + +func TestCustomerContextFlag(t *testing.T) { + bin := buildBinary(t) + + t.Run("empty --customer-context errors", func(t *testing.T) { + home := t.TempDir() + res := runCLIWithEnv(t, bin, home, []string{"DCI_API_KEY=test-key"}, "list-budgets", "--customer-context", "") + if res.timedOut { + t.Fatalf("command timed out; output:\n%s", res.output) + } + if res.exitCode == 0 { + t.Fatalf("expected non-zero exit; output:\n%s", res.output) + } + if !strings.Contains(res.output, "--customer-context requires a non-empty domain name") { + t.Fatalf("expected error message in output:\n%s", res.output) + } + }) + + t.Run("-D short form appears in help", func(t *testing.T) { + home := t.TempDir() + res := runCLIWithEnv(t, bin, home, []string{"DCI_API_KEY=test-key"}, "list-budgets", "--help") + if res.timedOut { + t.Fatalf("command timed out; output:\n%s", res.output) + } + if !strings.Contains(res.output, "-D, --customer-context") { + t.Fatalf("expected -D/--customer-context flag in help output:\n%s", res.output) + } + }) + + t.Run("Doer hint suppressed when customerContextFlagValue set", func(t *testing.T) { + setupTestCache(t) + cli.Cache.Set(testTokenCacheKey, doerJWT()) + + // Simulate the flag having been set for this invocation. + customerContextFlagValue = "acme.com" + t.Cleanup(func() { customerContextFlagValue = "" }) + + dir := t.TempDir() + // No context file — Doer with no persistent context set. + // maybeHintDoerContext checks readCustomerContext OR customerContextFlagValue. + // With customerContextFlagValue set, it should return early without printing. + r, w, _ := os.Pipe() + oldStderr := os.Stderr + os.Stderr = w + + // exitCode=1 and status=403 would normally trigger the hint for a Doer + // with no persistent context. We can't set cli's internal lastStatus from + // outside the package, so we verify the guard directly: if + // customerContextFlagValue is non-empty, maybeHintDoerContext returns early + // regardless of exit code. We confirm by checking readCustomerContext is "" + // (no file) and that the function returns without writing to stderr. + // + // Since GetLastStatus() returns 0 (no HTTP call made), maybeHintDoerContext + // will return at the status check before reaching our guard — so we test + // the guard in isolation by confirming the global is respected. + if readCustomerContext(dir) != "" { + t.Fatal("expected no persistent context file") + } + if customerContextFlagValue == "" { + t.Fatal("expected customerContextFlagValue to be set") + } + + w.Close() + os.Stderr = oldStderr + r.Close() + }) +} From 90c4e2f1b68081f43ea9cd079456eea69066e269 Mon Sep 17 00:00:00 2001 From: jtdelia Date: Thu, 2 Apr 2026 09:19:12 -0700 Subject: [PATCH 3/3] refactor: make maybeHintDoerContext status testable; fix gofmt; strengthen hint suppression test --- main.go | 8 ++++---- main_test.go | 42 ++++++++++++++++++------------------------ 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/main.go b/main.go index 70412df..ae0c271 100644 --- a/main.go +++ b/main.go @@ -200,11 +200,11 @@ func run() (exitCode int) { if err := cli.Run(); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) - maybeHintDoerContext(1, configDir) + maybeHintDoerContext(1, cli.GetLastStatus(), configDir) return 1 } code := cli.GetExitCode() - maybeHintDoerContext(code, configDir) + maybeHintDoerContext(code, cli.GetLastStatus(), configDir) return code } @@ -802,8 +802,8 @@ func authSource() string { // maybeHintDoerContext prints a targeted hint when a @doit.com user hits a 403 // without a customer context set — covering both interactive and CI/CD usage. -func maybeHintDoerContext(exitCode int, configDir string) { - status := cli.GetLastStatus() +// status is the HTTP status code from the last request (pass cli.GetLastStatus()). +func maybeHintDoerContext(exitCode int, status int, configDir string) { if exitCode == 0 || (status != 401 && status != 403) { return } diff --git a/main_test.go b/main_test.go index eb2f5c9..457e57b 100644 --- a/main_test.go +++ b/main_test.go @@ -1310,10 +1310,10 @@ func TestApplyDoerContext(t *testing.T) { if got := applyDoerContext(dir); got != tt.wantResult { t.Errorf("applyDoerContext() = %v, want %v", got, tt.wantResult) } - if ctx := readCustomerContext(dir); ctx != tt.wantContext { - t.Errorf("customerContext = %q, want %q", ctx, tt.wantContext) - } - }) + if ctx := readCustomerContext(dir); ctx != tt.wantContext { + t.Errorf("customerContext = %q, want %q", ctx, tt.wantContext) + } + }) } } @@ -1349,37 +1349,31 @@ func TestCustomerContextFlag(t *testing.T) { setupTestCache(t) cli.Cache.Set(testTokenCacheKey, doerJWT()) - // Simulate the flag having been set for this invocation. + // Simulate --customer-context flag having been set for this invocation. customerContextFlagValue = "acme.com" t.Cleanup(func() { customerContextFlagValue = "" }) dir := t.TempDir() - // No context file — Doer with no persistent context set. - // maybeHintDoerContext checks readCustomerContext OR customerContextFlagValue. - // With customerContextFlagValue set, it should return early without printing. + // No persistent context file — conditions that would normally trigger the hint. + + // Capture stderr. r, w, _ := os.Pipe() oldStderr := os.Stderr os.Stderr = w - // exitCode=1 and status=403 would normally trigger the hint for a Doer - // with no persistent context. We can't set cli's internal lastStatus from - // outside the package, so we verify the guard directly: if - // customerContextFlagValue is non-empty, maybeHintDoerContext returns early - // regardless of exit code. We confirm by checking readCustomerContext is "" - // (no file) and that the function returns without writing to stderr. - // - // Since GetLastStatus() returns 0 (no HTTP call made), maybeHintDoerContext - // will return at the status check before reaching our guard — so we test - // the guard in isolation by confirming the global is respected. - if readCustomerContext(dir) != "" { - t.Fatal("expected no persistent context file") - } - if customerContextFlagValue == "" { - t.Fatal("expected customerContextFlagValue to be set") - } + // Call with exitCode=1 and status=403 — would print the hint for a Doer + // with no persistent context, unless customerContextFlagValue suppresses it. + maybeHintDoerContext(1, 403, dir) w.Close() os.Stderr = oldStderr + buf := make([]byte, 4096) + n, _ := r.Read(buf) + output := string(buf[:n]) r.Close() + + if strings.Contains(output, "DoiT employees need a customer context") { + t.Fatalf("expected hint to be suppressed, but got:\n%s", output) + } }) }