diff --git a/README.md b/README.md index 88a761d..54351e9 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,22 @@ Renaming ./alphabet to betabet * Files and directories are renamed. * Searches are performed recursively from the current working directory. * Searches are case sensitive. -* `.git/` directories are skipped. -* Binary files are ignored. +* `.git` is always skipped (whether it's a directory or a worktree-linkage file). +* Binary files (those with a NUL byte or invalid UTF-8 in the first 1 KiB) are ignored. +* Symbolic links are not followed and not rewritten. +* Files with the setuid or setgid bit set are not rewritten. +* Original mode, owner/group, and modification time are preserved across rewrites. +* Exits non-zero if any file failed to be rewritten or renamed (the rest of the tree is still processed on a best-effort basis). + +## Security model + +`find-replace` is designed to be safe to run inside an untrusted directory: + +* It will never traverse out of the working tree through a symbolic link. +* Temp files used during rewrites are created with `O_EXCL` under a `crypto/rand`-generated name, so a co-resident attacker cannot pre-create the temp-file path to redirect the write. +* Renames refuse to overwrite an existing destination (implemented as a hardlink + remove pair, so the existence check and the directory-entry creation are a single step). + +When run as root, `find-replace` preserves the original uid/gid of every rewritten file. Files with `setuid`/`setgid` bits are skipped to avoid producing a setuid binary owned by the wrong user. ## Goal diff --git a/build.sh b/build.sh index f8b9fd6..27271c6 100755 --- a/build.sh +++ b/build.sh @@ -2,7 +2,7 @@ set -e # Vet -go vet +go vet ./... # Build GIT_COMMIT="$(git rev-parse --short $(git rev-list -1 HEAD))" @@ -27,5 +27,5 @@ go build \ -X 'main.BuildTainted=$BUILD_TAINTED'" \ ./... -# Test -go test -cover -v ./... +# Test (with race detector to catch concurrency bugs). +go test -race -cover -v ./... diff --git a/chown_unix.go b/chown_unix.go new file mode 100644 index 0000000..701731e --- /dev/null +++ b/chown_unix.go @@ -0,0 +1,32 @@ +//go:build unix + +package main + +import ( + "errors" + "io/fs" + "os" + "syscall" +) + +// chownToOriginal applies the original uid/gid from info to path. Best-effort: +// permission errors (typical for non-root users) are silently ignored. Other +// errors are returned. +func chownToOriginal(path string, info os.FileInfo) error { + st, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return nil + } + err := os.Chown(path, int(st.Uid), int(st.Gid)) + if err == nil { + return nil + } + // Non-root processes lack CAP_CHOWN and will get EPERM here. That's + // expected: the temp file already has the correct uid (the running + // user), and the running user is the owner of the original file too + // (otherwise they couldn't have rewritten it). Don't fail. + if errors.Is(err, fs.ErrPermission) { + return nil + } + return err +} diff --git a/chown_windows.go b/chown_windows.go new file mode 100644 index 0000000..ff74877 --- /dev/null +++ b/chown_windows.go @@ -0,0 +1,10 @@ +//go:build windows + +package main + +import "os" + +func chownToOriginal(path string, info os.FileInfo) error { + // Windows does not have a meaningful equivalent. + return nil +} diff --git a/file_handling.go b/file_handling.go index 9844c3a..a65234f 100644 --- a/file_handling.go +++ b/file_handling.go @@ -1,88 +1,265 @@ package main import ( + "bytes" + "errors" + "fmt" "io" - "log" "os" "path/filepath" - "strings" - - "golang.org/x/tools/godoc/util" + "unicode/utf8" ) -type File struct { - Path string - info os.FileInfo -} +// sniffSize is the number of leading bytes inspected when deciding whether a +// file looks textual. Matches the historical behavior. +const sniffSize = 1024 -func NewFile(path string) *File { - absPath, err := filepath.Abs(path) - if err != nil { - log.Fatalf("Unable to resolve absolute path of %v: %v", path, err) +// rewriteBufSize is the working buffer size used by the streaming rewriter. +// Sized to balance syscall count against per-call memory. +const rewriteBufSize = 64 * 1024 + +// looksBinary reports whether the given prefix appears to be binary content. +// A file is considered binary if it contains a NUL byte or an invalid UTF-8 +// sequence in its sniffed prefix. Empty content is treated as text (a no-op +// rewrite is harmless). +func looksBinary(prefix []byte) bool { + if bytes.IndexByte(prefix, 0) >= 0 { + return true } - return &File{Path: absPath} + return !utf8.Valid(prefix) } -func (f *File) Base() string { - return filepath.Base(f.Path) +// renameNoReplace renames src to dst, returning os.ErrExist if dst already +// exists. Implemented as a hardlink-then-remove so the existence check and +// the directory-entry creation are a single atomic step. Falls back to a +// best-effort exists-check + Rename when hardlinks are not supported (for +// example renaming across filesystems or onto filesystems without link +// support). +func renameNoReplace(src, dst string) error { + err := os.Link(src, dst) + if err == nil { + // Link succeeded, source can be unlinked. + if rmErr := os.Remove(src); rmErr != nil { + // Try to undo the link so we don't end up with two copies. + _ = os.Remove(dst) + return rmErr + } + return nil + } + if errors.Is(err, os.ErrExist) { + return err + } + // Hardlinks unsupported (EXDEV across filesystems, EPERM on certain + // filesystems). Fall back to a TOCTOU-prone but functional rename. + if _, statErr := os.Lstat(dst); statErr == nil { + return os.ErrExist + } + return os.Rename(src, dst) } -func (f *File) Dir() string { - return filepath.Dir(f.Path) -} +// rewriteFile rewrites the contents of path, replacing every occurrence of +// find with replace. Returns true if the file was modified. The rewrite is +// atomic with respect to readers: a temp file is written under O_EXCL in the +// same directory and renamed over the original on success. +// +// Files that look binary in their first sniffSize bytes are skipped. Files +// that contain no occurrences of find are also skipped (no temp file is +// written). +// +// The original mode (including setuid/setgid/sticky bits), owner/group, and +// modification time are preserved on the rewritten inode. Setuid bits will +// still be cleared by the kernel if the temp file's uid differs from the +// original; passing info from the caller lets us reapply ownership before +// the rename to minimize the window. +func rewriteFile(path string, find, replace []byte, info os.FileInfo) (changed bool, err error) { + // Refuse to rewrite files with setuid/setgid bits. We cannot reliably + // preserve those bits across a tempfile-and-rename: the kernel strips + // them on any uid change, and producing a setuid binary owned by the + // running user from one we found owned by someone else would be a + // privilege footgun. Better to leave such files alone. + if info.Mode()&(os.ModeSetuid|os.ModeSetgid) != 0 { + return false, fmt.Errorf("refusing to rewrite %s: setuid/setgid bit is set", path) + } -func (f *File) Info() os.FileInfo { - if f.info == nil { - stat, err := os.Stat(f.Path) - if err != nil { - log.Fatalf("Failed to stat %v: %v", f.Path, err) - } - f.info = stat + // We have to know whether the content needs rewriting before we create a + // temp file (otherwise we'd thrash the filesystem on every no-op file). + // Read the prefix, sniff it, and remember it so we don't have to seek + // back later. + in, err := os.Open(path) + if err != nil { + return false, err } - return f.info -} + defer in.Close() -func (f *File) Mode() os.FileMode { - return f.Info().Mode() -} + prefix := make([]byte, sniffSize) + n, err := io.ReadFull(in, prefix) + prefix = prefix[:n] + switch { + case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): + // Short file; prefix contains the entire content. + err = nil + case err != nil: + return false, err + } + if looksBinary(prefix) { + return false, nil + } + + // Cheap pre-check: if the file is entirely contained in our prefix and + // it doesn't contain `find`, there's nothing to do. The streaming path + // below would also detect this, but it would still create a temp file. + // For larger files we let the streaming pass discover "no match" by + // observing that no replacement was performed. + if n < sniffSize && !bytes.Contains(prefix, find) { + return false, nil + } -// Read the file into a string. -func (f *File) Read() string { - handle, err := os.Open(f.Path) + // Atomic write: create a temp file with O_EXCL via os.CreateTemp, stream + // the rewrite into it, fsync, rename over the original. If anything + // fails, the deferred remove cleans up the temp file. + tmp, err := os.CreateTemp(filepath.Dir(path), ".find-replace-*") if err != nil { - log.Fatalf("Unable to open %v: %v", f.Path, err) + return false, err } - defer handle.Close() + tmpName := tmp.Name() + // Always attempt to clean up the temp file. A successful rename consumes + // the path, so the remove becomes a harmless no-op in that case. + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpName) + } + }() - // Check if the file looks like text before reading the entire file. - var buf [1024]byte - n, err := handle.Read(buf[0:]) - if err != nil || !util.IsText(buf[0:n]) { - return "" + // Stream the rest of the file, prepending the prefix we already read. + rest := io.MultiReader(bytes.NewReader(prefix), in) + wrote, err := streamReplace(tmp, rest, find, replace) + if err != nil { + _ = tmp.Close() + return false, err + } + if err = tmp.Sync(); err != nil { + _ = tmp.Close() + return false, err + } + if err = tmp.Close(); err != nil { + return false, err + } + if !wrote.changed { + // Nothing actually replaced; leave the original alone. + return false, nil } - // Reset file handle so we can read the entire file. - if _, err := handle.Seek(0, io.SeekStart); err != nil { - log.Fatalf("Failed to seek back to beginning of %v: %v", f.Path, err) + // Preserve metadata before the rename. Order: chown first (so that the + // later chmod doesn't get its setuid/setgid bits cleared by a uid + // change), then chmod, then chtimes. + if err = preserveMetadata(tmpName, info); err != nil { + return false, err } - builder := new(strings.Builder) - if _, err := io.Copy(builder, handle); err != nil { - log.Fatalf("Failed to read %v to a string: %v", f.Path, err) + if err = os.Rename(tmpName, path); err != nil { + return false, fmt.Errorf("rename %s -> %s: %w", tmpName, path, err) } - return builder.String() + cleanup = false + return true, nil +} + +type rewriteStats struct { + changed bool } -// Write content to file atomically, by writing it to a temporary file first, -// and then moving it to the destination, overwriting the original. -func (f *File) Write(content string) { - tempName := filepath.Join(f.Dir(), RandomString(20)) - if err := os.WriteFile(tempName, []byte(content), f.Mode()); err != nil { - log.Fatalf("Error creating tempfile in %v: %v", f.Dir(), err) +// streamReplace copies r to w, replacing every occurrence of find with +// replace. Memory usage is bounded by the size of the working buffer plus the +// length of `find`. Returns rewriteStats.changed=true if at least one +// replacement was made. +func streamReplace(w io.Writer, r io.Reader, find, replace []byte) (rewriteStats, error) { + var stats rewriteStats + if len(find) == 0 { + // Defensive: callers should reject empty find before this point. + // Behave as a plain copy to avoid pathological output. + _, err := io.Copy(w, r) + return stats, err + } + + // Buffer is sized so that a full `find` plus a non-trivial chunk fits + // after we carry up to (len(find)-1) bytes from the previous iteration. + bufSize := rewriteBufSize + if bufSize < 2*len(find) { + bufSize = 2 * len(find) + } + buf := make([]byte, bufSize) + keep := 0 + for { + n, readErr := io.ReadFull(r, buf[keep:]) + end := keep + n + eof := errors.Is(readErr, io.EOF) || errors.Is(readErr, io.ErrUnexpectedEOF) + if readErr != nil && !eof { + return stats, readErr + } + + // Scan the whole buffer for matches. A match found here is fully + // contained in buf[0:end] (bytes.Index requires the entire pattern + // to fit inside the search slice). + i := 0 + for i < end { + j := bytes.Index(buf[i:end], find) + if j < 0 { + break + } + if _, err := w.Write(buf[i : i+j]); err != nil { + return stats, err + } + if _, err := w.Write(replace); err != nil { + return stats, err + } + stats.changed = true + i += j + len(find) + } + + if eof { + // Emit anything left over and we're done. + if i < end { + if _, err := w.Write(buf[i:end]); err != nil { + return stats, err + } + } + return stats, nil + } + + // Determine how much of the unmatched tail is safe to emit. The + // last (len(find)-1) bytes might be the start of a match that + // completes after the next read, so they must be carried forward. + safeEnd := end - (len(find) - 1) + if safeEnd < i { + safeEnd = i + } + if i < safeEnd { + if _, err := w.Write(buf[i:safeEnd]); err != nil { + return stats, err + } + } + + copy(buf, buf[safeEnd:end]) + keep = end - safeEnd } +} - log.Printf("Rewriting %v", f.Path) - if err := os.Rename(tempName, f.Path); err != nil { - log.Fatalf("Unable to atomically move temp file %v to %v: %v", tempName, f.Path, err) +// preserveMetadata copies mode, ownership, and modification time from info +// onto path. Best-effort: a failure to chown when not running as root is +// expected (no CAP_CHOWN) and not treated as an error; other failures abort +// the rewrite. +func preserveMetadata(path string, info os.FileInfo) error { + // chown first: changing ownership can clear setuid/setgid bits on some + // systems, so we want to apply the mode after chown. + if err := chownToOriginal(path, info); err != nil { + return err + } + // Mode includes Perm and the special bits (setuid, setgid, sticky). + if err := os.Chmod(path, info.Mode()&os.ModePerm|info.Mode()&(os.ModeSetuid|os.ModeSetgid|os.ModeSticky)); err != nil { + return err + } + if err := os.Chtimes(path, info.ModTime(), info.ModTime()); err != nil { + return err } + return nil } diff --git a/file_handling_test.go b/file_handling_test.go index 9dfcaee..a3d6545 100644 --- a/file_handling_test.go +++ b/file_handling_test.go @@ -1,6 +1,252 @@ package main -import "testing" +import ( + "bytes" + "errors" + "os" + "path/filepath" + "strings" + "testing" +) -func TestNewFile(t *testing.T) { +func statOrFail(t testing.TB, path string) os.FileInfo { + t.Helper() + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + return info +} + +func TestLooksBinary(t *testing.T) { + tests := []struct { + name string + input []byte + want bool + }{ + {"empty", []byte{}, false}, + {"plain text", []byte("hello world\n"), false}, + {"utf8", []byte("héllo wörld"), false}, + {"NUL byte", []byte("hello\x00world"), true}, + {"invalid UTF-8", []byte{0xff, 0xfe, 0xfd}, true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := looksBinary(tc.input); got != tc.want { + t.Errorf("looksBinary(%q) = %v, want %v", tc.input, got, tc.want) + } + }) + } +} + +func TestStreamReplace(t *testing.T) { + tests := []struct { + name string + in string + find string + replace string + want string + changed bool + }{ + {"no match", "alpha bravo", "zulu", "yankee", "alpha bravo", false}, + {"single match", "alpha", "ph", "f", "alfa", true}, + {"multiple matches", "alphaalpha", "ph", "f", "alfaalfa", true}, + {"match at start", "alpha bravo", "alpha", "delta", "delta bravo", true}, + {"match at end", "alpha bravo", "bravo", "delta", "alpha delta", true}, + {"longer replacement", "ab", "a", "xxxxxx", "xxxxxxb", true}, + {"shorter replacement", "abcabc", "abc", "x", "xx", true}, + {"newlines preserved", "foo\nbar\nfoo", "foo", "BAZ", "BAZ\nbar\nBAZ", true}, + {"replace with empty", "remove this please", "this ", "", "remove please", true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var out bytes.Buffer + stats, err := streamReplace(&out, strings.NewReader(tc.in), + []byte(tc.find), []byte(tc.replace)) + if err != nil { + t.Fatal(err) + } + if got := out.String(); got != tc.want { + t.Errorf("output = %q, want %q", got, tc.want) + } + if stats.changed != tc.changed { + t.Errorf("changed = %v, want %v", stats.changed, tc.changed) + } + }) + } +} + +// TestStreamReplaceMatchAcrossBuffers exercises a match that spans the +// boundary between two buffer-fills. Builds an input that's larger than the +// rewrite buffer with the match deliberately straddling the boundary. +func TestStreamReplaceMatchAcrossBuffers(t *testing.T) { + find := "needle" + replace := "PIN" + + // Place the find string at offset rewriteBufSize - 3 so the first 3 + // bytes are in the first buffer fill and the last 3 are in the second. + prefix := strings.Repeat("a", rewriteBufSize-3) + tail := strings.Repeat("b", 100) + in := prefix + find + tail + want := prefix + replace + tail + + var out bytes.Buffer + stats, err := streamReplace(&out, strings.NewReader(in), []byte(find), []byte(replace)) + if err != nil { + t.Fatal(err) + } + if !stats.changed { + t.Fatal("expected changed=true") + } + if got := out.String(); got != want { + t.Errorf("boundary match not handled correctly (got len=%d, want len=%d)", + len(got), len(want)) + } +} + +func TestRewriteFileSkipsBinary(t *testing.T) { + d := t.TempDir() + path := filepath.Join(d, "binary") + original := []byte{0x00, 0x01, 0x02, 'a', 'l', 'p', 'h', 'a'} + if err := os.WriteFile(path, original, 0o600); err != nil { + t.Fatal(err) + } + + changed, err := rewriteFile(path, []byte("alpha"), []byte("BETA"), statOrFail(t, path)) + if err != nil { + t.Fatal(err) + } + if changed { + t.Error("expected binary file to be skipped") + } + got, _ := os.ReadFile(path) + if !bytes.Equal(got, original) { + t.Errorf("binary file was modified: %q", got) + } +} + +func TestRewriteFileNoMatchLeavesOriginalUntouched(t *testing.T) { + d := t.TempDir() + path := filepath.Join(d, "f.txt") + if err := os.WriteFile(path, []byte("hello world"), 0o600); err != nil { + t.Fatal(err) + } + stat0, _ := os.Stat(path) + + changed, err := rewriteFile(path, []byte("xyz"), []byte("abc"), statOrFail(t, path)) + if err != nil { + t.Fatal(err) + } + if changed { + t.Error("expected no change") + } + + stat1, _ := os.Stat(path) + if !stat0.ModTime().Equal(stat1.ModTime()) { + t.Error("file was rewritten despite no match (mtime changed)") + } + + // And no temp files left behind. + entries, _ := os.ReadDir(d) + if len(entries) != 1 { + var names []string + for _, e := range entries { + names = append(names, e.Name()) + } + t.Errorf("expected 1 file in dir, got %d: %v", len(entries), names) + } +} + +func TestRewriteFileLeavesNoTempfile(t *testing.T) { + d := t.TempDir() + path := filepath.Join(d, "f.txt") + if err := os.WriteFile(path, []byte("hello alpha world"), 0o600); err != nil { + t.Fatal(err) + } + + changed, err := rewriteFile(path, []byte("alpha"), []byte("BETA"), statOrFail(t, path)) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Error("expected change=true") + } + + entries, _ := os.ReadDir(d) + if len(entries) != 1 { + var names []string + for _, e := range entries { + names = append(names, e.Name()) + } + t.Errorf("expected exactly one file (no temp files leaked), got %v", names) + } +} + +func TestRewriteFilePreservesMode(t *testing.T) { + d := t.TempDir() + path := filepath.Join(d, "f.txt") + if err := os.WriteFile(path, []byte("alpha"), 0o640); err != nil { + t.Fatal(err) + } + + changed, err := rewriteFile(path, []byte("alpha"), []byte("BETA"), statOrFail(t, path)) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("expected change=true") + } + stat, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if stat.Mode().Perm() != 0o640 { + t.Errorf("mode = %v, want %v", stat.Mode().Perm(), os.FileMode(0o640)) + } +} + +func TestRenameNoReplaceFailsOnExist(t *testing.T) { + d := t.TempDir() + src := filepath.Join(d, "src") + dst := filepath.Join(d, "dst") + if err := os.WriteFile(src, []byte("source"), 0o600); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(dst, []byte("destination"), 0o600); err != nil { + t.Fatal(err) + } + + err := renameNoReplace(src, dst) + if !errors.Is(err, os.ErrExist) { + t.Errorf("expected ErrExist, got %v", err) + } + + // Source must still exist; dst content untouched. + if _, err := os.Stat(src); err != nil { + t.Errorf("source was removed: %v", err) + } + got, _ := os.ReadFile(dst) + if string(got) != "destination" { + t.Errorf("dst was overwritten: %q", got) + } +} + +func TestRenameNoReplaceSucceeds(t *testing.T) { + d := t.TempDir() + src := filepath.Join(d, "src") + dst := filepath.Join(d, "dst") + if err := os.WriteFile(src, []byte("source"), 0o600); err != nil { + t.Fatal(err) + } + + if err := renameNoReplace(src, dst); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(src); !errors.Is(err, os.ErrNotExist) { + t.Errorf("source still exists: %v", err) + } + got, _ := os.ReadFile(dst) + if string(got) != "source" { + t.Errorf("dst content = %q, want source", got) + } } diff --git a/find_replace.go b/find_replace.go index a2f2c65..91fe49a 100644 --- a/find_replace.go +++ b/find_replace.go @@ -2,34 +2,36 @@ package main import ( "errors" + "fmt" + "io/fs" "log" - "math/rand" "os" "path/filepath" "strings" - "sync" - "time" ) -// findReplace is a struct used to provide context to all find & replace -// operations, including the strings to search for & replace. +// findReplace carries the parameters and counters for a single run. type findReplace struct { - find string - replace string + find string + replace string + findBytes []byte // pre-computed []byte form of find, reused per file + replaceBytes []byte // pre-computed []byte form of replace, reused per file + errors int +} + +// Reserved tempfile prefix used by rewriteFile. Skipped during traversal so +// orphans from a crashed prior run don't get picked up as targets. +const tempPrefix = ".find-replace-" + +// Skipped base names. These are skipped whether the entry is a file or a +// directory: `.git` in particular can be either a directory (a normal repo) +// or a file (a worktree/submodule pointing at the real `.git` elsewhere). +var skipNames = map[string]struct{}{ + ".git": {}, } -// main processes command line arguments, builds the context struct, and begins -// the process of walking the current working directory. -// -// Variable terminology used throughout this module: -// -// • dirName: the name of a directory, without a trailing separator -// • baseName: the relative name of a file, without a directory -// • path: the relative path to a specific file or directory, including both dirName and baseName func main() { - // Remove date/time from logging output log.SetFlags(0) - rand.Seed(time.Now().UnixNano()) if len(os.Args) != 3 { log.Fatal("Usage: find-replace FIND REPLACE") @@ -38,87 +40,134 @@ func main() { find := os.Args[1] replace := os.Args[2] - fr := findReplace{find: find, replace: replace} - - // Recursively explore the hierarchy depth first, rewrite files as needed, - // and rename files last (after we don't have to revisit them). - // path.filepath.WalkDir() won't work here because it walks files - // alphabetically, breadth-first (and you'd be renaming files that you - // haven't explored yet). + if find == "" { + log.Fatal("FIND must be non-empty") + } + if find == replace { + // Nothing to do, but it's a user error worth flagging. + log.Fatal("FIND and REPLACE are identical; nothing to do") + } + fr := findReplace{ + find: find, + replace: replace, + findBytes: []byte(find), + replaceBytes: []byte(replace), + } fr.WalkDir(NewFile(".")) + + if fr.errors > 0 { + os.Exit(1) + } } -// Walks files in the directory given by dirName, which is a relative path to a -// directory. Calls HandleFile for each file it finds, if it's not ignored. +// WalkDir traverses the directory tree rooted at f depth-first, rewriting file +// contents and renaming entries on the way back up. It runs single-threaded +// to keep memory and file-descriptor usage bounded and to avoid races between +// concurrent renames in the same parent directory. func (fr *findReplace) WalkDir(f *File) { - var wg sync.WaitGroup - - // List the files in this directory. - files, err := os.ReadDir(f.Path) + // Pre-compute byte forms once so we don't reallocate them per file. + if fr.findBytes == nil { + fr.findBytes = []byte(fr.find) + } + if fr.replaceBytes == nil { + fr.replaceBytes = []byte(fr.replace) + } + entries, err := os.ReadDir(f.Path) if err != nil { - log.Fatalf("Unable to read directory: %v", err) + fr.recordErr(fmt.Errorf("read directory %s: %w", f.Path, err)) + return } - for _, file := range files { - childFile := NewFile(filepath.Join(f.Path, file.Name())) - wg.Add(1) - go func() { - defer wg.Done() - fr.HandleFile(childFile) - }() + for _, entry := range entries { + // Use the DirEntry directly instead of stat'ing each child; this + // avoids both a redundant syscall and the symlink-following side + // effect of os.Stat. (Issues #2, #13.) + fr.handleEntry(f, entry) } +} + +func (fr *findReplace) handleEntry(parent *File, entry fs.DirEntry) { + name := entry.Name() + mode := entry.Type() - wg.Wait() // for (potentially recursive) calls to return + // Skip our own orphaned tempfiles from a prior crashed run. + if strings.HasPrefix(name, tempPrefix) { + return + } + // Skip names like `.git` regardless of file vs directory: a `.git` file + // is the worktree/submodule linkage and rewriting it corrupts the link. + if _, skip := skipNames[name]; skip { + return + } + // Symlinks are skipped entirely (issue #2). We never follow them and we + // don't rewrite or rename them — renaming a symlink only renames the + // link itself, which is harmless, but skipping is clearer and avoids + // subtle interactions with the rename phase below. + if mode&os.ModeSymlink != 0 { + return + } + + child := newChildFile(parent.Path, name) + + if entry.IsDir() { + fr.WalkDir(child) + } else if mode.IsRegular() { + fr.rewriteContents(child, entry) + } else { + // Devices, named pipes, sockets, etc.: leave alone. + return + } + + fr.RenameFile(child) } -// HandleFile immediately recurses depth-first into directories it finds, -// otherwise calls ReplaceContents for regular files. When either operation is -// complete, the file is renamed (if necessary) since no subsequent operations -// will need to access it again. -func (fr *findReplace) HandleFile(f *File) { - // If file is a directory, recurse immediately (depth-first). - if f.Info().IsDir() { - // Ignore certain directories - if f.Base() == ".git" { +// rewriteContents replaces fr.find with fr.replace in the file's bytes, +// streaming through a bounded buffer. The original file's mode is preserved. +func (fr *findReplace) rewriteContents(f *File, entry fs.DirEntry) { + info, err := entry.Info() + if err != nil { + // Race with another process; treat as missing and continue. + if errors.Is(err, fs.ErrNotExist) { return } - fr.WalkDir(f) - } else { - // Replace the contents of regular files - fr.ReplaceContents(f) + fr.recordErr(fmt.Errorf("stat %s: %w", f.Path, err)) + return } - // Rename the file now that we're otherwise done with it - fr.RenameFile(f) + changed, err := rewriteFile(f.Path, fr.findBytes, fr.replaceBytes, info) + if err != nil { + fr.recordErr(fmt.Errorf("rewrite %s: %w", f.Path, err)) + return + } + if changed { + log.Printf("Rewriting %v", f.Path) + } } -// RenameFile renames a file if the destination file name does not already -// exist. +// RenameFile renames f to the same path with fr.find replaced by fr.replace +// in the basename, only if the destination does not already exist. Uses +// renameNoReplace to close the TOCTOU window in the existence check. func (fr *findReplace) RenameFile(f *File) { - newBaseName := strings.Replace(f.Base(), fr.find, fr.replace, -1) - newPath := filepath.Join(f.Dir(), newBaseName) - - if f.Base() != newBaseName { - if _, err := os.Stat(newPath); errors.Is(err, os.ErrNotExist) { - log.Printf("Renaming %v to %v", f.Path, newBaseName) - if err := os.Rename(f.Path, newPath); err != nil { - log.Fatalf("Unable to rename %v to %v: %v", f.Path, newBaseName, err) - } - } else { - log.Fatalf("Refusing to rename %v to %v: %v already exists", f.Path, newBaseName, newPath) + newBase := strings.Replace(f.Base(), fr.find, fr.replace, -1) + if newBase == f.Base() { + return + } + newPath := filepath.Join(f.Dir(), newBase) + + if err := renameNoReplace(f.Path, newPath); err != nil { + if errors.Is(err, os.ErrExist) { + fr.recordErr(fmt.Errorf("refusing to rename %s to %s: %s already exists", + f.Path, newBase, newPath)) + return } + fr.recordErr(fmt.Errorf("rename %s to %s: %w", f.Path, newBase, err)) + return } + log.Printf("Renaming %v to %v", f.Path, newBase) } -// Replaces the contents of the given file, using the find & replace values in -// context. -func (fr *findReplace) ReplaceContents(f *File) { - // Find & replace the contents of text files. Binary-looking files return - // an empty string and will be skipped here. - content := f.Read() - if strings.Contains(content, fr.find) { - newContent := strings.Replace(content, fr.find, fr.replace, -1) - f.Write(newContent) - } +func (fr *findReplace) recordErr(err error) { + fr.errors++ + log.Printf("error: %v", err) } diff --git a/find_replace_test.go b/find_replace_test.go index 5bfda8a..5f1f955 100644 --- a/find_replace_test.go +++ b/find_replace_test.go @@ -2,9 +2,8 @@ package main import ( "errors" - "log" + "fmt" "os" - "os/exec" "path/filepath" "strings" "testing" @@ -14,37 +13,39 @@ import ( * Testing utilities */ -// newTestFile creates a file in the given directory path, with the given name -// and content. If a directory path is not provided, a temp directory is used. -// If a baseName is not provided, a random file name is generated. Returns the -// directory where the file was created, the file's directory entry, and the -// actual name of the file. -func newTestFile(path string, baseName string, content string) *File { - f, err := os.CreateTemp(path, baseName) +// newTestFile creates a file in the given directory with the given name and +// content. If baseName is empty a unique random name is used. +func newTestFile(t testing.TB, dir, baseName, content string) *File { + t.Helper() + pattern := baseName + if pattern == "" { + pattern = "*" + } + f, err := os.CreateTemp(dir, pattern) if err != nil { - log.Fatal(err) + t.Fatal(err) } if _, err := f.Write([]byte(content)); err != nil { - defer os.Remove(f.Name()) - log.Fatal(err) + _ = f.Close() + t.Fatal(err) } if err := f.Close(); err != nil { - defer os.Remove(f.Name()) - log.Fatal(err) + t.Fatal(err) } - return NewFile(f.Name()) } -// newTestDir creates a directory in the given directory path, with the given -// base name. If a directory path is not provided, a temp directory is used. If -// a baseName is not provided, a random file name is generated. Returns the -// directory where the file was created, the file's directory entry, and the -// actual name of the file. -func newTestDir(path string, baseName string) *File { - dirPath, err := os.MkdirTemp(path, baseName) +// newTestDir creates a directory in the given directory path with the given +// base name (or a random name if empty). +func newTestDir(t testing.TB, dir, baseName string) *File { + t.Helper() + pattern := baseName + if pattern == "" { + pattern = "*" + } + dirPath, err := os.MkdirTemp(dir, pattern) if err != nil { - log.Fatal(err) + t.Fatal(err) } return NewFile(dirPath) } @@ -57,15 +58,15 @@ func expectedPathAfterRename(f *File, fr *findReplace) string { * Assertions */ -// assertFileExists ensures that the given File exists func assertFileExists(t *testing.T, f *File) { + t.Helper() if _, err := os.Stat(f.Path); errors.Is(err, os.ErrNotExist) { t.Errorf("test file %v does not exist", f.Path) } } -// assertFileNonexistent ensures that the File does not exist func assertFileNonexistent(t *testing.T, f *File) { + t.Helper() if _, err := os.Stat(f.Path); !errors.Is(err, os.ErrNotExist) { if err == nil { t.Errorf("test file %v exists", f.Path) @@ -76,73 +77,83 @@ func assertFileNonexistent(t *testing.T, f *File) { } func assertPathExistsAfterRename(t *testing.T, f *File, expectedPath string) *File { + t.Helper() assertFileNonexistent(t, f) newFile := NewFile(expectedPath) assertFileExists(t, newFile) return newFile } +func readFile(t *testing.T, path string) string { + t.Helper() + b, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read %v: %v", path, err) + } + return string(b) +} + +func assertNewContentsOfFile(t *testing.T, path, initial, find, replace, want string) { + t.Helper() + got := readFile(t, path) + if got != want { + t.Errorf("replace %v with %v in %v, but got %v; want %v", find, replace, initial, got, want) + } +} + /* * Tests */ // TestWalkDir is the most important test of the entire suite, because it -// exercises all the basic functionality of the app. It builds a directory tree -// of temporary files and directories, walks the entire tree, and ensures that -// all files and directories are appropriately renamed at at the end, and all -// files contain the correct contents. +// exercises all the basic functionality of the app. It builds a directory +// tree of temporary files and directories, walks the entire tree, and ensures +// that all files and directories are appropriately renamed and contain the +// correct contents. func TestWalkDir(t *testing.T) { find := "wh" replace := "f" - d := newTestDir("", "*") - defer os.Remove(d.Path) + d := NewFile(t.TempDir()) // d1: who/ - d1 := newTestDir(d.Path, "who") - defer os.Remove(d1.Path) + d1 := newTestDir(t, d.Path, "who") // d1d1: who/what/ - d1d1 := newTestDir(d1.Path, "what") - defer os.Remove(d1d1.Path) + d1d1 := newTestDir(t, d1.Path, "what") // d1d1f1: who/what/when (contains "where") d1d1f1Contents := "where" - d1d1f1 := newTestFile(d1d1.Path, "when", d1d1f1Contents) - defer os.Remove(d1d1f1.Path) + d1d1f1 := newTestFile(t, d1d1.Path, "when", d1d1f1Contents) // d2: what/ - d2 := newTestDir(d.Path, "what") - defer os.Remove(d2.Path) + d2 := newTestDir(t, d.Path, "what") // d2d1: what/when/ - d2d1 := newTestDir(d2.Path, "when") - defer os.Remove(d2d1.Path) + d2d1 := newTestDir(t, d2.Path, "when") // d2d1d1: what/when/where (directories with no files) - d2d1d1 := newTestDir(d2d1.Path, "where") - defer os.Remove(d2d1d1.Path) + d2d1d1 := newTestDir(t, d2d1.Path, "where") // d3: when/ - d3 := newTestDir(d.Path, "when") - defer os.Remove(d3.Path) + d3 := newTestDir(t, d.Path, "when") // d3f1: when/where (contains "why") d3f1Contents := "why" - d3f1 := newTestFile(d3.Path, "where", d3f1Contents) - defer os.Remove(d3f1.Path) + d3f1 := newTestFile(t, d3.Path, "where", d3f1Contents) // d4: where/ (empty directory in base dir) - d4 := newTestDir(d.Path, "where") - defer os.Remove(d4.Path) + d4 := newTestDir(t, d.Path, "where") // f1: why (file in base dir contains "wh") f1Contents := "wh\nwh\nwh\n" - f1 := newTestFile(d.Path, "why", f1Contents) - defer os.Remove(f1.Path) + f1 := newTestFile(t, d.Path, "why", f1Contents) fr := findReplace{find: find, replace: replace} fr.WalkDir(d) + if fr.errors != 0 { + t.Fatalf("walk reported %d errors", fr.errors) + } // d1: who/ > fo/ d1ExpectedPath := expectedPathAfterRename(d1, &fr) @@ -188,179 +199,153 @@ func TestWalkDir(t *testing.T) { assertNewContentsOfFile(t, f1ExpectedPath, f1Contents, find, replace, "f\nf\nf\n") } -func TestHandleFileWithDir(t *testing.T) { - initial := "alpha" - find := "ph" - replace := "f" - - f := newTestDir("", initial) - defer os.Remove(f.Path) - expectedPath := filepath.Join(f.Dir(), strings.Replace(f.Base(), find, replace, -1)) - defer os.Remove(expectedPath) - fr := findReplace{find: find, replace: replace} - - assertFileExists(t, f) - fr.HandleFile(f) - assertPathExistsAfterRename(t, f, expectedPath) -} - -func TestHandleFileWithIgnoredDir(t *testing.T) { - initial := ".git" - find := "git" - replace := "got" - - dirPath := filepath.Join(os.TempDir(), initial) - if err := os.Mkdir(dirPath, 0700); err != nil { - log.Fatal(err) +func TestIgnoredGitDir(t *testing.T) { + root := t.TempDir() + gitDir := filepath.Join(root, ".git") + if err := os.Mkdir(gitDir, 0o700); err != nil { + t.Fatal(err) + } + gitFile := filepath.Join(gitDir, "config") + if err := os.WriteFile(gitFile, []byte("git contents"), 0o600); err != nil { + t.Fatal(err) } - f := NewFile(dirPath) - defer os.Remove(f.Path) - // Just in case it's unexpectedly renamed, let's make sure we cleanup the - // anticipated name. - unexpectedName := strings.Replace(f.Base(), find, replace, -1) - unexpectedPath := filepath.Join(f.Dir(), unexpectedName) - defer os.Remove(unexpectedPath) - fr := findReplace{find: find, replace: replace} - - assertFileExists(t, f) - fr.HandleFile(f) - assertFileExists(t, f) -} - -func TestHandleFileWithFile(t *testing.T) { - initial := "alpha" - find := "ph" - replace := "f" - want := "alfa" - - f := newTestFile("", initial, initial) - defer os.Remove(f.Path) - expectedName := strings.Replace(f.Base(), find, replace, -1) - expectedPath := filepath.Join(f.Dir(), expectedName) - defer os.Remove(expectedPath) - fr := findReplace{find: find, replace: replace} - assertFileExists(t, f) - fr.HandleFile(f) - assertPathExistsAfterRename(t, f, expectedPath) + fr := findReplace{find: "git", replace: "got"} + fr.WalkDir(NewFile(root)) + if fr.errors != 0 { + t.Fatalf("walk reported %d errors", fr.errors) + } - got := NewFile(expectedPath).Read() - if got != want { - t.Errorf("replace %v with %v in %v, but got %v; want %v", find, replace, initial, got, want) + // The .git directory must remain unchanged. + if _, err := os.Stat(gitDir); err != nil { + t.Errorf(".git was renamed/removed: %v", err) + } + if got := readFile(t, gitFile); got != "git contents" { + t.Errorf(".git/config was rewritten: %q", got) } } -func TestRenameFile(t *testing.T) { - initial := "alpha" - find := "ph" - replace := "f" +func TestRenameSingleFile(t *testing.T) { + d := t.TempDir() + f := newTestFile(t, d, "alpha", "") - f := newTestFile("", initial, "") - defer os.Remove(f.Path) - expectedName := strings.Replace(f.Base(), find, replace, -1) - expectedPath := filepath.Join(f.Dir(), expectedName) - defer os.Remove(expectedPath) - fr := findReplace{find: find, replace: replace} + fr := findReplace{find: "ph", replace: "f"} + expectedPath := filepath.Join(f.Dir(), strings.Replace(f.Base(), fr.find, fr.replace, -1)) - assertFileExists(t, f) - fr.RenameFile(f) + fr.WalkDir(NewFile(d)) assertPathExistsAfterRename(t, f, expectedPath) } -// assertNewContentsOfFile ensures that the contents of the file at the given -// path exactly match the desired string. -func assertNewContentsOfFile(t *testing.T, path string, initial string, find string, replace string, want string) { - got := NewFile(path).Read() - if got != want { - t.Errorf("replace %v with %v in %v, but got %v; want %v", find, replace, initial, got, want) - } -} - func TestReplaceContents(t *testing.T) { - initial := "alpha" - find := "ph" - replace := "f" - want := "alfa" + d := t.TempDir() + f := newTestFile(t, d, "*", "alpha") - f := newTestFile("", "*", initial) - defer os.Remove(f.Path) - fr := findReplace{find: find, replace: replace} - fr.ReplaceContents(f) - assertNewContentsOfFile(t, f.Path, initial, find, replace, want) + fr := findReplace{find: "ph", replace: "f"} + fr.WalkDir(NewFile(d)) + assertNewContentsOfFile(t, f.Path, "alpha", "ph", "f", "alfa") } func TestReplaceContentsEntireFile(t *testing.T) { - initial := "alpha" - find := "alpha" - replace := "beta" - want := "beta" + d := t.TempDir() + f := newTestFile(t, d, "*", "alpha") - f := newTestFile("", "*", initial) - defer os.Remove(f.Path) - fr := findReplace{find: find, replace: replace} - fr.ReplaceContents(f) - assertNewContentsOfFile(t, f.Path, initial, find, replace, want) + fr := findReplace{find: "alpha", replace: "beta"} + fr.WalkDir(NewFile(d)) + assertNewContentsOfFile(t, f.Path, "alpha", "alpha", "beta", "beta") } func TestReplaceContentsMultipleMatchesSingleLine(t *testing.T) { - initial := "alphaalpha" - find := "ph" - replace := "f" - want := "alfaalfa" + d := t.TempDir() + f := newTestFile(t, d, "*", "alphaalpha") - f := newTestFile("", "*", initial) - defer os.Remove(f.Path) - fr := findReplace{find: find, replace: replace} - fr.ReplaceContents(f) - assertNewContentsOfFile(t, f.Path, initial, find, replace, want) + fr := findReplace{find: "ph", replace: "f"} + fr.WalkDir(NewFile(d)) + assertNewContentsOfFile(t, f.Path, "alphaalpha", "ph", "f", "alfaalfa") } func TestReplaceContentsMultipleMatchesMultipleLines(t *testing.T) { - initial := "alpha\nalpha" - find := "ph" - replace := "f" - want := "alfa\nalfa" + d := t.TempDir() + f := newTestFile(t, d, "*", "alpha\nalpha") - f := newTestFile("", "*", initial) - defer os.Remove(f.Path) - fr := findReplace{find: find, replace: replace} - fr.ReplaceContents(f) - assertNewContentsOfFile(t, f.Path, initial, find, replace, want) + fr := findReplace{find: "ph", replace: "f"} + fr.WalkDir(NewFile(d)) + assertNewContentsOfFile(t, f.Path, "alpha\nalpha", "ph", "f", "alfa\nalfa") } func TestReplaceContentsNoMatches(t *testing.T) { - initial := "alpha" - find := "abc" - replace := "xyz" - want := "alpha" + d := t.TempDir() + f := newTestFile(t, d, "*", "alpha") - f := newTestFile("", "*", initial) - defer os.Remove(f.Path) - fr := findReplace{find: find, replace: replace} - fr.ReplaceContents(f) - assertNewContentsOfFile(t, f.Path, initial, find, replace, want) + fr := findReplace{find: "abc", replace: "xyz"} + fr.WalkDir(NewFile(d)) + assertNewContentsOfFile(t, f.Path, "alpha", "abc", "xyz", "alpha") } -func CloneRepoToTestDir(b *testing.B, repoUrl string) *File { - d := newTestDir("", "*") - defer os.Remove(d.Path) - - cmd := exec.Command("git", "clone", "--depth=1", "--single-branch", repoUrl, ".") - cmd.Dir = d.Path - out, err := cmd.CombinedOutput() - if err != nil { - b.Errorf("failed to clone repo: %s", out) +// BenchmarkSyntheticTree benchmarks find-replace against a synthetic tree of +// files with controlled size so the benchmark is reproducible and not network- +// bound. +func BenchmarkSyntheticTree(b *testing.B) { + const dirs = 10 + const filesPerDir = 100 + const fileBytes = 4 * 1024 + + root := b.TempDir() + body := strings.Repeat("alpha beta gamma\n", fileBytes/16) + for i := 0; i < dirs; i++ { + dir := filepath.Join(root, fmt.Sprintf("dir-alpha-%d", i)) + if err := os.Mkdir(dir, 0o755); err != nil { + b.Fatal(err) + } + for j := 0; j < filesPerDir; j++ { + path := filepath.Join(dir, fmt.Sprintf("file-alpha-%d.txt", j)) + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + b.Fatal(err) + } + } } - return d -} - -func BenchmarkNova(b *testing.B) { + b.ResetTimer() for n := 0; n < b.N; n++ { + // We can't safely rerun with rename in place, so use a fresh + // subtree per iteration. b.StopTimer() - d := CloneRepoToTestDir(b, "git@github.com:openstack/nova.git") - fr := findReplace{find: RandomString(2), replace: RandomString(2)} + work := filepath.Join(b.TempDir(), "tree") + if err := copyTree(root, work); err != nil { + b.Fatal(err) + } b.StartTimer() - fr.WalkDir(d) + + fr := findReplace{find: "alpha", replace: "BETA"} + fr.WalkDir(NewFile(work)) + if fr.errors != 0 { + b.Fatalf("benchmark reported %d errors", fr.errors) + } } } + +// copyTree recursively copies the regular files and directories under src +// into dst. Used by benchmarks to set up a fresh working tree without +// requiring os.CopyFS (Go 1.23+). +func copyTree(src, dst string) error { + return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + rel, err := filepath.Rel(src, path) + if err != nil { + return err + } + target := filepath.Join(dst, rel) + if info.IsDir() { + return os.MkdirAll(target, info.Mode().Perm()) + } + if !info.Mode().IsRegular() { + return nil + } + data, err := os.ReadFile(path) + if err != nil { + return err + } + return os.WriteFile(target, data, info.Mode().Perm()) + }) +} diff --git a/go.mod b/go.mod index 96a1ddb..beb4122 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/dolph/find-replace go 1.19 - -require golang.org/x/tools v0.7.0 diff --git a/go.sum b/go.sum index b522ba0..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +0,0 @@ -golang.org/x/tools v0.1.9 h1:j9KsMiaP1c3B0OTQGth0/k+miLGTgLsAFUCrF2vLcF8= -golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= -golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= -golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= diff --git a/reliability_test.go b/reliability_test.go new file mode 100644 index 0000000..ac8b776 --- /dev/null +++ b/reliability_test.go @@ -0,0 +1,202 @@ +package main + +import ( + "bytes" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" +) + +// TestUnreadableSubdirSkipped verifies that an unreadable subdirectory does +// not abort the entire walk — see issue #6. The remaining tree must still be +// processed and the walker must record the error so the caller knows. +func TestUnreadableSubdirSkipped(t *testing.T) { + if os.Getuid() == 0 { + t.Skip("running as root; permission bits don't apply") + } + if runtime.GOOS == "windows" { + t.Skip("permission semantics differ on windows") + } + + root := t.TempDir() + + // Subdir with no read permission. + denied := filepath.Join(root, "denied-alpha") + if err := os.Mkdir(denied, 0o000); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = os.Chmod(denied, 0o755) }) + + // Sibling that should still be processed. + openable := filepath.Join(root, "open-alpha.txt") + if err := os.WriteFile(openable, []byte("alpha"), 0o600); err != nil { + t.Fatal(err) + } + + fr := findReplace{find: "alpha", replace: "BETA"} + fr.WalkDir(NewFile(root)) + + // The sibling file content was rewritten. + got, _ := os.ReadFile(filepath.Join(root, "open-BETA.txt")) + if string(got) != "BETA" { + t.Errorf("sibling not rewritten; got %q", got) + } + + // And the walker recorded at least one error. + if fr.errors == 0 { + t.Error("expected walker to record an error for the unreadable directory") + } +} + +// TestRejectsEmptyFind ensures the CLI fails fast when given an empty FIND +// argument — see issue #10. Runs as a subprocess against the built binary. +func TestRejectsEmptyFind(t *testing.T) { + bin := buildBinary(t) + cmd := exec.Command(bin, "", "anything") + cmd.Dir = t.TempDir() + out, err := cmd.CombinedOutput() + if err == nil { + t.Fatalf("expected non-zero exit, got success; output: %s", out) + } + if !bytes.Contains(out, []byte("FIND")) { + t.Errorf("expected error message to mention FIND, got %q", out) + } +} + +// TestRejectsIdenticalFindReplace ensures the CLI fails fast when given +// FIND==REPLACE. +func TestRejectsIdenticalFindReplace(t *testing.T) { + bin := buildBinary(t) + cmd := exec.Command(bin, "alpha", "alpha") + cmd.Dir = t.TempDir() + out, err := cmd.CombinedOutput() + if err == nil { + t.Fatalf("expected non-zero exit, got success; output: %s", out) + } +} + +// TestExitCodeNonZeroOnError ensures that recording at least one error during +// a walk causes the binary to exit non-zero — see issue #11. +func TestExitCodeNonZeroOnError(t *testing.T) { + bin := buildBinary(t) + work := t.TempDir() + + // Create a rename collision so the run records an error. + if err := os.WriteFile(filepath.Join(work, "alpha"), []byte(""), 0o600); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(work, "BETA"), []byte(""), 0o600); err != nil { + t.Fatal(err) + } + + cmd := exec.Command(bin, "alpha", "BETA") + cmd.Dir = work + out, err := cmd.CombinedOutput() + if err == nil { + t.Fatalf("expected non-zero exit, got success; output: %s", out) + } +} + +// TestBigFileStreamsThroughBoundedMemory exercises the streaming rewriter on +// a file substantially larger than the buffer, ensuring correctness without +// loading the file into memory all at once. (See issue #8 — full validation +// of memory bounds requires runtime.MemStats and is brittle, so we settle +// for correctness on a multi-megabyte file plus the targeted boundary tests +// in TestStreamReplaceMatchAcrossBuffers.) +func TestBigFileStreamsThroughBoundedMemory(t *testing.T) { + d := t.TempDir() + path := filepath.Join(d, "big.txt") + + // 5 MB of text with the find string sprinkled throughout. + chunk := strings.Repeat("alpha bravo charlie delta\n", 1024) // ~25 KB + body := strings.Repeat(chunk, 200) // ~5 MB + if err := os.WriteFile(path, []byte(body), 0o600); err != nil { + t.Fatal(err) + } + + changed, err := rewriteFile(path, []byte("bravo"), []byte("BB"), statOrFail(t, path)) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("expected change=true") + } + + got, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + want := strings.ReplaceAll(body, "bravo", "BB") + if !bytes.Equal(got, []byte(want)) { + t.Errorf("rewrite mismatch (got len=%d, want len=%d)", len(got), len(want)) + } +} + +func buildBinary(t *testing.T) string { + t.Helper() + dir := t.TempDir() + bin := filepath.Join(dir, "find-replace") + cmd := exec.Command("go", "build", "-o", bin, ".") + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("build failed: %v\n%s", err, out) + } + return bin +} + +// TestGitAsFileNotRewritten verifies that a `.git` file (worktree/submodule +// linkage) is skipped, not rewritten. See issue #19. +func TestGitAsFileNotRewritten(t *testing.T) { + root := t.TempDir() + gitFile := filepath.Join(root, ".git") + original := "gitdir: ../some-other-dir/.git/worktrees/me\n" + if err := os.WriteFile(gitFile, []byte(original), 0o644); err != nil { + t.Fatal(err) + } + + fr := findReplace{find: "gitdir", replace: "WAS_REWRITTEN"} + fr.WalkDir(NewFile(root)) + if fr.errors != 0 { + t.Fatalf("walk reported %d errors", fr.errors) + } + + got, err := os.ReadFile(gitFile) + if err != nil { + t.Fatalf(".git file disappeared: %v", err) + } + if string(got) != original { + t.Errorf(".git file was rewritten: %q", got) + } +} + +// TestStaleTempFilesSkipped verifies that orphan .find-replace-* files from a +// crashed prior run are not picked up as targets. See issue #21. +func TestStaleTempFilesSkipped(t *testing.T) { + root := t.TempDir() + + stale := filepath.Join(root, ".find-replace-orphan-alpha") + if err := os.WriteFile(stale, []byte("alpha"), 0o600); err != nil { + t.Fatal(err) + } + regular := filepath.Join(root, "alpha.txt") + if err := os.WriteFile(regular, []byte("alpha"), 0o600); err != nil { + t.Fatal(err) + } + + fr := findReplace{find: "alpha", replace: "BETA"} + fr.WalkDir(NewFile(root)) + + // The stale temp file must remain — neither rewritten nor renamed. + if got, _ := os.ReadFile(stale); string(got) != "alpha" { + t.Errorf("stale temp file was rewritten: %q", got) + } + + // The regular file was rewritten and renamed. + got, _ := os.ReadFile(filepath.Join(root, "BETA.txt")) + if string(got) != "BETA" { + t.Errorf("regular file not rewritten/renamed: %q", got) + } +} diff --git a/security_test.go b/security_test.go new file mode 100644 index 0000000..8a9d7a9 --- /dev/null +++ b/security_test.go @@ -0,0 +1,296 @@ +package main + +import ( + "bytes" + "os" + "path/filepath" + "runtime" + "strings" + "syscall" + "testing" + "time" +) + +func statSysOrSkip(t *testing.T, path string) *syscall.Stat_t { + t.Helper() + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + st, ok := info.Sys().(*syscall.Stat_t) + if !ok { + t.Skip("syscall.Stat_t not available on this platform") + } + return st +} + +// TestSymlinkNotFollowed verifies that find-replace does not follow symbolic +// links out of the current working tree. Following symlinks would let any +// directory entry rewrite or rename files anywhere on the filesystem the +// running user can write — see issue #2. +func TestSymlinkNotFollowed(t *testing.T) { + root := t.TempDir() + + // Build a "victim" tree outside of the area find-replace will be told to + // process. + victimDir := t.TempDir() + victimFile := filepath.Join(victimDir, "secret.txt") + const victimContents = "secret data" + if err := os.WriteFile(victimFile, []byte(victimContents), 0o600); err != nil { + t.Fatal(err) + } + + // Inside the work tree, plant a symlink that points at the victim + // directory. + if err := os.Symlink(victimDir, filepath.Join(root, "escape")); err != nil { + t.Fatal(err) + } + + // Also include a regular file with content that should be rewritten so we + // can confirm the run did something inside `root`. + regularFile := filepath.Join(root, "regular.txt") + if err := os.WriteFile(regularFile, []byte("secret in root"), 0o600); err != nil { + t.Fatal(err) + } + + fr := findReplace{find: "secret", replace: "PWNED"} + fr.WalkDir(NewFile(root)) + + // The victim file must be untouched. + got, err := os.ReadFile(victimFile) + if err != nil { + t.Fatalf("victim file disappeared: %v", err) + } + if string(got) != victimContents { + t.Errorf("victim file was rewritten through symlink: got %q want %q", + string(got), victimContents) + } + + // The regular file inside root should have been rewritten. + got, err = os.ReadFile(regularFile) + if err != nil { + t.Fatalf("regular file disappeared: %v", err) + } + if !strings.Contains(string(got), "PWNED") { + t.Errorf("regular file was not rewritten: got %q", string(got)) + } +} + +// TestSymlinkNotRenamed verifies that a symlink whose name matches the find +// string is not itself renamed (which would still be safe, but we want to be +// explicit) AND that it is not chased to rename its target. +func TestSymlinkTargetNotRenamed(t *testing.T) { + root := t.TempDir() + + victimDir := t.TempDir() + victimFile := filepath.Join(victimDir, "alpha-target") + if err := os.WriteFile(victimFile, []byte(""), 0o600); err != nil { + t.Fatal(err) + } + + // Symlink whose own name does NOT match the find string but whose target + // directory contains a file whose name DOES match the find string. + if err := os.Symlink(victimDir, filepath.Join(root, "via")); err != nil { + t.Fatal(err) + } + + fr := findReplace{find: "alpha", replace: "beta"} + fr.WalkDir(NewFile(root)) + + // File inside the symlinked target must not have been renamed. + if _, err := os.Stat(victimFile); err != nil { + t.Errorf("victim file %v was renamed/removed: %v", victimFile, err) + } + renamed := filepath.Join(victimDir, "beta-target") + if _, err := os.Stat(renamed); err == nil { + t.Errorf("symlink target was renamed to %v", renamed) + } +} + +// TestTempfileSymlinkAttack verifies that even if the same directory contains +// pre-planted files with names matching the temp-file pattern, the rewrite +// uses an O_EXCL-style creation that does not follow attacker-planted +// symlinks. See issue #3. +func TestTempfileSymlinkAttackRefuses(t *testing.T) { + root := t.TempDir() + target := filepath.Join(root, "target.txt") + if err := os.WriteFile(target, []byte("alpha"), 0o600); err != nil { + t.Fatal(err) + } + + // Plant a victim file outside of the working tree. + victimDir := t.TempDir() + victim := filepath.Join(victimDir, "victim.txt") + if err := os.WriteFile(victim, []byte("victim contents"), 0o600); err != nil { + t.Fatal(err) + } + + // Pre-plant 200 symlinks named like our temp-file pattern that point at + // the victim. With O_EXCL the temp-file creation must reject every + // pre-existing name and ultimately settle on a unique one (or error + // out); without O_EXCL the open would follow the symlink and clobber + // the victim. + for i := 0; i < 200; i++ { + linkName := filepath.Join(root, ".find-replace-attack-"+filepath.Base(target)+"-"+strings.Repeat("x", i+1)) + _ = os.Symlink(victim, linkName) + } + + changed, err := rewriteFile(target, []byte("alpha"), []byte("BETA"), statOrFail(t, target)) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Error("expected changed=true") + } + + // Victim must be untouched. + got, err := os.ReadFile(victim) + if err != nil { + t.Fatalf("victim missing: %v", err) + } + if !bytes.Equal(got, []byte("victim contents")) { + t.Errorf("victim was rewritten via tempfile: %q", got) + } + + // Target was rewritten correctly. + got, _ = os.ReadFile(target) + if string(got) != "BETA" { + t.Errorf("target = %q, want BETA", got) + } +} + +// TestRenameRefusesOverwrite verifies that RenameFile does not silently +// overwrite an existing destination, even one created concurrently. +// See issue #4. +func TestRenameRefusesOverwrite(t *testing.T) { + root := t.TempDir() + + src := filepath.Join(root, "alpha") + if err := os.WriteFile(src, []byte("source"), 0o600); err != nil { + t.Fatal(err) + } + // Pre-plant the destination so the rename must refuse. + dst := filepath.Join(root, "BETA") + if err := os.WriteFile(dst, []byte("destination"), 0o600); err != nil { + t.Fatal(err) + } + + fr := findReplace{find: "alpha", replace: "BETA"} + fr.WalkDir(NewFile(root)) + + // The destination must not have been overwritten. + got, _ := os.ReadFile(dst) + if string(got) != "destination" { + t.Errorf("destination overwritten: %q", got) + } + // And we should have recorded an error. + if fr.errors == 0 { + t.Error("expected fr.errors > 0 when refusing rename") + } +} + +// TestRewritePreservesMtime verifies the rewrite preserves the original +// modification time. See issue #23. +func TestRewritePreservesMtime(t *testing.T) { + d := t.TempDir() + path := filepath.Join(d, "f.txt") + if err := os.WriteFile(path, []byte("alpha"), 0o600); err != nil { + t.Fatal(err) + } + stat0, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + // Roll mtime back so the test catches a real preservation rather than + // a coincidence (writing now and rewriting now are likely close enough + // to pass a coarser-grained check). + past := stat0.ModTime().Add(-2 * time.Hour) + if err := os.Chtimes(path, past, past); err != nil { + t.Fatal(err) + } + + changed, err := rewriteFile(path, []byte("alpha"), []byte("BETA"), statOrFail(t, path)) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("expected change=true") + } + + stat1, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if !stat1.ModTime().Equal(past) { + t.Errorf("mtime not preserved: got %v want %v", stat1.ModTime(), past) + } +} + +// TestRewritePreservesOwner verifies that on Linux the rewrite preserves the +// original Uid and Gid. Only meaningful when running as root over files owned +// by other uids; otherwise we just check that the uid/gid are unchanged. +// See issue #17. +func TestRewritePreservesOwner(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("ownership semantics differ on windows") + } + d := t.TempDir() + path := filepath.Join(d, "f.txt") + if err := os.WriteFile(path, []byte("alpha"), 0o600); err != nil { + t.Fatal(err) + } + + stat0 := statSysOrSkip(t, path) + + if os.Getuid() == 0 { + // Chown to a non-root uid we can use to make the test meaningful. + // 'nobody' (65534) is universally available. + if err := os.Chown(path, 65534, 65534); err != nil { + t.Skipf("cannot chown to nobody: %v", err) + } + stat0 = statSysOrSkip(t, path) + } + + changed, err := rewriteFile(path, []byte("alpha"), []byte("BETA"), statOrFail(t, path)) + if err != nil { + t.Fatal(err) + } + if !changed { + t.Fatal("expected change=true") + } + + stat1 := statSysOrSkip(t, path) + if stat1.Uid != stat0.Uid || stat1.Gid != stat0.Gid { + t.Errorf("ownership not preserved: got uid=%d gid=%d want uid=%d gid=%d", + stat1.Uid, stat1.Gid, stat0.Uid, stat0.Gid) + } +} + +// TestRefusesSetuidFile verifies that files with setuid/setgid bits are not +// rewritten. See issue #18. +func TestRefusesSetuidFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("setuid not applicable on windows") + } + d := t.TempDir() + path := filepath.Join(d, "setuid") + if err := os.WriteFile(path, []byte("alpha"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.Chmod(path, 0o755|os.ModeSetuid); err != nil { + t.Skipf("cannot set setuid bit: %v", err) + } + // Verify the bit actually stuck (some sandboxes silently strip setuid). + if info, err := os.Stat(path); err != nil || info.Mode()&os.ModeSetuid == 0 { + t.Skipf("setuid bit was stripped by the filesystem") + } + + _, err := rewriteFile(path, []byte("alpha"), []byte("BETA"), statOrFail(t, path)) + if err == nil { + t.Error("expected an error for setuid file") + } + got, _ := os.ReadFile(path) + if string(got) != "alpha" { + t.Errorf("setuid file was rewritten: %q", got) + } +} diff --git a/strings.go b/strings.go index b546a31..079c97e 100644 --- a/strings.go +++ b/strings.go @@ -1,18 +1,42 @@ package main -import "math/rand" +import ( + "log" + "path/filepath" +) -var characters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") +// File holds a single absolute path. It exists primarily to keep call sites +// readable (`f.Base()` reads better than `filepath.Base(path)`) and is +// otherwise a thin wrapper around a string. +type File struct { + Path string +} -// randomString generates a random base-62 string of the given length (or returns an -// empty string). -func RandomString(n int) string { - if n <= 0 { - return "" +// NewFile resolves path to an absolute path and returns a *File for it. Used +// for the initial root passed in from main(). +func NewFile(path string) *File { + if filepath.IsAbs(path) { + return &File{Path: filepath.Clean(path)} } - b := make([]rune, n) - for i := range b { - b[i] = characters[rand.Intn(len(characters))] + abs, err := filepath.Abs(path) + if err != nil { + log.Fatalf("Unable to resolve absolute path of %v: %v", path, err) } - return string(b) + return &File{Path: abs} +} + +// newChildFile builds a *File for the given child of an already-absolute +// parent path, skipping the redundant filepath.Abs call performed by NewFile. +func newChildFile(parentAbs, name string) *File { + return &File{Path: filepath.Join(parentAbs, name)} +} + +// Base returns the base name of f's path. +func (f *File) Base() string { + return filepath.Base(f.Path) +} + +// Dir returns the directory of f's path. +func (f *File) Dir() string { + return filepath.Dir(f.Path) } diff --git a/strings_test.go b/strings_test.go index e5d5d55..87ac47f 100644 --- a/strings_test.go +++ b/strings_test.go @@ -1,34 +1,38 @@ package main import ( + "path/filepath" "testing" ) -// assertRandomStringLength ensures that the generated string matches the -// desired length. -func assertRandomStringLength(t *testing.T, ask int, want int) { - got := len(RandomString(ask)) - if got != want { - t.Errorf("len(RandomString(%v)) = %v; want %v", ask, got, want) +func TestNewFileAbsolutePath(t *testing.T) { + f := NewFile("/tmp/find-replace/example") + if !filepath.IsAbs(f.Path) { + t.Errorf("expected absolute path, got %v", f.Path) } } -func TestRandomStringLengthNegativeOne(t *testing.T) { - assertRandomStringLength(t, -1, 0) -} - -func TestRandomStringLengthZero(t *testing.T) { - assertRandomStringLength(t, 0, 0) -} - -func TestRandomStringLengthOne(t *testing.T) { - assertRandomStringLength(t, 1, 1) +func TestNewFileRelativePath(t *testing.T) { + f := NewFile("example") + if !filepath.IsAbs(f.Path) { + t.Errorf("expected absolute path, got %v", f.Path) + } } -func TestRandomStringLengthTen(t *testing.T) { - assertRandomStringLength(t, 10, 10) +func TestNewChildFileSkipsAbs(t *testing.T) { + parent := "/tmp/find-replace" + child := newChildFile(parent, "kid") + if child.Path != "/tmp/find-replace/kid" { + t.Errorf("unexpected path: %v", child.Path) + } } -func TestRandomStringLengthTwenty(t *testing.T) { - assertRandomStringLength(t, 20, 20) +func TestBaseDir(t *testing.T) { + f := NewFile("/tmp/find-replace/example") + if f.Base() != "example" { + t.Errorf("Base = %v, want example", f.Base()) + } + if f.Dir() != "/tmp/find-replace" { + t.Errorf("Dir = %v, want /tmp/find-replace", f.Dir()) + } }