diff --git a/config.go b/config.go index 68c2098..d8eda25 100644 --- a/config.go +++ b/config.go @@ -60,15 +60,12 @@ func SetEnabled(home string, enabled bool) error { if b, err := os.ReadFile(p); err == nil { _ = json.Unmarshal(b, &raw) } - flagJSON, err := json.Marshal(EnabledFlag{Enabled: enabled}) - if err != nil { - return err - } + // EnabledFlag is all-primitive — json.Marshal cannot fail. + flagJSON, _ := json.Marshal(EnabledFlag{Enabled: enabled}) raw[configKey] = flagJSON - out, err := json.MarshalIndent(raw, "", " ") - if err != nil { - return err - } + // All values in raw are valid json.RawMessage (either freshly marshaled + // above, or from a successful Unmarshal upstream) — MarshalIndent cannot fail. + out, _ := json.MarshalIndent(raw, "", " ") tmp := p + ".tmp" if err := os.WriteFile(tmp, out, 0o600); err != nil { return err diff --git a/plugin_allowlist.go b/plugin_allowlist.go index 6298d74..4f33c21 100644 --- a/plugin_allowlist.go +++ b/plugin_allowlist.go @@ -254,10 +254,9 @@ func verifyOnDiskResult(configPath string, originalTopKeys map[string]struct{}, // the leaf key. Returns nil + empty string if any intermediate node is // not a JSON object and create is false. func walkObject(obj map[string]any, jsonPath string, create bool) (map[string]any, string) { + // strings.Split always returns at least one element (an empty string + // when jsonPath is ""), so a len(parts) == 0 guard is unreachable. parts := strings.Split(jsonPath, ".") - if len(parts) == 0 { - return nil, "" - } cur := obj for i := 0; i < len(parts)-1; i++ { p := parts[i] @@ -297,10 +296,10 @@ func allowListContains(obj map[string]any, jsonPath, id string) bool { } func ensureAllowListEntry(obj map[string]any, jsonPath, id string) error { + // walkObject with create=true always returns a non-nil parent (it + // materialises any missing intermediate object), so a nil check here + // would be unreachable. parent, leaf := walkObject(obj, jsonPath, true) - if parent == nil { - return fmt.Errorf("walk allow-list path %q: parent missing", jsonPath) - } cur, _ := parent[leaf].([]any) for _, v := range cur { if s, ok := v.(string); ok && s == id { @@ -329,10 +328,10 @@ func entryEnabled(obj map[string]any, jsonPath, id string) bool { } func ensureEntryEnabled(obj map[string]any, jsonPath, id string) error { + // walkObject with create=true always returns a non-nil parent (it + // materialises any missing intermediate object), so a nil check here + // would be unreachable. parent, leaf := walkObject(obj, jsonPath, true) - if parent == nil { - return fmt.Errorf("walk entries path %q: parent missing", jsonPath) - } entries, ok := parent[leaf].(map[string]any) if !ok { entries = map[string]any{} diff --git a/uninstall.go b/uninstall.go index e39bbea..ee3a943 100644 --- a/uninstall.go +++ b/uninstall.go @@ -297,12 +297,10 @@ func removePluginAllowListEntry(p *ManifestPlugin, cfgPath string) Removal { return r } - next, err := json.MarshalIndent(obj, "", " ") - if err != nil { - r.Action = RemovalError - r.Err = err.Error() - return r - } + // obj came from json.Unmarshal of a known-parseable config (we returned + // above otherwise), so every value is one of the standard library's + // JSON types — re-marshaling cannot fail. + next, _ := json.MarshalIndent(obj, "", " ") next = append(next, '\n') if err := writeFileAtomic(cfgPath, next, 0o644); err != nil { r.Action = RemovalError diff --git a/zz_helpers_test.go b/zz_helpers_test.go new file mode 100644 index 0000000..1c3a683 --- /dev/null +++ b/zz_helpers_test.go @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +//go:build !no_skillinject +// +build !no_skillinject + +package skillinject + +import ( + "path/filepath" + "strings" + "testing" +) + +func TestExpandHome_TildeSlash(t *testing.T) { + t.Parallel() + got := expandHome("~/.config/foo", "/home/alice") + want := filepath.Join("/home/alice", ".config/foo") + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestExpandHome_TildeOnly(t *testing.T) { + t.Parallel() + if got := expandHome("~", "/home/alice"); got != "/home/alice" { + t.Errorf("got %q, want /home/alice", got) + } +} + +func TestExpandHome_NoTildeReturnedAsIs(t *testing.T) { + t.Parallel() + if got := expandHome("/abs/path", "/home/x"); got != "/abs/path" { + t.Errorf("got %q", got) + } + if got := expandHome("relative/path", "/home/x"); got != "relative/path" { + t.Errorf("got %q", got) + } +} + +func TestRenderHeartbeat_BadTemplate(t *testing.T) { + t.Parallel() + _, err := renderHeartbeat([]byte(`{{.UnclosedAction`), heartbeatVars{}) + if err == nil { + t.Error("expected parse error") + } +} + +func TestRenderHeartbeat_HappyPath(t *testing.T) { + t.Parallel() + out, err := renderHeartbeat([]byte(`entry={{.EntrypointPath}}`), heartbeatVars{EntrypointPath: "/usr/local/bin/foo"}) + if err != nil { + t.Fatalf("renderHeartbeat: %v", err) + } + if !strings.Contains(out, "/usr/local/bin/foo") { + t.Errorf("output missing entrypoint: %q", out) + } +} diff --git a/zz_more_test.go b/zz_more_test.go new file mode 100644 index 0000000..1f2c36a --- /dev/null +++ b/zz_more_test.go @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +//go:build !no_skillinject +// +build !no_skillinject + +package skillinject + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "testing" + "time" + + "github.com/TeoSlayer/pilotprotocol/pkg/coreapi" +) + +func TestSetEnabled_RoundtripsThroughDisk(t *testing.T) { + t.Parallel() + home := t.TempDir() + if err := SetEnabled(home, false); err != nil { + t.Fatalf("SetEnabled(false): %v", err) + } + if IsEnabled(home) { + t.Error("after SetEnabled(false): IsEnabled = true, want false") + } + if err := SetEnabled(home, true); err != nil { + t.Fatalf("SetEnabled(true): %v", err) + } + if !IsEnabled(home) { + t.Error("after SetEnabled(true): IsEnabled = false, want true") + } +} + +func TestSetEnabled_PreservesOtherKeys(t *testing.T) { + t.Parallel() + home := t.TempDir() + cfgDir := filepath.Join(home, ".pilot") + if err := os.MkdirAll(cfgDir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + // Seed an existing config with an unrelated key. + if err := os.WriteFile(filepath.Join(cfgDir, "config.json"), + []byte(`{"other_key":"preserved"}`), 0600); err != nil { + t.Fatalf("seed: %v", err) + } + + if err := SetEnabled(home, false); err != nil { + t.Fatalf("SetEnabled: %v", err) + } + + body, err := os.ReadFile(filepath.Join(cfgDir, "config.json")) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + var raw map[string]json.RawMessage + if err := json.Unmarshal(body, &raw); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if _, ok := raw["other_key"]; !ok { + t.Error("other_key should be preserved across SetEnabled") + } + if _, ok := raw["skill_inject"]; !ok { + t.Error("skill_inject key missing after SetEnabled") + } +} + +func TestIsEnabled_MissingFileDefaultsTrue(t *testing.T) { + t.Parallel() + home := t.TempDir() + if !IsEnabled(home) { + t.Error("missing config: want true (opt-out)") + } +} + +func TestIsEnabled_BadJSONDefaultsTrue(t *testing.T) { + t.Parallel() + home := t.TempDir() + cfgDir := filepath.Join(home, ".pilot") + if err := os.MkdirAll(cfgDir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(cfgDir, "config.json"), []byte("not json"), 0600); err != nil { + t.Fatalf("write: %v", err) + } + if !IsEnabled(home) { + t.Error("bad JSON: want default true") + } +} + +func TestIsEnabled_MissingSubkeyDefaultsTrue(t *testing.T) { + t.Parallel() + home := t.TempDir() + cfgDir := filepath.Join(home, ".pilot") + if err := os.MkdirAll(cfgDir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(cfgDir, "config.json"), []byte(`{"unrelated":1}`), 0600); err != nil { + t.Fatalf("write: %v", err) + } + if !IsEnabled(home) { + t.Error("missing subkey: want default true") + } +} + +func TestIsEnabled_BadSubkeyDefaultsTrue(t *testing.T) { + t.Parallel() + home := t.TempDir() + cfgDir := filepath.Join(home, ".pilot") + if err := os.MkdirAll(cfgDir, 0755); err != nil { + t.Fatalf("mkdir: %v", err) + } + // skill_inject value is not the expected EnabledFlag shape. + if err := os.WriteFile(filepath.Join(cfgDir, "config.json"), + []byte(`{"skill_inject":42}`), 0600); err != nil { + t.Fatalf("write: %v", err) + } + if !IsEnabled(home) { + t.Error("bad subkey type: want default true") + } +} + +func TestService_Lifecycle(t *testing.T) { + t.Parallel() + cfg := Config{} + s := NewService(cfg) + if s == nil { + t.Fatal("NewService returned nil") + } + if s.Name() != "skillinject" { + t.Errorf("Name = %q", s.Name()) + } + if s.Order() != 200 { + t.Errorf("Order = %d, want 200", s.Order()) + } + + // Start with a context that cancels right away so Run exits cleanly. + ctx, cancel := context.WithCancel(context.Background()) + if err := s.Start(ctx, coreapi.Deps{}); err != nil { + t.Fatalf("Start: %v", err) + } + cancel() + + stopCtx, stopCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer stopCancel() + if err := s.Stop(stopCtx); err != nil { + t.Errorf("Stop: %v", err) + } +} + +func TestService_StopWithoutStart(t *testing.T) { + t.Parallel() + s := NewService(Config{}) + if err := s.Stop(context.Background()); err != nil { + t.Errorf("Stop without Start: %v", err) + } +} diff --git a/zz_pilot_bak_file_mode_test.go b/zz_pilot_bak_file_mode_test.go new file mode 100644 index 0000000..3a8b942 --- /dev/null +++ b/zz_pilot_bak_file_mode_test.go @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +package skillinject + +// Regression for P0 secret leakage: the `.pilot-bak` sidecar mirrors +// the operator's openclaw.json byte-for-byte, including any embedded +// API keys / tokens / credentials. The original implementation wrote +// it with mode 0o644 — world-readable on any multi-user host. Any +// local user (or any process running as a different UID) could read +// the operator's secrets via the backup file. +// +// Fix: write .pilot-bak with mode 0o600 (user-only read/write). +// Strictly tighter — no caller besides the same skillinject process +// reads this file, and it's removed on uninstall. + +import ( + "os" + "path/filepath" + "testing" +) + +func TestPilotBakWrittenWith0600(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + cfg := filepath.Join(dir, "openclaw.json") + + // Seed with content that triggers a real merge (and therefore a + // backup write). + original := map[string]any{ + "gateway": map[string]any{"mode": "auto"}, + "someToken": "super-secret-credential-blob", + } + writeJSON(t, cfg, original) + + if err := mergePluginAllowList(cfg, "plugins.allow", "plugins.entries", "pilot"); err != nil { + t.Fatalf("merge: %v", err) + } + + st, err := os.Stat(cfg + BackupSuffix) + if err != nil { + t.Fatalf("expected .pilot-bak at %s%s; got: %v", cfg, BackupSuffix, err) + } + + // We tolerate file modes <= 0600 (e.g. 0600 or stricter via umask). + // We REJECT anything that allows group or other to read. + mode := st.Mode().Perm() + if mode&0o077 != 0 { + t.Fatalf(".pilot-bak file mode = %o, want no group/other access (≤ 0600). "+ + "This file mirrors openclaw.json byte-for-byte and may contain "+ + "operator credentials.", mode) + } +} diff --git a/zz_prune_test.go b/zz_prune_test.go new file mode 100644 index 0000000..74d23ef --- /dev/null +++ b/zz_prune_test.go @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +//go:build !no_skillinject +// +build !no_skillinject + +package skillinject + +import ( + "os" + "path/filepath" + "testing" +) + +func TestPruneEmptyParent_BasenameMismatch(t *testing.T) { + t.Parallel() + dir := t.TempDir() + // Parent basename is the temp dir name — definitely not "pilot-protocol". + pruneEmptyParent(filepath.Join(dir, "child")) + // The dir must still exist. + if _, err := os.Stat(dir); err != nil { + t.Errorf("dir removed despite wrong basename: %v", err) + } +} + +func TestPruneEmptyParent_NonEmptyDirNotRemoved(t *testing.T) { + t.Parallel() + root := t.TempDir() + dir := filepath.Join(root, "pilot-protocol") + if err := os.MkdirAll(dir, 0700); err != nil { + t.Fatalf("mkdir: %v", err) + } + // Add a file so dir isn't empty. + if err := os.WriteFile(filepath.Join(dir, "x"), []byte("y"), 0600); err != nil { + t.Fatalf("write: %v", err) + } + pruneEmptyParent(filepath.Join(dir, "child")) + if _, err := os.Stat(dir); err != nil { + t.Errorf("non-empty dir removed: %v", err) + } +} + +func TestPruneEmptyParent_EmptyPilotProtocolDirRemoved(t *testing.T) { + t.Parallel() + root := t.TempDir() + dir := filepath.Join(root, "pilot-protocol") + if err := os.MkdirAll(dir, 0700); err != nil { + t.Fatalf("mkdir: %v", err) + } + pruneEmptyParent(filepath.Join(dir, "child")) + if _, err := os.Stat(dir); !os.IsNotExist(err) { + t.Errorf("empty pilot-protocol dir should be removed: %v", err) + } +} + +func TestDirIsEmpty_NonExistentReturnsFalse(t *testing.T) { + t.Parallel() + if dirIsEmpty("/no/such/path") { + t.Error("non-existent path: want false") + } +} + +func TestDirIsEmpty_EmptyDirReturnsTrue(t *testing.T) { + t.Parallel() + dir := t.TempDir() + if !dirIsEmpty(dir) { + t.Error("empty dir: want true") + } +} + +func TestDirIsEmpty_NonEmptyReturnsFalse(t *testing.T) { + t.Parallel() + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "x"), []byte("y"), 0600); err != nil { + t.Fatalf("write: %v", err) + } + if dirIsEmpty(dir) { + t.Error("non-empty dir: want false") + } +} diff --git a/zz_remove_entries_test.go b/zz_remove_entries_test.go new file mode 100644 index 0000000..fca8837 --- /dev/null +++ b/zz_remove_entries_test.go @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: AGPL-3.0-or-later + +//go:build !no_skillinject +// +build !no_skillinject + +package skillinject + +import ( + "testing" +) + +func TestRemoveAllowListEntry_HappyAndMissing(t *testing.T) { + t.Parallel() + obj := map[string]any{ + "allowlist": map[string]any{ + "plugins": []any{"a", "b", "c"}, + }, + } + // Remove existing entry. + if !removeAllowListEntry(obj, "allowlist.plugins", "b") { + t.Error("removeAllowListEntry should have removed 'b'") + } + // Verify b was removed. + arr, _ := obj["allowlist"].(map[string]any)["plugins"].([]any) + if len(arr) != 2 { + t.Errorf("arr len = %d, want 2", len(arr)) + } + + // Removing again should return false. + if removeAllowListEntry(obj, "allowlist.plugins", "b") { + t.Error("second removeAllowListEntry should return false") + } +} + +func TestRemoveAllowListEntry_ParentMissing(t *testing.T) { + t.Parallel() + if removeAllowListEntry(map[string]any{}, "nope.x", "id") { + t.Error("parent missing: want false") + } +} + +func TestRemoveAllowListEntry_NotAnArray(t *testing.T) { + t.Parallel() + obj := map[string]any{ + "plugins": "not-an-array", + } + if removeAllowListEntry(obj, "plugins", "id") { + t.Error("non-array: want false") + } +} + +func TestRemoveEntriesEntry_HappyAndMissing(t *testing.T) { + t.Parallel() + obj := map[string]any{ + "entries": map[string]any{ + "id-1": "value-1", + "id-2": "value-2", + }, + } + if !removeEntriesEntry(obj, "entries", "id-1") { + t.Error("removeEntriesEntry should have removed id-1") + } + entries, _ := obj["entries"].(map[string]any) + if _, ok := entries["id-1"]; ok { + t.Error("id-1 should be deleted") + } + + // Second removal returns false. + if removeEntriesEntry(obj, "entries", "id-1") { + t.Error("second removeEntriesEntry should return false") + } +} + +func TestRemoveEntriesEntry_ParentMissing(t *testing.T) { + t.Parallel() + if removeEntriesEntry(map[string]any{}, "nope.x", "id") { + t.Error("parent missing: want false") + } +} + +func TestRemoveEntriesEntry_NotAMap(t *testing.T) { + t.Parallel() + obj := map[string]any{ + "entries": []any{"not-a-map"}, + } + if removeEntriesEntry(obj, "entries", "id") { + t.Error("non-map: want false") + } +}