diff --git a/tests/api/test_client_integration.py b/tests/api/test_client_integration.py index dbfb7ea..a2105c3 100644 --- a/tests/api/test_client_integration.py +++ b/tests/api/test_client_integration.py @@ -52,6 +52,14 @@ def _log_429(fn, *args, **kwargs): raise +def _to_bytes(files: dict) -> dict: + """Convert a str-valued dict to bytes-valued for upload_file().""" + return { + k: (v.encode("utf-8") if isinstance(v, str) else v) + for k, v in files.items() + } + + # --------------------------------------------------------------------------- # Framework-specific mock file sets # --------------------------------------------------------------------------- @@ -194,7 +202,7 @@ def test_05_check_repo_bool(self): # 03. Upload + Create (nanobot — richest file set) # ----------------------------------------------------------------------- def test_06_upload_and_create(self): - file_id = _log_429(self.client.upload_file, NANOBOT_FILES) + file_id = _log_429(self.client.upload_file, _to_bytes(NANOBOT_FILES)) self.assertTrue(file_id) result = _log_429( @@ -211,7 +219,7 @@ def test_06_upload_and_create(self): def test_07_repeated_upload(self): _wait_server(3) for i in range(2): - fid = _log_429(self.client.upload_file, NANOBOT_FILES) + fid = _log_429(self.client.upload_file, _to_bytes(NANOBOT_FILES)) self.assertTrue(fid) result = _log_429( self.client.create_repo, @@ -229,7 +237,7 @@ def test_08_modify_and_reupload(self): modified["SOUL.md"] += "\n## Custom Section\nUser added this.\n" modified["new_file.md"] = "# New File\nAdded in update.\n" - fid = _log_429(self.client.upload_file, modified) + fid = _log_429(self.client.upload_file, _to_bytes(modified)) self.assertTrue(fid) result = _log_429( @@ -299,7 +307,7 @@ def test_14_repeated_download(self): # 09. E2E roundtrip # ----------------------------------------------------------------------- def test_15_e2e_roundtrip(self): - fid = _log_429(self.client.upload_file, NANOBOT_FILES) + fid = _log_429(self.client.upload_file, _to_bytes(NANOBOT_FILES)) self.assertTrue(fid) _log_429( self.client.create_repo, @@ -329,7 +337,7 @@ def test_16_multi_framework_upload(self): for fw, files in ALL_FRAMEWORK_FILES.items(): with self.subTest(framework=fw): agent = f"{AGENT_NAME}-{fw}" - fid = _log_429(self.client.upload_file, files) + fid = _log_429(self.client.upload_file, _to_bytes(files)) self.assertTrue(fid) try: result = _log_429( @@ -351,7 +359,7 @@ def test_17_cross_framework_convert(self): source_files = ALL_FRAMEWORK_FILES[source_fw] agent = f"{AGENT_NAME}-conv-{source_fw}" - fid = _log_429(self.client.upload_file, source_files) + fid = _log_429(self.client.upload_file, _to_bytes(source_files)) try: _log_429( self.client.create_repo, @@ -406,14 +414,14 @@ def test_18_empty_zip(self): fid = _log_429(self.client.upload_file, {}) # Accepted is fine; empty file_id is also acceptable except ApiError: - pass # server may reject empty zips + pass # server may reject empty uploads # ----------------------------------------------------------------------- # 13. Edge: large file # ----------------------------------------------------------------------- def test_19_large_file(self): large_content = "x" * (500 * 1024) - files = {"SOUL.md": "# Soul\nLarge file test.\n", "data/large.txt": large_content} + files = {"SOUL.md": b"# Soul\nLarge file test.\n", "data/large.txt": large_content.encode("utf-8")} fid = _log_429(self.client.upload_file, files) self.assertTrue(fid) @@ -422,9 +430,9 @@ def test_19_large_file(self): # ----------------------------------------------------------------------- def test_20_special_chars_path(self): files = { - "SOUL.md": "# Soul\nSpecial chars test.\n", - "memory/user-notes (1).md": "# Notes\nParentheses in filename.\n", - "skills/web-search-v2/SKILL.md": "# Web Search v2\nHyphen in skill name.\n", + "SOUL.md": b"# Soul\nSpecial chars test.\n", + "memory/user-notes (1).md": b"# Notes\nParentheses in filename.\n", + "skills/web-search-v2/SKILL.md": b"# Web Search v2\nHyphen in skill name.\n", } fid = _log_429(self.client.upload_file, files) self.assertTrue(fid) @@ -435,7 +443,7 @@ def test_20_special_chars_path(self): def test_21_visibility_variants(self): for vis in ["public", "private"]: with self.subTest(visibility=vis): - files = {"SOUL.md": f"# Soul\nVisibility={vis} test.\n"} + files = {"SOUL.md": f"# Soul\nVisibility={vis} test.\n".encode("utf-8")} fid = _log_429(self.client.upload_file, files) self.assertTrue(fid) result = _log_429( @@ -451,7 +459,7 @@ def test_21_visibility_variants(self): # 16. Edge: upload then immediate download # ----------------------------------------------------------------------- def test_22_immediate_download(self): - files = {"SOUL.md": "# Soul\nImmediate download test.\n", "README.md": "# README\n"} + files = {"SOUL.md": b"# Soul\nImmediate download test.\n", "README.md": b"# README\n"} fid = _log_429(self.client.upload_file, files) _log_429( self.client.create_repo, @@ -479,7 +487,7 @@ def test_23_framework_structure(self): with self.subTest(framework=fw): files = ALL_FRAMEWORK_FILES[fw] agent = f"{AGENT_NAME}-struct-{fw}" - fid = _log_429(self.client.upload_file, files) + fid = _log_429(self.client.upload_file, _to_bytes(files)) try: _log_429( self.client.create_repo, diff --git a/tests/api/test_upload_download.py b/tests/api/test_upload_download.py index a45a85d..0ace9ba 100644 --- a/tests/api/test_upload_download.py +++ b/tests/api/test_upload_download.py @@ -34,7 +34,7 @@ from types import SimpleNamespace from ultron.cli.client import ApiError, UltronClient -from ultron.cli.commands import cmd_download, cmd_upload, _repo_name +from ultron.cli.commands import cmd_download, cmd_list, cmd_upload, _repo_name from ultron.cli import config as cli_config from ultron.services.harness.allowlist import ( ALL_AGENT_NAME, @@ -161,21 +161,22 @@ def setUp(self): # ----------------------------------------------------------------------- # Helper: build args namespace for cmd_upload / cmd_download # ----------------------------------------------------------------------- - def _upload_args(self, framework, name, local_dir=None, dry_run=False, list_=False): + def _upload_args(self, framework, name, local_dir=None, dry_run=False, repo=None): return SimpleNamespace( framework=framework, name=name, + repo=repo, local_dir=local_dir, server=SERVER, token=TOKEN, message=None, - list=list_, dry_run=dry_run, ) - def _download_args(self, name, framework=None, target=None, local_dir=None, dry_run=False): + def _download_args(self, name, framework=None, target=None, local_dir=None, dry_run=False, repo=None): return SimpleNamespace( name=name, + repo=repo or _repo_name(framework or "", name or ""), framework=framework, target=target, local_dir=local_dir, @@ -190,7 +191,10 @@ def _create_local_workspace(self, files: dict) -> str: for rel, content in files.items(): fp = Path(tmpdir) / rel fp.parent.mkdir(parents=True, exist_ok=True) - fp.write_text(content, encoding="utf-8") + if isinstance(content, bytes): + fp.write_bytes(content) + else: + fp.write_text(content, encoding="utf-8") return tmpdir def _cleanup_dir(self, path: str): @@ -273,24 +277,24 @@ def test_05_upload_dry_run(self): self._cleanup_dir(local) # ----------------------------------------------------------------------- - # 06. Upload: --list + # 06. List: list sub-agents # ----------------------------------------------------------------------- def test_06_upload_list(self): - """--list should enumerate sub-agents on disk and return 0.""" + """cmd_list should enumerate sub-agents on disk and return 0.""" local = self._create_local_workspace(QODER_ALL_FILES) try: - args = self._upload_args("qoder", None, local_dir=local, list_=True) - rc = cmd_upload(args) + args = SimpleNamespace(framework="qoder", local_dir=local) + rc = cmd_list(args) self.assertEqual(rc, 0) finally: self._cleanup_dir(local) # ----------------------------------------------------------------------- - # 07. Upload: missing --name → error + # 07. Upload: missing --name with multiple agents → error # ----------------------------------------------------------------------- def test_07_upload_missing_name(self): - """Upload without --name (and not --list) should fail.""" - local = self._create_local_workspace(QODER_INDIVIDUAL_FILES) + """Upload without --name when multiple agents exist should fail.""" + local = self._create_local_workspace(QODER_ALL_FILES) try: args = self._upload_args("qoder", None, local_dir=local) rc = cmd_upload(args) diff --git a/tests/api/test_watch_sync.py b/tests/api/test_watch_sync.py index c59a04e..f3165da 100644 --- a/tests/api/test_watch_sync.py +++ b/tests/api/test_watch_sync.py @@ -137,7 +137,12 @@ def _cleanup(self, path: str): def _upload_remote(self, name: str, framework: str, files: dict): """Upload files directly to remote (simulates remote-side changes).""" - file_id = self.client.upload_file(files) + # Convert str values to bytes for the new upload_file API + byte_files = { + k: (v.encode("utf-8") if isinstance(v, str) else v) + for k, v in files.items() + } + file_id = self.client.upload_file(byte_files) self.client.create_repo(self.username, name, framework, system_prompt_files=file_id) def _start_watch(self, framework: str, agent_name: str, local_dir: str, repo_name: str, push_only: bool = True) -> multiprocessing.Process: @@ -557,10 +562,12 @@ def test_11_qoder_individual_watch_rejected(self): args = SimpleNamespace( framework="qoder", name="reviewer", + repo=None, local_dir=None, server=SERVER, token=TOKEN, interval=60, + pull=False, sessions_dir=None, ) rc = cmd_watch(args) diff --git a/tests/cli/test_download_convert.py b/tests/cli/test_download_convert.py index cb36220..eb4dec8 100644 --- a/tests/cli/test_download_convert.py +++ b/tests/cli/test_download_convert.py @@ -48,7 +48,7 @@ def tearDown(self): @mock.patch.object(commands, "UltronClient", _DownloadStub) def test_download_writes_files(self, *_): rc = _run([ - "download", "--name", "nano", "--framework", "nanobot", + "download", "--repo", "nano", "--framework", "nanobot", "--local_dir", str(self.out), ]) self.assertEqual(rc, 0) @@ -62,7 +62,7 @@ def test_download_writes_files(self, *_): def test_download_with_conversion(self, *_): # nanobot -> hermes: USER.md must land at hermes' memories/USER.md. rc = _run([ - "download", "--name", "nano", "--framework", "nanobot", + "download", "--repo", "nano", "--framework", "nanobot", "--target", "hermes", "--local_dir", str(self.out), ]) self.assertEqual(rc, 0) @@ -72,10 +72,70 @@ def test_download_with_conversion(self, *_): @mock.patch.object(commands.config, "resolve_server", return_value=None) @mock.patch.object(commands.config, "resolve_token", return_value=None) def test_download_without_login_fails(self, *_): - rc = _run(["download", "--name", "nano", "--framework", "nanobot", + rc = _run(["download", "--repo", "nano", "--framework", "nanobot", "--local_dir", str(self.out)]) self.assertEqual(rc, 1) + def test_download_repo_required(self): + """Download without --repo should fail at argparse level.""" + import sys + from io import StringIO + stderr = StringIO() + with self.assertRaises(SystemExit): + _run(["download", "--framework", "nanobot", "--local_dir", str(self.out)]) + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _DownloadStub) + def test_download_with_name_creates_agent(self, *_): + """Download with --name should write files for that local agent.""" + rc = _run([ + "download", "--repo", "nano", "--framework", "nanobot", + "--name", "myagent", "--local_dir", str(self.out), + ]) + self.assertEqual(rc, 0) + # Files should still be written (nanobot shared files match). + self.assertTrue((self.out / "SOUL.md").is_file()) + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _DownloadStub) + def test_download_filters_by_allowlist(self, *_): + """Files not matching the allowlist patterns should be skipped.""" + # Add a file that won't match any pattern. + _DownloadStub.STORE = { + "SOUL.md": "soul", + "random/junk.txt": "junk", + "memory/MEMORY.md": "mem", + } + rc = _run([ + "download", "--repo", "nano", "--framework", "nanobot", + "--local_dir", str(self.out), + ]) + self.assertEqual(rc, 0) + # random/junk.txt should NOT be written. + self.assertFalse((self.out / "random" / "junk.txt").exists()) + # Valid files should be written. + self.assertTrue((self.out / "SOUL.md").is_file()) + # Restore original store. + _DownloadStub.STORE = {"SOUL.md": "soul", "USER.md": "user", "memory/MEMORY.md": "mem"} + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _DownloadStub) + def test_download_repo_with_slash(self, *_): + """--repo with '/' uses the specified group instead of username.""" + rc = _run([ + "download", "--repo", "othergroup/nano", "--framework", "nanobot", + "--local_dir", str(self.out), + ]) + self.assertEqual(rc, 0) + # Should still write files (stub doesn't care about group). + self.assertTrue((self.out / "SOUL.md").is_file()) + class TestConvert(unittest.TestCase): def setUp(self): diff --git a/tests/cli/test_upload.py b/tests/cli/test_upload.py index c893245..31a7324 100644 --- a/tests/cli/test_upload.py +++ b/tests/cli/test_upload.py @@ -1,9 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """CLI argument parsing and upload flow (against a stubbed client).""" -import io import tempfile import unittest -import zipfile from pathlib import Path from unittest import mock @@ -20,20 +18,16 @@ def __init__(self, server, token=None, timeout=60): self.server = server self.token = token self.created = [] - self.uploaded_zip = None + self.uploaded_resources = None _StubClient.instances.append(self) def check_repo(self, path, name): return False def upload_file(self, resources): - """Accept either dict or bytes; return a fake file_id.""" - if isinstance(resources, dict): - from ultron.cli.sync import zip_resources - self.uploaded_zip = zip_resources(resources) - else: - self.uploaded_zip = resources - return "fake-file-id" + """Accept Dict[str, bytes]; return a fake Gid.""" + self.uploaded_resources = resources + return "fake-gid-uuid" def create_repo(self, path, name, framework, **kwargs): self.created.append((path, name, framework, kwargs.get("system_prompt_files"))) @@ -52,6 +46,8 @@ def setUp(self): (self.root / "agents").mkdir() (self.root / "agents" / "reviewer.md").write_text("reviewer") (self.root / "AGENTS.md").write_text("shared") + (self.root / "skills" / "test-skill").mkdir(parents=True) + (self.root / "skills" / "test-skill" / "SKILL.md").write_text("skill") _StubClient.instances = [] def tearDown(self): @@ -76,13 +72,6 @@ def test_no_files_fails(self): ]) self.assertEqual(rc, 1) - def test_list_agents(self): - rc = _run([ - "upload", "--framework", "qoder", "--list", - "--local_dir", str(self.root), - ]) - self.assertEqual(rc, 0) - @mock.patch.object(commands.config, "resolve_username", return_value="u") @mock.patch.object(commands.config, "resolve_token", return_value="tok") @mock.patch.object(commands.config, "resolve_server", return_value="http://s") @@ -95,16 +84,18 @@ def test_full_upload_creates_then_uploads_zip(self, *_): self.assertEqual(rc, 0) self.assertEqual(len(_StubClient.instances), 1) client = _StubClient.instances[0] - # create_repo called with (path, name, framework, system_prompt_files) + # create_repo called with (group, repo_name, framework, system_prompt_files) self.assertEqual(len(client.created), 1) self.assertEqual(client.created[0][:3], ("u", "qoder-reviewer", "qoder")) - self.assertEqual(client.created[0][3], "fake-file-id") - # Verify zip content - self.assertIsNotNone(client.uploaded_zip) - with zipfile.ZipFile(io.BytesIO(client.uploaded_zip)) as zf: - paths = {p.removeprefix("agent/") for p in zf.namelist()} - self.assertIn("agents/reviewer.md", paths) - self.assertIn("AGENTS.md", paths) + self.assertEqual(client.created[0][3], "fake-gid-uuid") + # Verify uploaded resources are bytes-valued dict + self.assertIsNotNone(client.uploaded_resources) + self.assertIsInstance(client.uploaded_resources, dict) + self.assertIn("agents/reviewer.md", client.uploaded_resources) + self.assertIn("AGENTS.md", client.uploaded_resources) + # Values should be bytes + for v in client.uploaded_resources.values(): + self.assertIsInstance(v, bytes) @mock.patch.object(commands.config, "resolve_server", return_value=None) @mock.patch.object(commands.config, "resolve_token", return_value=None) @@ -115,6 +106,112 @@ def test_upload_without_login_fails(self, *_): ]) self.assertEqual(rc, 1) + # ---------- New tests for refactored behavior ---------- + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _StubClient) + def test_upload_global_only_no_name(self, *_): + """When --name is not specified and multiple agents exist, should fail.""" + # Add a second agent to trigger multiple-agent error. + (self.root / "agents" / "coder.md").write_text("coder") + rc = _run([ + "upload", "--framework", "qoder", + "--local_dir", str(self.root), + ]) + # Should fail because multiple sub-agents exist. + self.assertEqual(rc, 1) + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _StubClient) + def test_upload_auto_select_single_agent(self, *_): + """When only one sub-agent exists, auto-select it without --name.""" + rc = _run([ + "upload", "--framework", "qoder", + "--local_dir", str(self.root), + ]) + self.assertEqual(rc, 0) + client = _StubClient.instances[0] + # Should auto-select "reviewer" and upload as qoder-reviewer. + self.assertEqual(client.created[0][1], "qoder-reviewer") + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _StubClient) + def test_upload_with_repo_slash(self, *_): + """--repo with '/' should use the group from repo, not username.""" + rc = _run([ + "upload", "--framework", "qoder", "--name", "reviewer", + "--repo", "mygroup/myrepo", + "--local_dir", str(self.root), + ]) + self.assertEqual(rc, 0) + client = _StubClient.instances[0] + # group should be "mygroup", repo should be "myrepo". + self.assertEqual(client.created[0][0], "mygroup") + self.assertEqual(client.created[0][1], "myrepo") + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _StubClient) + def test_upload_repo_defaults_to_name(self, *_): + """When --repo is omitted, remote repo name derives from --name.""" + rc = _run([ + "upload", "--framework", "qoder", "--name", "reviewer", + "--local_dir", str(self.root), + ]) + self.assertEqual(rc, 0) + client = _StubClient.instances[0] + self.assertEqual(client.created[0][0], "u") + self.assertEqual(client.created[0][1], "qoder-reviewer") + + @mock.patch.object(commands.config, "resolve_username", return_value="u") + @mock.patch.object(commands.config, "resolve_token", return_value="tok") + @mock.patch.object(commands.config, "resolve_server", return_value="http://s") + @mock.patch.object(commands, "UltronClient", _StubClient) + def test_upload_global_only_no_agents_dir(self, *_): + """When no agents/ directory exists, upload only shared (global) files.""" + import shutil + shutil.rmtree(self.root / "agents") + rc = _run([ + "upload", "--framework", "qoder", + "--local_dir", str(self.root), + ]) + self.assertEqual(rc, 0) + client = _StubClient.instances[0] + # Repo should be "qoder" (no name specified, global mode). + self.assertEqual(client.created[0][1], "qoder") + # Verify that no agents/*.md files are uploaded. + self.assertIsNotNone(client.uploaded_resources) + for p in client.uploaded_resources.keys(): + self.assertFalse(p.startswith("agents/")) + + +class TestListCli(unittest.TestCase): + def setUp(self): + self.tmp = tempfile.TemporaryDirectory() + self.root = Path(self.tmp.name) + (self.root / "agents").mkdir() + (self.root / "agents" / "reviewer.md").write_text("reviewer") + (self.root / "agents" / "coder.md").write_text("coder") + (self.root / "AGENTS.md").write_text("shared") + + def tearDown(self): + self.tmp.cleanup() + + def test_list_shows_agents(self): + rc = _run(["list", "--framework", "qoder", "--local_dir", str(self.root)]) + self.assertEqual(rc, 0) + + def test_list_unknown_framework_fails(self): + rc = _run(["list", "--framework", "nope", "--local_dir", str(self.root)]) + self.assertEqual(rc, 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/services/test_allowlist_agents.py b/tests/services/test_allowlist_agents.py index e5b5bfb..bc576e9 100644 --- a/tests/services/test_allowlist_agents.py +++ b/tests/services/test_allowlist_agents.py @@ -52,7 +52,7 @@ def test_qoder_list_agents(self): (self.root / "agents" / "a.md").write_text("a") (self.root / "agents" / "b.md").write_text("b") spec = QoderWorkspaceAllowlist(local_dir=self.root) - self.assertEqual(spec.list_agents(), ["a", "b"]) + self.assertEqual(spec.list_agents(), ["default", "a", "b"]) def test_qwenpaw_default_root_uses_agent_name(self): # default_workspace_root must embed the agent name under workspaces/. diff --git a/ultron/cli/__init__.py b/ultron/cli/__init__.py index 48365f6..89a75e2 100644 --- a/ultron/cli/__init__.py +++ b/ultron/cli/__init__.py @@ -10,7 +10,7 @@ import argparse import sys -from .commands import cmd_convert, cmd_download, cmd_login, cmd_recover, cmd_stop, cmd_upload, cmd_watch +from .commands import cmd_convert, cmd_download, cmd_list, cmd_login, cmd_recover, cmd_stop, cmd_upload, cmd_watch def build_parser() -> argparse.ArgumentParser: @@ -37,7 +37,11 @@ def build_parser() -> argparse.ArgumentParser: ) p_up.add_argument( "--name", "-n", - help="Internal sub-agent name; also the repository name (agent_id).", + help="Local sub-agent name (auto-selects if only one exists).", + ) + p_up.add_argument( + "--repo", "-r", + help="Remote repository name. Supports 'group/name' format. Defaults to local name.", ) p_up.add_argument( "--local_dir", "-d", @@ -46,10 +50,6 @@ def build_parser() -> argparse.ArgumentParser: p_up.add_argument("--message", "-m", help="Commit message.") p_up.add_argument("--server", help="Server URL override.") p_up.add_argument("--token", help="API token override.") - p_up.add_argument( - "--list", action="store_true", - help="List sub-agents discovered on disk for the framework and exit.", - ) p_up.add_argument( "--dry-run", action="store_true", help="Show the files that would be uploaded without uploading.", @@ -62,12 +62,16 @@ def build_parser() -> argparse.ArgumentParser: help="Download a sub-agent's files from the agent repository to disk.", ) p_dl.add_argument( - "--name", "-n", required=True, - help="Repository / sub-agent name to download.", + "--repo", "-r", required=True, + help="Remote repository name (required). Supports 'group/name' format.", ) p_dl.add_argument( "--framework", "-f", required=True, - help="Source framework / bot type (used to derive the repo name).", + help="Source framework / bot type.", + ) + p_dl.add_argument( + "--name", "-n", + help="Local sub-agent name to write as (default: 'default').", ) p_dl.add_argument( "--target", "-t", @@ -116,6 +120,18 @@ def build_parser() -> argparse.ArgumentParser: ) p_cv.set_defaults(func=cmd_convert) + # ---- list ---- + p_list = sub.add_parser( + "list", + help="List discoverable sub-agents for a framework.", + ) + p_list.add_argument( + "--framework", "-f", required=True, + help="Agent framework / bot type.", + ) + p_list.add_argument("--local_dir", "-d", help="Override workspace root.") + p_list.set_defaults(func=cmd_list) + # ---- watch ---- p_watch = sub.add_parser( "watch", @@ -127,7 +143,11 @@ def build_parser() -> argparse.ArgumentParser: ) p_watch.add_argument( "--name", "-n", - help="Sub-agent name (default: 'all' = full-scope sync).", + help="Local sub-agent name (default: global/shared files only).", + ) + p_watch.add_argument( + "--repo", "-r", + help="Remote repository name. Supports 'group/name' format. Defaults to local name.", ) p_watch.add_argument("--local_dir", "-d", help="Override workspace root.") p_watch.add_argument("--server", help="Server URL override.") @@ -185,7 +205,7 @@ def _run_watch_daemon(param_path: str) -> int: from .cache import pid_file, log_file from .client import UltronClient - from .commands import _build_allowlist, _repo_name, ALL_AGENT_NAME + from .commands import _build_allowlist, ALL_AGENT_NAME from .config import resolve_server, resolve_token, resolve_username from .watcher import watch_loop @@ -196,19 +216,24 @@ def _run_watch_daemon(param_path: str) -> int: ppath.unlink(missing_ok=True) username = payload.get("username") or resolve_username() - name = payload.get("name") or ALL_AGENT_NAME + repo = payload.get("repo") or payload.get("name", "") # compat: fall back to legacy "name" key framework = payload.get("framework", "") interval = payload.get("interval", 120) push_only = payload.get("push_only", True) + local_name = payload.get("local_name") or ALL_AGENT_NAME + + if not repo: + repo = "default" - server = resolve_server(None) - token = resolve_token(None) + # Prefer serialized server/token (supports modelscope integration); + # fall back to ultron's own config for standalone usage. + server = payload.get("server") or resolve_server(None) + token = payload.get("token") or resolve_token(None) if not server or not token or not username: return 1 - spec = _build_allowlist(framework, name, None) + spec = _build_allowlist(framework, local_name, None) client = UltronClient(server, token) - repo = _repo_name(framework, name) # Redirect stdout/stderr to log file. import os diff --git a/ultron/cli/cache.py b/ultron/cli/cache.py index 8ad2d0a..6beef6c 100644 --- a/ultron/cli/cache.py +++ b/ultron/cli/cache.py @@ -41,6 +41,15 @@ def pid_file() -> Path: return cache_dir() / "watch.pid" +def stop_file() -> Path: + """Stop signal file: presence tells the watch loop to exit gracefully. + + Cross-platform mechanism — works on both Unix and Windows where signal + delivery is unreliable. + """ + return cache_dir() / "watch.stop" + + # ---- Sync state persistence ---- def sync_state_file(name: str) -> Path: diff --git a/ultron/cli/client.py b/ultron/cli/client.py index ade9d51..c03a196 100644 --- a/ultron/cli/client.py +++ b/ultron/cli/client.py @@ -1,25 +1,27 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """HTTP client for ultron's agent-repository API. -All endpoints go through ``/openapi/v1/`` via ``modelscope_hub.OpenAPIClient``: +Endpoints: * ``GET /openapi/v1/users/me`` → login * ``GET /openapi/v1/agents/{path}/{name}`` → repo metadata * ``POST /openapi/v1/agents`` → create/update agent * ``GET /openapi/v1/agents/{path}/{name}/repo/files`` → list files -* ``GET /openapi/v1/agents/{path}/{name}/repo`` → file download -* ``POST /openapi/v1/files/upload`` → upload zip +* ``GET /agents/{path}/{name}/resolve/{rev}/{file}`` → file download +* ``POST /api/v1/agents/repo/files/upload`` → two-step OSS upload (step1) +* ``POST /openapi/v1/agents/{path}/{name}/commit/{rev}`` → commit files """ -import io +import logging from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional +from urllib.parse import unquote import requests from modelscope_hub._openapi import OpenAPIClient from modelscope_hub.config import HubConfig from modelscope_hub.errors import HubError, NotExistError -from .sync import zip_resources +logger = logging.getLogger("ultron.cli") @dataclass @@ -80,6 +82,43 @@ def check_repo(self, path: str, name: str) -> bool: """True if the repo exists, False on 404.""" return self.repo_info(path, name) is not None + def list_agents(self, owner: Optional[str] = None, page_number: int = 1, page_size: int = 10) -> dict: + """List agent repositories (GET /agents). + + Returns a dict with 'items' (list of agent metadata dicts) and + 'total_count' (int). + """ + params = {"page_number": page_number, "page_size": page_size} + if owner: + params["owner"] = owner + try: + data = self._openapi._request( + "GET", "/agents", params=params, require_token=False) + except HubError as exc: + raise _wrap(exc) from exc + # Normalize response: server may return {Data: [...], Total: N} + # or {items: [...], total_count: N}. + if isinstance(data, dict): + items = None + for key in ("Data", "items", "data"): + if key in data: + items = data[key] + break + if items is None: + items = [] + total = data.get("Total") + if total is None: + total = data.get("total_count") + if total is None: + total = data.get("TotalCount") + if total is None: + total = len(items) + return {"items": items, "total_count": total} + # If response is a list directly + if isinstance(data, list): + return {"items": data, "total_count": len(data)} + return {"items": [], "total_count": 0} + def create_repo( self, path: str, name: str, framework: str, visibility: str = "public", @@ -127,43 +166,60 @@ def list_repo_files_detail(self, path: str, name: str, revision: str = 'master') return results def _fetch_tree_entries(self, path: str, name: str, revision: str) -> List[dict]: - """Fetch and normalize the repo file tree from the API.""" - try: - data = self._openapi._request( - "GET", f"/agents/{path}/{name}/repo/files", - params={ - "recursive": "true", - "page_size": "100", - "page_number": "1", - "revision": revision, - }, - ) - except HubError as exc: - raise _wrap(exc) from exc + """Fetch and normalize the repo file tree from the API (with pagination).""" + page = 1 + page_size = 100 + max_pages = 50 # safety cap: 5000 files max + all_entries: List[dict] = [] - raw = [] - if isinstance(data, dict): - raw = data.get("trees") or data.get("Trees") or [] - elif isinstance(data, list): - raw = data + while True: + try: + data = self._openapi._request( + "GET", f"/agents/{path}/{name}/repo/files", + params={ + "recursive": "true", + "page_size": str(page_size), + "page": str(page), + "revision": revision, + }, + ) + except HubError as exc: + raise _wrap(exc) from exc - entries: List[dict] = [] - for item in raw: - if not isinstance(item, dict): - continue - entries.append({ - "path": item.get("path") or item.get("Path") or "", - "type": item.get("type") or item.get("Type") or "", - "sha256": item.get("sha256") or item.get("Sha256") or "", - "committed_date": item.get("committed_date") or item.get("Committed_date") or 0, - }) - return entries + raw = [] + if isinstance(data, dict): + raw = data.get("trees") or data.get("Trees") or [] + elif isinstance(data, list): + raw = data + + for item in raw: + if not isinstance(item, dict): + continue + all_entries.append({ + "path": item.get("path") or item.get("Path") or "", + "type": item.get("type") or item.get("Type") or "", + "sha256": item.get("sha256") or item.get("Sha256") or "", + "committed_date": item.get("committed_date") or item.get("Committed_date") or 0, + }) + + if len(raw) < page_size: + break + page += 1 + if page > max_pages: + logger.warning( + "Pagination limit reached (%d pages) for %s/%s; results may be incomplete.", + max_pages, path, name, + ) + break + + return all_entries def download_repo_file(self, path: str, name: str, file_path: str, - revision: str = "master") -> str: + revision: str = "master", *, binary: bool = False): """Download one repo file (GET /agents/{path}/{name}/resolve/{revision}/{file_path}). This endpoint does NOT use the /openapi/v1/ prefix. + Returns bytes when *binary=True*, otherwise str. """ url = f"{self.server}/agents/{path}/{name}/resolve/{revision}/{file_path}" headers = {"Authorization": f"Bearer {self.token}"} if self.token else {} @@ -176,29 +232,120 @@ def download_repo_file(self, path: str, name: str, file_path: str, raise ApiError(status, detail) from exc except requests.RequestException as exc: raise ApiError(0, str(exc)) from exc - return resp.text + return resp.content if binary else resp.text + + # ---- upload (two-step OSS) ---- + + def _request_upload_urls(self, filenames: List[str]) -> dict: + """Step 1: POST /api/v1/agents/repo/files/upload → {Gid, Urls}. + + Uses /api/v1/ prefix (not /openapi/v1/). Response envelope uses + capitalised keys: {"Code": 200, "Data": {...}, "Success": true}. + """ + url = f"{self.server}/api/v1/agents/repo/files/upload" + headers = {"Authorization": f"Bearer {self.token}", + "Content-Type": "application/json"} + try: + resp = requests.post(url, json={"FileNames": filenames}, + headers=headers, timeout=self.timeout) + resp.raise_for_status() + except requests.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else 0 + detail = exc.response.text if exc.response is not None else str(exc) + raise ApiError(status, detail) from exc + except requests.RequestException as exc: + raise ApiError(0, str(exc)) from exc + body = resp.json() + if not body.get("Success"): + raise ApiError(body.get("Code", 0), body.get("Message", "upload credential failed")) + return body["Data"] + + @staticmethod + def _normalize_oss_url(url: str) -> str: + """Decode %2F in the URL path so OSS signature verification passes. + + The server may return signed URLs with path separators encoded as %2F. + OSS computes the signature on the *decoded* resource path, so we must + send the request with real '/' in the path. We only decode the path + portion (before '?') to avoid corrupting query-string parameters. + """ + parts = url.split("?", 1) + path_part = parts[0] + if "%2F" not in path_part and "%2f" not in path_part: + return url + decoded_path = unquote(path_part) + if len(parts) == 2: + return decoded_path + "?" + parts[1] + return decoded_path + + def _upload_to_oss(self, signed_url: str, data: bytes) -> None: + """Step 2: PUT raw bytes to signed OSS URL. + + The server signs the URL with these headers in the StringToSign: + - Content-Type: application/octet-stream + - x-oss-meta-author: aliy (included in CanonicalizedOSSHeaders) + Both MUST be present for the signature to match. + + COUPLING: These headers are dictated by the server-side signing config. + If the server changes its signing parameters, these must be updated + in lockstep. + """ + url = self._normalize_oss_url(signed_url) + try: + resp = requests.put(url, data=data, + headers={ + "Content-Type": "application/octet-stream", + "x-oss-meta-author": "aliy", + }, + timeout=max(self.timeout, 300)) + resp.raise_for_status() + except requests.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else 0 + detail = exc.response.text if exc.response is not None else str(exc) + raise ApiError(status, detail) from exc + except requests.RequestException as exc: + raise ApiError(0, str(exc)) from exc + + def upload_file(self, resources: Dict[str, bytes]) -> str: + """Two-step upload: get signed URLs → PUT to OSS → return Gid. + + The returned Gid (UUID) is used as ``system_prompt_files`` in + :meth:`create_repo`. + + Returns empty string if *resources* is empty (nothing to upload). + """ + if not resources: + logger.warning("upload_file called with empty resources; skipping.") + return "" + filenames = list(resources.keys()) + data = self._request_upload_urls(filenames) + gid = data["Gid"] + url_map = {item["Filename"]: item["Url"] for item in data["Urls"]} + for fname, content in resources.items(): + signed_url = url_map.get(fname) + if not signed_url: + raise ApiError( + 0, + f"Server did not return a signed URL for '{fname}'. " + f"Available: {list(url_map.keys())}", + ) + self._upload_to_oss(signed_url, content) + return gid + + # ---- commit (incremental) ---- - # ---- upload ---- + def commit_files(self, path: str, name: str, actions: List[dict], + revision: str = "master", commit_message: str = "sync") -> dict: + """Commit file changes via POST /openapi/v1/agents/{path}/{name}/commit/{revision}. - def upload_file(self, resources: Union[Dict[str, str], bytes]) -> str: - """Upload agent files (POST /files/upload), return the file ID. + *actions* example:: - Args: - resources: Either a dict {rel_path: content} that will be zipped, - or raw zip bytes. + [{"action": "create", "file_path": "a.md", + "content": "hello", "encoding": "text"}] """ - if isinstance(resources, dict): - zip_bytes = zip_resources(resources) - else: - zip_bytes = resources + body = {"commit_message": commit_message, "actions": actions} try: - files = [("file", ("agent.zip", io.BytesIO(zip_bytes), "application/zip"))] - result = self._openapi._request("POST", "/files/upload", files=files) - return ( - result.get("id") - or result.get("Id") - or result.get("file_id") - or "" - ) + return self._openapi._request( + "POST", f"/agents/{path}/{name}/commit/{revision}", json_body=body) except HubError as exc: raise _wrap(exc) from exc diff --git a/ultron/cli/commands.py b/ultron/cli/commands.py index afa6058..c702d55 100644 --- a/ultron/cli/commands.py +++ b/ultron/cli/commands.py @@ -5,9 +5,14 @@ import sys import zipfile from pathlib import Path -from typing import Dict - -from ultron.services.harness.allowlist import ALLOWLIST_REGISTRY, ALL_AGENT_NAME +from typing import Dict, Optional + +from ultron.services.harness.allowlist import ( + ALLOWLIST_REGISTRY, + ALL_AGENT_NAME, + DEFAULT_AGENT_NAME, + GLOBAL_AGENT_NAME, +) from ultron.services.harness.defaults import get_defaults from ultron.services.harness.merge import merge_resources @@ -55,6 +60,55 @@ def _repo_name(framework: str, name: str) -> str: return "default" +def _resolve_remote(repo: Optional[str] = None, name: Optional[str] = None, framework: str = "", username: str = ""): + """Resolve remote target as (group, repo_name). + + - repo contains '/' → split into (group, repo_name), ignore username + - repo without '/' → (username, repo) + - repo is None/empty → derive from name+framework using _repo_name logic + """ + if repo: + if "/" in repo: + parts = repo.split("/", 1) + return parts[0], parts[1] + return username, repo + # No explicit repo → derive from framework + name + derived = _repo_name(framework, name or "") + return username, derived + + +def _resolve_local_name(name: Optional[str], framework: str, local_dir=None): + """Resolve local agent name when --name is omitted. + + Returns (resolved_name, error_message). + - If name is given → use it directly. + - If omitted → check list_agents(): + - 0 or only 'default' → use GLOBAL_AGENT_NAME (shared files only) + - exactly 1 non-default agent → auto-select it + - multiple → return error + """ + if name: + return name, None + + # Build a temporary spec to discover agents. + spec_cls = ALLOWLIST_REGISTRY[framework] + local = Path(local_dir).expanduser() if local_dir else None + tmp_spec = spec_cls(agent_name=DEFAULT_AGENT_NAME, local_dir=local) + agents = tmp_spec.list_agents() + + # Filter out "default" to find real sub-agents. + real_agents = [a for a in agents if a != DEFAULT_AGENT_NAME] + + if len(real_agents) == 0: + return GLOBAL_AGENT_NAME, None + if len(real_agents) == 1: + return real_agents[0], None + return None, ( + f"multiple sub-agents found: {', '.join(agents)}. " + f"Please specify --name to select one." + ) + + def _frameworks() -> str: return ", ".join(sorted(ALLOWLIST_REGISTRY)) @@ -107,36 +161,51 @@ def _convert(resources: dict, source_fw: str, target_fw: str) -> dict: return result.merged_files -def cmd_upload(args) -> int: +def cmd_list(args) -> int: + """List discoverable sub-agents for a framework.""" framework = args.framework if framework not in ALLOWLIST_REGISTRY: return _fail(f"unknown framework '{framework}'. Available: {_frameworks()}") - # --list: enumerate discoverable sub-agents and exit (no name required). - if args.list: - spec = _build_allowlist(framework, args.name or "default", args.local_dir) - agents = spec.list_agents() - print(f"Sub-agents for {framework}:") - for a in agents: - print(f" {a}") - return 0 + spec = _build_allowlist(framework, DEFAULT_AGENT_NAME, getattr(args, 'local_dir', None)) + agents = spec.list_agents() + print(f"Sub-agents for {framework}:") + for a in agents: + # Show file count for each agent. + if a == DEFAULT_AGENT_NAME: + tmp = _build_allowlist(framework, GLOBAL_AGENT_NAME, getattr(args, 'local_dir', None)) + else: + tmp = _build_allowlist(framework, a, getattr(args, 'local_dir', None)) + count = len(tmp.collect_bytes()) + label = " (global/shared files only)" if a == DEFAULT_AGENT_NAME else "" + print(f" {a} — {count} file(s){label}") + return 0 - if not args.name: - return _fail("--name is required (the internal sub-agent name)") - spec = _build_allowlist(framework, args.name, args.local_dir) +def cmd_upload(args) -> int: + framework = args.framework + if framework not in ALLOWLIST_REGISTRY: + return _fail(f"unknown framework '{framework}'. Available: {_frameworks()}") + + # Resolve local agent name (auto-select if only one). + local_name, err = _resolve_local_name(args.name, framework, args.local_dir) + if err: + return _fail(err) + + spec = _build_allowlist(framework, local_name, args.local_dir) root = spec.workspace_root - resources: Dict[str, str] = spec.collect() + resources: Dict[str, bytes] = spec.collect_bytes() if not resources: + display_name = local_name if local_name != GLOBAL_AGENT_NAME else "global" return _fail( - f"no files found for {framework}/{args.name} under {root}. " + f"no files found for {framework}/{display_name} under {root}. " f"Check the path or pass --local_dir." ) - total_bytes = sum(len(c.encode("utf-8")) for c in resources.values()) + total_bytes = sum(len(v) for v in resources.values()) print(f"Found {len(resources)} file(s) ({total_bytes} bytes) under {root}:") for rel in sorted(resources): - print(f" {rel} ({len(resources[rel].encode('utf-8'))} B)") + print(f" {rel} ({len(resources[rel])} B)") if args.dry_run: print("\n[dry-run] nothing uploaded.") @@ -152,13 +221,22 @@ def cmd_upload(args) -> int: client = UltronClient(server, token) - repo = _repo_name(framework, args.name) + # Resolve remote target. + # Use the resolved local_name for remote derivation (handles auto-select). + effective_name = local_name if local_name != GLOBAL_AGENT_NAME else None + group, repo = _resolve_remote( + repo=getattr(args, 'repo', None), + name=effective_name, + framework=framework, + username=username, + ) + # Step 1: upload files -> get file_id try: file_id = client.upload_file(resources) # Step 2: create/update agent with file_id result = client.create_repo( - username, repo, framework, + group, repo, framework, system_prompt_files=file_id, ) except ApiError as e: @@ -166,16 +244,16 @@ def cmd_upload(args) -> int: print( f"\nUploaded {len(resources)} file(s) to " - f"{username}/{repo}." + f"{group}/{repo}." ) return 0 def cmd_download(args) -> int: - if not args.name: - return _fail("--name is required (the repository / sub-agent name)") + if not getattr(args, 'repo', None): + return _fail("--repo is required for download (the remote repository name)") if not args.framework: - return _fail("--framework is required for download (to derive repo name)") + return _fail("--framework is required for download") framework = args.framework if framework not in ALLOWLIST_REGISTRY: @@ -189,18 +267,27 @@ def cmd_download(args) -> int: if not username: return _fail("missing username; run 'ultron login' again.") - repo = _repo_name(framework, args.name) + # Resolve remote target. + group, repo = _resolve_remote( + repo=args.repo, + name=args.name, + framework=framework, + username=username, + ) + client = UltronClient(server, token) try: - info = client.repo_info(username, repo) + info = client.repo_info(group, repo) if info is None: - return _fail(f"repository {username}/{repo} not found.") - paths = client.list_repo_files(username, repo) + return _fail(f"repository {group}/{repo} not found.") + paths = client.list_repo_files(group, repo) if not paths: - return _fail(f"repository {username}/{repo} has no files.") - # List then fetch each file via its download link, one at a time. + return _fail(f"repository {group}/{repo} has no files.") + # NOTE: downloads as text (str). Binary files (images, etc.) may lose + # fidelity when decoded as text. A future binary-aware path is needed + # for full parity with upload's collect_bytes. resources = { - p: client.download_repo_file(username, repo, p) for p in paths + p: client.download_repo_file(group, repo, p) for p in paths } except ApiError as e: return _fail(_api_error_message(e, "download")) @@ -213,17 +300,32 @@ def cmd_download(args) -> int: resources = _convert(resources, framework, target_fw) print(f"Converted {framework} -> {target_fw} ({len(resources)} file(s)).") - spec = _build_allowlist(target_fw, args.name, args.local_dir) + # Resolve local agent name for writing. + local_name = args.name or DEFAULT_AGENT_NAME + spec = _build_allowlist(target_fw, local_name, args.local_dir) root = spec.workspace_root - print(f"{len(resources)} file(s) for {username}/{repo} (framework={target_fw}):") - for rel in sorted(resources): + + # Filter downloaded resources by allowlist patterns. + patterns = spec.resolved_patterns() + filtered = {k: v for k, v in resources.items() if spec.matches(k, patterns)} + skipped = set(resources.keys()) - set(filtered.keys()) + if skipped: + print(f"Skipped {len(skipped)} file(s) not matching allowlist:") + for s in sorted(skipped): + print(f" [skip] {s}") + + if not filtered: + return _fail("no downloaded files match the local allowlist patterns.") + + print(f"{len(filtered)} file(s) for {group}/{repo} (framework={target_fw}):") + for rel in sorted(filtered): print(f" {rel} -> {root / rel}") if args.dry_run: print("\n[dry-run] nothing written.") return 0 - written = spec.apply(resources) + written = spec.apply(filtered) print(f"\nWrote {len(written)} file(s) under {root}.") return 0 @@ -279,8 +381,13 @@ def cmd_watch(args) -> int: if framework not in ALLOWLIST_REGISTRY: return _fail(f"unknown framework '{framework}'. Available: {_frameworks()}") - # Default --name to "all" (full-scope sync). - name = args.name or ALL_AGENT_NAME + # Resolve local agent name: if --name not given, default to ALL mode. + if args.name: + local_name, err = _resolve_local_name(args.name, framework, args.local_dir) + if err: + return _fail(err) + else: + local_name = ALL_AGENT_NAME server = config.resolve_server(args.server) token = config.resolve_token(args.token) @@ -296,22 +403,30 @@ def cmd_watch(args) -> int: from .watcher import stop_daemon stop_daemon() - spec = _build_allowlist(framework, name, args.local_dir) + spec = _build_allowlist(framework, local_name, args.local_dir) client = UltronClient(server, token) - # Guard: file-per-agent frameworks must use --name all for watch. - if not spec.supports_individual_watch and name != ALL_AGENT_NAME: + # Guard: file-per-agent frameworks with a specific agent name. + if (not spec.supports_individual_watch + and local_name not in (GLOBAL_AGENT_NAME, ALL_AGENT_NAME, DEFAULT_AGENT_NAME)): return _fail( f"'{framework}' has shared files across sub-agents; " - f"watch only supports '--name all' to avoid sync conflicts. " - f"Use 'ultron upload/download -n {name}' for individual sub-agent operations." + f"watch only supports global/default mode to avoid sync conflicts. " + f"Use 'ultron upload/download -n {local_name}' for individual sub-agent operations." ) - repo = _repo_name(framework, name) + # Resolve remote target. + effective_name = args.name if args.name else None + group, repo = _resolve_remote( + repo=getattr(args, 'repo', None), + name=effective_name, + framework=framework, + username=username, + ) # Guard: check remote repo framework matches local. try: - info = client.repo_info(username, repo) + info = client.repo_info(group, repo) if info: remote_fw = info.get("Framework") or info.get("framework") or "" if remote_fw and remote_fw != framework: @@ -322,17 +437,17 @@ def cmd_watch(args) -> int: except ApiError as e: if e.status in (403, 401): return _fail(_api_error_message(e, "watch")) - pass # repo not found or unreachable — proceed, first push will create it + # repo not found or unreachable — proceed, first push will create it interval = 120 push_only = not getattr(args, "pull", False) - print(f"Starting sync for {username}/{repo} (interval={interval}s)...") + print(f"Starting sync for {group}/{repo} (interval={interval}s)...") print(f" Framework: {framework}") print(f" Root: {spec.workspace_root}") if push_only: - print(f" Mode: push-only (local → remote, will NOT pull remote changes)") + print(f" Mode: push-only (local \u2192 remote, will NOT pull remote changes)") else: - print(f" Mode: bidirectional (local ↔ remote, WILL pull remote changes to local)") + print(f" Mode: bidirectional (local \u2194 remote, WILL pull remote changes to local)") print(f" Logs: {pid_file().parent / 'logs' / 'watch.log'}") print(f" Stop: ultron stop") diff --git a/ultron/cli/sync.py b/ultron/cli/sync.py index 54ba7cb..36f8950 100644 --- a/ultron/cli/sync.py +++ b/ultron/cli/sync.py @@ -1,12 +1,13 @@ # Copyright (c) ModelScope Contributors. All rights reserved. """Core sync logic: backup, zip, bidirectional sync helpers.""" +import base64 import hashlib import io import logging import zipfile from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, List, Union from .cache import cache_dir @@ -17,7 +18,7 @@ -def zip_resources(resources: Dict[str, str], wrapper: str = "agent") -> bytes: +def zip_resources(resources: Dict[str, Union[str, bytes]], wrapper: str = "agent") -> bytes: """Pack resources into a deterministic in-memory zip. The server always strips the first directory level from zip entries, so we @@ -25,10 +26,8 @@ def zip_resources(resources: Dict[str, str], wrapper: str = "agent") -> bytes: after stripping, the remaining path matches the original ``rel_path``. Args: - resources: A dict {rel_path: content_or_filepath}. If a value is a - short string that points to an existing file on disk, its - content is read; otherwise the value is treated as literal - text content. + resources: A dict {rel_path: content}. Values are written directly + via ``ZipFile.writestr`` (accepts both str and bytes). wrapper: Name of the top-level wrapper directory (default: "agent"). """ buf = io.BytesIO() @@ -43,7 +42,7 @@ def backup_local(spec, name: str) -> Path: Returns the path to the created zip file. """ - resources: Dict[str, str] = spec.collect() + resources: Dict[str, bytes] = spec.collect_bytes() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") zip_path = cache_dir() / f"{name}_{timestamp}.zip" zip_path.write_bytes(zip_resources(resources)) @@ -53,22 +52,24 @@ def backup_local(spec, name: str) -> Path: # ---- Bidirectional sync helpers ---- -def sha256_content(content: str) -> str: - """Compute sha256 of text content (utf-8 encoded, no BOM).""" - return hashlib.sha256(content.encode("utf-8")).hexdigest() +def sha256_content(content: Union[str, bytes]) -> str: + """Compute sha256 of content (accepts str or bytes).""" + if isinstance(content, str): + content = content.encode("utf-8") + return hashlib.sha256(content).hexdigest() def detect_local_changes( - local_resources: Dict[str, str], + local_resources: Dict[str, bytes], baseline_sha256: Dict[str, str], -) -> Dict[str, str]: +) -> Dict[str, Union[bytes, None]]: """Compare local files against the sync baseline sha256 map. Returns a dict of files that differ: - - key present with non-empty value: content changed or file is new locally - - key present with empty string value: file was deleted locally (in baseline but not local) + - key present with bytes value: content changed or file is new locally + - key present with None value: file was deleted locally (in baseline but not local) """ - changed: Dict[str, str] = {} + changed: Dict[str, Union[bytes, None]] = {} # Modified or new files. for rel, content in local_resources.items(): local_sha = sha256_content(content) @@ -77,7 +78,7 @@ def detect_local_changes( # Deleted files (in baseline but not in local). for rel in baseline_sha256: if rel not in local_resources: - changed[rel] = "" + changed[rel] = None return changed @@ -86,16 +87,58 @@ def push_resources( username: str, name: str, framework: str, - resources: Dict[str, str], + resources: Dict[str, bytes], ) -> None: - """Zip, upload, and create/update the remote agent repo. + """Full upload via two-step OSS, then create/update agent repo. Raises on failure (caller should NOT update baseline on exception). + Does nothing if *resources* is empty. """ - zip_bytes = zip_resources(resources) - file_id = client.upload_file(zip_bytes) - client.create_repo(username, name, framework, system_prompt_files=file_id) - logger.info("Pushed %d file(s) (%d bytes zip).", len(resources), len(zip_bytes)) + if not resources: + logger.warning("push_resources called with empty resources; skipping.") + return + gid = client.upload_file(resources) + if not gid: + logger.warning("upload_file returned empty gid; skipping create_repo.") + return + client.create_repo(username, name, framework, system_prompt_files=gid) + for rel in sorted(resources): + logger.info(" UPLOAD: %s (%d B)", rel, len(resources[rel])) + logger.info("Pushed %d file(s) via OSS (gid=%s).", len(resources), gid) + + +def push_incremental( + client: "UltronClient", + username: str, + name: str, + changed: Dict[str, Union[bytes, None]], + remote_paths: set, +) -> None: + """Incremental push via commit interface. + + Builds create/update/delete actions and commits in one request. + Raises on failure (caller should NOT update baseline on exception). + """ + actions: List[dict] = [] + for fpath, content in changed.items(): + if content is None: # None = delete + actions.append({"action": "delete", "file_path": fpath}) + else: + action_type = "update" if fpath in remote_paths else "create" + # Try UTF-8 decode; fall back to base64 for binary. + try: + text = content.decode("utf-8") + actions.append({"action": action_type, "file_path": fpath, + "content": text, "encoding": "text"}) + except UnicodeDecodeError: + b64 = base64.b64encode(content).decode("ascii") + actions.append({"action": action_type, "file_path": fpath, + "content": b64, "encoding": "base64"}) + if actions: + for a in actions: + logger.info(" %s: %s", a["action"].upper(), a["file_path"]) + client.commit_files(username, name, actions, commit_message="watch sync") + logger.info("Committed %d action(s) incrementally.", len(actions)) def pull_incremental( @@ -104,7 +147,7 @@ def pull_incremental( name: str, spec, remote_files: "List[RemoteFileInfo]", - local_resources: Dict[str, str], + local_resources: Dict[str, bytes], ) -> int: """Incrementally pull remote changes to local workspace. @@ -116,6 +159,7 @@ def pull_incremental( (caller should NOT update baseline on exception). """ root: Path = spec.workspace_root + resolved_root = root.resolve() remote_sha_map = {f.path: f.sha256 for f in remote_files} remote_paths = set(remote_sha_map.keys()) local_paths = set(local_resources.keys()) @@ -123,22 +167,27 @@ def pull_incremental( # Download files that are new or changed on remote. for rfile in remote_files: + target = (root / rfile.path).resolve() + if not target.is_relative_to(resolved_root): + logger.warning(" Skipped (path traversal): %s", rfile.path) + continue local_content = local_resources.get(rfile.path) if local_content is not None: local_sha = sha256_content(local_content) if local_sha == rfile.sha256: continue # identical, skip # Need to download. - content = client.download_repo_file(username, name, rfile.path) - target = root / rfile.path + content = client.download_repo_file(username, name, rfile.path, binary=True) target.parent.mkdir(parents=True, exist_ok=True) - target.write_text(content, encoding="utf-8") + target.write_bytes(content) changes += 1 logger.info(" Downloaded: %s", rfile.path) # Delete local files that no longer exist on remote. for rel in sorted(local_paths - remote_paths): - target = root / rel + target = (root / rel).resolve() + if not target.is_relative_to(resolved_root): + continue if target.exists(): target.unlink() changes += 1 diff --git a/ultron/cli/watcher.py b/ultron/cli/watcher.py index 4da8a16..fbec77f 100644 --- a/ultron/cli/watcher.py +++ b/ultron/cli/watcher.py @@ -5,16 +5,18 @@ import signal import subprocess import sys +import threading import time from logging.handlers import RotatingFileHandler from typing import List, Optional -from .cache import load_sync_state, log_file, pid_file, save_sync_state +from .cache import load_sync_state, log_file, pid_file, save_sync_state, stop_file from .client import ApiError from .sync import ( backup_local, detect_local_changes, pull_incremental, + push_incremental, push_resources, ) @@ -35,34 +37,58 @@ def _get_logger() -> logging.Logger: return _logger -def watch_loop(spec, client, username: str, name: str, framework: str, interval: int = 120, *, push_only: bool = True): +def watch_loop(spec, client, username: str, repo: str, framework: str, interval: int = 120, *, push_only: bool = True): """Sync loop: push local changes, optionally pull remote changes. - push_only=True (default): only pushes, never modifies local files. - push_only=False: full bidirectional sync (remote wins on conflict). + Args: + repo: Remote repository name (used as the API path component). + push_only: True (default) = only pushes, never modifies local files. + False = full bidirectional sync (remote wins on conflict). """ logger = _get_logger() logger.info("Watch started for %s/%s (root=%s, interval=%ds, push_only=%s)", - username, name, spec.workspace_root, interval, push_only) + username, repo, spec.workspace_root, interval, push_only) - state = load_sync_state(name) + state = load_sync_state(repo) running = True + stop_event = threading.Event() + sf = stop_file() def _handle_term(signum, frame): nonlocal running running = False + stop_event.set() - signal.signal(signal.SIGTERM, _handle_term) + # Unix: register signal handlers for graceful stop via kill(1). + # Windows: SIGTERM triggers TerminateProcess (hard kill), so signals are + # unreliable; the stop-file mechanism below is the primary channel. + if hasattr(signal, "SIGTERM"): + signal.signal(signal.SIGTERM, _handle_term) signal.signal(signal.SIGINT, _handle_term) + # Remove any stale stop file from a previous session. + sf.unlink(missing_ok=True) + while running: - time.sleep(interval) + # Wait with periodic wake-ups to poll the stop file (Windows compat). + # On Unix, the signal handler sets stop_event immediately. + elapsed = 0 + poll_interval = min(interval, 5) # check stop file every 5s + while elapsed < interval and running: + stop_event.wait(timeout=poll_interval) + if stop_event.is_set(): + running = False + break + if sf.exists(): + running = False + break + elapsed += poll_interval if not running: break # ---- Fetch remote file list ---- try: - remote_files = client.list_repo_files_detail(username, name) + remote_files = client.list_repo_files_detail(username, repo) except ApiError as e: if e.status in (404, 500): remote_files = [] @@ -71,7 +97,7 @@ def _handle_term(signum, frame): continue # ---- Collect local resources & detect changes ---- - local_resources = spec.collect() + local_resources = spec.collect_bytes() scope = set(local_resources.keys()) | set(state.get("remote_files", {}).keys()) remote_sha_map = {f.path: f.sha256 for f in remote_files if f.path in scope} @@ -86,35 +112,60 @@ def _handle_term(signum, frame): try: did_sync = _sync_action( push_only, remote_changed, local_changed, - client, username, name, framework, spec, + client, username, repo, framework, spec, remote_files, local_resources, logger, + state, ) except Exception as exc: logger.error("Sync failed (will retry): %s", exc) # ---- Update baseline on successful sync ---- if did_sync: - _refresh_baseline(client, username, name, spec, state, logger) - save_sync_state(name, state["last_commit_date"], state["remote_files"]) + # After a pull, local files may have changed — re-collect. + if not push_only: + local_resources = spec.collect_bytes() + _refresh_baseline(client, username, repo, local_resources, state, logger) + save_sync_state(repo, state["last_commit_date"], state["remote_files"]) logger.info("Watch stopped (signal received).") pf = pid_file() if pf.exists(): pf.unlink(missing_ok=True) + sf.unlink(missing_ok=True) + + +def _push_local(client, username, name, framework, local_resources, state, logger) -> bool: + """Push local changes: full upload on first time, incremental thereafter. + + Returns True if something was actually pushed, False otherwise. + """ + if not local_resources: + logger.debug("No local resources to push — skipping.") + return False + if not state.get("remote_files"): + push_resources(client, username, name, framework, local_resources) + logger.info("Pushed local changes (full upload — first time).") + return True + else: + changed = detect_local_changes(local_resources, state["remote_files"]) + if changed: + push_incremental(client, username, name, changed, set(state["remote_files"].keys())) + logger.info("Pushed local changes (incremental commit).") + return True + return False def _sync_action( push_only, remote_changed, local_changed, client, username, name, framework, spec, remote_files, local_resources, logger, + state, ) -> bool: """Execute the appropriate sync action. Returns True if something changed.""" if push_only: if not local_changed: return False - push_resources(client, username, name, framework, local_resources) - logger.info("Pushed local changes.") - return True + return _push_local(client, username, name, framework, local_resources, state, logger) if remote_changed and local_changed: backup_path = backup_local(spec, name) @@ -125,19 +176,22 @@ def _sync_action( pull_incremental(client, username, name, spec, remote_files, local_resources) logger.info("Pulled remote changes (backup: %s).", backup_path) elif local_changed: - push_resources(client, username, name, framework, local_resources) - logger.info("Pushed local changes.") + _push_local(client, username, name, framework, local_resources, state, logger) else: return False return True -def _refresh_baseline(client, username: str, name: str, spec, state: dict, logger) -> None: - """Re-fetch remote file list and update state in-place.""" +def _refresh_baseline(client, username: str, name: str, local_resources: dict, state: dict, logger) -> None: + """Re-fetch remote file list and update state in-place. + + Uses *local_resources* keys as the managed-file scope to avoid a redundant + disk scan (caller already collected them this cycle). + """ + managed = set(local_resources.keys()) for attempt in range(3): try: fresh = client.list_repo_files_detail(username, name) - managed = set(spec.collect().keys()) state["last_commit_date"] = max((f.committed_date for f in fresh), default=0) state["remote_files"] = {f.path: f.sha256 for f in fresh if f.path in managed} return @@ -207,12 +261,19 @@ def _daemonize_windows(target, *args, **kwargs): import tempfile # Serialize the arguments that watch_loop needs. + # spec (args[0]) carries the agent_name used to build the allowlist scope. + # client (args[1]) carries server/token for the child process. + spec_obj = args[0] if len(args) > 0 else None + client_obj = args[1] if len(args) > 1 else None payload = { "username": args[2] if len(args) > 2 else kwargs.get("username", ""), - "name": args[3] if len(args) > 3 else kwargs.get("name", ""), + "repo": args[3] if len(args) > 3 else kwargs.get("repo", ""), "framework": args[4] if len(args) > 4 else kwargs.get("framework", ""), "interval": args[5] if len(args) > 5 else kwargs.get("interval", 120), "push_only": kwargs.get("push_only", True), + "local_name": getattr(spec_obj, "agent_name", "") if spec_obj else "", + "server": getattr(client_obj, "server", "") if client_obj else "", + "token": getattr(client_obj, "token", "") if client_obj else "", # spec and client are rebuilt in the child from stored config. } # Write to a temp file that the child will read and delete. @@ -236,30 +297,45 @@ def _daemonize_windows(target, *args, **kwargs): pf.write_text(str(proc.pid), encoding="utf-8") -def stop_daemon() -> bool: - """Stop ALL running watch daemon processes. +_DEFAULT_WATCH_PATTERNS = [ + "ultron watch --framework", +] + - Kills the PID-file-tracked process, then scans for orphaned processes. - Waits briefly for graceful shutdown before returning. +def stop_daemon(extra_patterns: Optional[List[str]] = None) -> bool: + """Stop ALL running watch daemon processes (cross-platform). + + Primary mechanism: write a stop-file that the watch loop polls. + Secondary: send SIGTERM (Unix) or taskkill (Windows) as a backup. + Cleans up PID file and stop file on return. """ stopped = False pf = pid_file() + sf = stop_file() + + # 1. Write the stop file — the watch loop will notice within 5 seconds. + sf.write_text("stop", encoding="utf-8") - # 1. Kill PID-file-tracked process. + # 2. Also send SIGTERM to PID-tracked process (Unix only). + # On Windows, os.kill(SIGTERM) = TerminateProcess (hard kill), which + # bypasses graceful shutdown entirely. Rely on stop-file instead. tracked_pid = None if pf.exists(): try: tracked_pid = int(pf.read_text().strip()) - os.kill(tracked_pid, signal.SIGTERM) + if hasattr(os, "fork"): + # Unix: SIGTERM triggers the handler → sets running=False. + os.kill(tracked_pid, signal.SIGTERM) + # On Windows, the stop file (written above) is the sole signal. stopped = True except (ValueError, OSError, ProcessLookupError): tracked_pid = None - pf.unlink(missing_ok=True) - # 2. Kill orphaned watch processes (Unix only; pgrep unavailable on Windows). + # 3. Kill orphaned watch processes. if hasattr(os, "fork"): + # Unix: use pgrep. my_pid = os.getpid() - for found_pid in _find_watch_pids(): + for found_pid in _find_watch_pids(extra_patterns): if found_pid in (my_pid, tracked_pid): continue try: @@ -267,23 +343,118 @@ def stop_daemon() -> bool: stopped = True except (ProcessLookupError, PermissionError): pass + else: + # Windows: use wmic/tasklist to find orphaned processes. + for found_pid in _find_watch_pids_windows(extra_patterns): + if found_pid == tracked_pid: + continue + _terminate_pid_windows(found_pid) + stopped = True + + # 4. Wait for graceful exit (stop-file polling interval is 5s max). + if stopped or tracked_pid: + _wait_for_exit(tracked_pid, timeout=8) + + # 5. Force kill if still alive. + if tracked_pid and _is_alive(tracked_pid): + _force_kill(tracked_pid) + + # 6. Clean up. + pf.unlink(missing_ok=True) + sf.unlink(missing_ok=True) - # 3. Wait for processes to exit (up to 3s). - if stopped: - time.sleep(1) + return stopped or tracked_pid is not None - return stopped +def _find_watch_pids(extra_patterns: Optional[List[str]] = None) -> List[int]: + """Find PIDs of running watch daemon processes via pgrep (Unix only). -def _find_watch_pids() -> List[int]: - """Find PIDs of running 'ultron watch' daemon processes via pgrep.""" + Searches default patterns plus any *extra_patterns* provided by the caller. + """ + patterns = list(dict.fromkeys(_DEFAULT_WATCH_PATTERNS + (extra_patterns or []))) + pids: set = set() + for pattern in patterns: + try: + result = subprocess.run( + ["pgrep", "-f", "--", pattern], + capture_output=True, text=True, timeout=5, + ) + if result.returncode == 0: + for p in result.stdout.strip().split("\n"): + if p.strip().isdigit(): + pids.add(int(p.strip())) + except (OSError, subprocess.TimeoutExpired, ValueError): + pass + return list(pids) + + +def _find_watch_pids_windows(extra_patterns: Optional[List[str]] = None) -> List[int]: + """Find PIDs of running watch daemon processes on Windows via wmic/tasklist. + + Searches for python processes whose command line matches watch patterns. + """ + patterns = list(dict.fromkeys(_DEFAULT_WATCH_PATTERNS + (extra_patterns or []))) + pids: set = set() try: + # Use wmic to get process command lines. result = subprocess.run( - ["pgrep", "-f", "ultron watch --framework"], - capture_output=True, text=True, timeout=5, + ["wmic", "process", "where", "name like '%python%'", + "get", "processid,commandline"], + capture_output=True, text=True, timeout=10, ) - if result.returncode == 0: - return [int(p) for p in result.stdout.strip().split("\n") if p.strip().isdigit()] + if result.returncode != 0: + return [] + for line in result.stdout.splitlines(): + line_lower = line.lower() + for pattern in patterns: + if pattern.lower() in line_lower: + # Extract PID (last number on the line). + parts = line.strip().split() + if parts and parts[-1].isdigit(): + pids.add(int(parts[-1])) + break except (OSError, subprocess.TimeoutExpired, ValueError): pass - return [] + return list(pids) + + +def _terminate_pid_windows(pid: int) -> None: + """Terminate a process on Windows using taskkill.""" + try: + subprocess.run( + ["taskkill", "/PID", str(pid), "/F"], + capture_output=True, timeout=5, + ) + except (OSError, subprocess.TimeoutExpired): + pass + + +def _is_alive(pid: int) -> bool: + """Check if a process with the given PID is still running.""" + try: + os.kill(pid, 0) + return True + except (ProcessLookupError, PermissionError, OSError): + return False + + +def _wait_for_exit(pid: Optional[int], timeout: int = 8) -> None: + """Wait up to *timeout* seconds for a process to exit.""" + if pid is None: + time.sleep(2) + return + for _ in range(timeout * 2): # check every 0.5s + if not _is_alive(pid): + return + time.sleep(0.5) + + +def _force_kill(pid: int) -> None: + """Force-kill a process (SIGKILL on Unix, taskkill /F on Windows).""" + if hasattr(os, "fork"): + try: + os.kill(pid, getattr(signal, "SIGKILL", signal.SIGTERM)) + except (ProcessLookupError, PermissionError, OSError): + pass + else: + _terminate_pid_windows(pid) diff --git a/ultron/services/harness/allowlist.py b/ultron/services/harness/allowlist.py index 9a131ce..5a4ce14 100644 --- a/ultron/services/harness/allowlist.py +++ b/ultron/services/harness/allowlist.py @@ -32,7 +32,7 @@ import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Type +from typing import Dict, List, Optional, Tuple, Type logger = logging.getLogger(__name__) @@ -40,6 +40,7 @@ DEFAULT_AGENT_NAME = "default" ALL_AGENT_NAME = "all" +GLOBAL_AGENT_NAME = "__global__" class ClawWorkspaceAllowlist(ABC): @@ -118,23 +119,39 @@ def workspace_root(self) -> Path: """Effective root: ``local_dir`` override, else the product default.""" return self._local_dir if self._local_dir is not None else self.default_workspace_root - def _resolved_patterns(self) -> List[str]: + def _is_global(self) -> bool: + """Whether we are in global-only mode (shared files only, no sub-agent).""" + return self.agent_name == GLOBAL_AGENT_NAME + + def resolved_patterns(self) -> List[str]: + """Resolve glob patterns for the current agent mode. + + Convention: In global mode (``GLOBAL_AGENT_NAME``), patterns containing + the ``{name}`` placeholder are excluded because they target specific + sub-agents. Shared/framework-level patterns (those without ``{name}``) + remain. New frameworks MUST follow this convention: use ``{name}`` in + patterns that are per-agent and omit it for shared/global patterns. + """ + if self._is_global(): + # Global mode: exclude patterns containing {name} placeholder. + return [p for p in self._effective_patterns() if "{name}" not in p] name = "*" if self._is_all() else self.agent_name return [p.format(name=name) for p in self._effective_patterns()] - def _matches(self, rel_path: str, patterns: List[str]) -> bool: + def matches(self, rel_path: str, patterns: List[str]) -> bool: + """Return True if *rel_path* matches any of the given glob *patterns*.""" for pattern in patterns: if fnmatch.fnmatch(rel_path, pattern): return True return False - def collect(self) -> Dict[str, str]: - """Gather allowed workspace files as {relative_path: text_content}.""" + def _walk_matched(self) -> List[Tuple[str, Path]]: + """Walk workspace and return (rel_path, Path) for matched files.""" root = self.workspace_root if not root.is_dir(): - return {} - patterns = self._resolved_patterns() - result: Dict[str, str] = {} + return [] + patterns = self.resolved_patterns() + matched: List[Tuple[str, Path]] = [] for dirpath, dirnames, filenames in os.walk(root): # Skip hidden directories in-place (prevents descending into them). dirnames[:] = sorted(d for d in dirnames if not d.startswith(".")) @@ -148,14 +165,38 @@ def collect(self) -> Dict[str, str]: rel = f.relative_to(root).as_posix() except ValueError: continue - if not self._matches(rel, patterns): + if not self.matches(rel, patterns): continue try: if f.stat().st_size > MAX_FILE_SIZE: continue - result[rel] = f.read_text(encoding="utf-8") - except (OSError, UnicodeDecodeError) as e: - logger.debug("Skip %s: %s", f, e) + except OSError: + continue + matched.append((rel, f)) + return matched + + def collect(self) -> Dict[str, str]: + """Gather allowed workspace files as {relative_path: text_content}.""" + result: Dict[str, str] = {} + for rel, f in self._walk_matched(): + try: + result[rel] = f.read_text(encoding="utf-8") + except (OSError, UnicodeDecodeError) as e: + logger.warning("Skip %s: %s", f, e) + return result + + def collect_bytes(self) -> Dict[str, bytes]: + """Gather allowed workspace files as {relative_path: raw_bytes}. + + Unlike :meth:`collect`, this includes binary files (images, PDFs, etc.) + and does not skip on UnicodeDecodeError. + """ + result: Dict[str, bytes] = {} + for rel, f in self._walk_matched(): + try: + result[rel] = f.read_bytes() + except OSError as e: + logger.warning("Skip %s: %s", f, e) return result def list_agents(self) -> List[str]: @@ -166,13 +207,20 @@ def list_agents(self) -> List[str]: """ return [DEFAULT_AGENT_NAME] + def _list_agents_from_dir(self, agents_dir: Path) -> List[str]: + """List agents from a directory, prepending DEFAULT if not present.""" + agents = _list_agent_files(agents_dir) + if DEFAULT_AGENT_NAME not in agents: + agents = [DEFAULT_AGENT_NAME] + agents + return agents + def apply(self, resources: Dict[str, str]) -> List[str]: """Write resource files back to the workspace. Returns list of written paths.""" root = self.workspace_root.resolve() written: List[str] = [] for rel_path, content in resources.items(): target = (root / rel_path).resolve() - if not str(target).startswith(str(root)): + if not target.is_relative_to(root): logger.warning("Path traversal blocked: %s", rel_path) continue target.parent.mkdir(parents=True, exist_ok=True) @@ -216,7 +264,7 @@ def patterns(self) -> List[str]: ] def list_agents(self) -> List[str]: - return _list_agent_files(self.workspace_root / "agents") + return self._list_agents_from_dir(self.workspace_root / "agents") class OpenclawWorkspaceAllowlist(ClawWorkspaceAllowlist): @@ -423,7 +471,7 @@ def patterns(self) -> List[str]: ] def list_agents(self) -> List[str]: - return _list_agent_files(self.workspace_root / "agents") + return self._list_agents_from_dir(self.workspace_root / "agents") def _list_agent_files(agents_dir: Path) -> List[str]: