diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf.handler.py b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf.handler.py new file mode 100644 index 000000000..eba118e6d --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf.handler.py @@ -0,0 +1,1689 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import os +import ssl +import urllib.parse +import uuid +from http.client import HTTPConnection, HTTPSConnection, RemoteDisconnected +from pathlib import Path +from typing import Any, Optional + +from flocks.config.config_writer import ConfigWriter +from flocks.security import get_secret_manager +from flocks.tool.registry import ToolContext, ToolResult + + +SERVICE_ID = "360_waf" +STORAGE_KEY = "360_waf_v5_5" +PRODUCT_VERSION = "5.5" + +BLOCKED_DEVICE_STATE_MUTATIONS: dict[str, set[str]] = { + "/rest/api/reboot_system": {"POST"}, + "/rest/api/mgmt_image": {"POST", "DELETE"}, + "/rest/api/signature": {"POST", "PUT"}, + "/rest/api/configfile": {"POST", "DELETE"}, + "/rest/api/waf_deploy_mode": {"PUT"}, + "/rest/api/licenseManagementAgent": {"PUT"}, + "/rest/api/interface": {"POST", "PUT", "DELETE"}, + "/rest/api/zone": {"POST", "PUT", "DELETE"}, +} + +BLOCKED_FILE_MUTATIONS: dict[str, set[str]] = { + "/rest/file/signature_import": {"POST"}, + "/rest/file/mgmt_import": {"POST"}, + "/rest/file/admind_image_upgrade": {"POST"}, + "/rest/file?fileName=tmp": {"DELETE"}, +} + +DOCUMENTED_API_METHODS: dict[str, list[str]] = { + "/rest/api/login": ["DELETE", "GET", "POST"], + "/rest/api/waf_attack_source_client_ip": ["GET"], + "/rest/api/waf_attack_source_map": ["GET"], + "/rest/api/waf_protection_type": ["GET"], + "/rest/api/waf_site_attack": ["GET"], + "/rest/api/website": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/wafacpolicy": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/blacklist": ["DELETE", "GET", "POST"], + "/rest/api/exceptionlist": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/site_global_blacklist": ["DELETE", "GET", "POST"], + "/rest/api/site_global_whitelist": ["DELETE", "GET", "POST"], + "/rest/api/wafpolicy": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/whitelist": ["DELETE", "GET", "POST"], + "/rest/api/configurationlog": ["GET", "POST"], + "/rest/api/loggerconfiguration": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/websecuritylog": ["GET"], + "/rest/api/ad": ["GET", "POST", "PUT"], + "/rest/api/interface": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/zone": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/configfile": ["DELETE", "GET", "POST"], + "/rest/api/licenseManagementAgent": ["GET", "PUT"], + "/rest/api/mgmt_image": ["DELETE", "GET", "POST"], + "/rest/api/sysinfo": ["GET"], + "/rest/api/waf_custom_error_page": ["DELETE", "GET", "POST", "PUT"], + "/rest/api/waf_deploy_mode": ["GET", "PUT"], + "/rest/api/signature": ["GET", "POST", "PUT"], + "/rest/api/capacity": ["GET"], + "/rest/api/disk_usage": ["GET"], + "/rest/api/reboot_system": ["POST"], + "/rest/api/file_exists_on_device": ["GET"], +} + +DOCUMENTED_FILE_ENDPOINTS: dict[str, list[str]] = { + "/rest/file/signature_import": ["POST"], + "/rest/file/mgmt_import": ["POST"], + "/rest/file/wafd_error_page": ["POST"], + "/rest/file/admind_image_upgrade": ["POST"], + "/rest/file?fileName=tmp": ["DELETE"], +} + + +class WafApiError(RuntimeError): + pass + + +class RuntimeConfig: + def __init__( + self, + *, + base_url: str, + username: str, + password: str, + verify_ssl: bool, + timeout: int, + ) -> None: + self.base_url = base_url + self.username = username + self.password = password + self.verify_ssl = verify_ssl + self.timeout = timeout + + +def _resolve_ref(value: Any) -> str: + if value is None: + return "" + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return get_secret_manager().get(value[len("{secret:") : -1]) or "" + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:") : -1], "") + return value + + +def _raw_service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + if not isinstance(raw, dict): + raw = ConfigWriter.get_api_service_raw(STORAGE_KEY) + return raw if isinstance(raw, dict) else {} + + +def _as_bool(value: Any, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, str): + text = value.strip().lower() + if text in {"1", "true", "yes", "on"}: + return True + if text in {"0", "false", "no", "off"}: + return False + return bool(value) + + +def _config_value(raw: dict[str, Any], *keys: str) -> Any: + for key in keys: + if key in raw and raw[key] is not None: + return raw[key] + custom_settings = raw.get("custom_settings") + if isinstance(custom_settings, dict): + for key in keys: + if key in custom_settings and custom_settings[key] is not None: + return custom_settings[key] + return None + + +def _resolve_verify_ssl(raw: dict[str, Any]) -> bool: + value = _config_value(raw, "verify_ssl", "ssl_verify") + if value is None: + value = os.getenv("WAF_VERIFY_SSL") + return _as_bool(value, False) + + +def _load_runtime_config() -> RuntimeConfig: + raw = _raw_service_config() + sm = get_secret_manager() + + base_url = ( + _resolve_ref(raw.get("base_url")) + or _resolve_ref(raw.get("baseUrl")) + or os.getenv("WAF_BASE_URL", "") + ).rstrip("/") + username = ( + _resolve_ref(raw.get("username")) + or sm.get("360_waf_v5_5_username") + or sm.get("360_waf_username") + or os.getenv("WAF_USERNAME", "") + ) + password = ( + _resolve_ref(raw.get("password")) + or sm.get("360_waf_v5_5_password") + or sm.get("360_waf_password") + or os.getenv("WAF_PASSWORD", "") + ) + timeout_value = raw.get("timeout") or os.getenv("WAF_TIMEOUT") or 30 + try: + timeout = int(timeout_value) + except (TypeError, ValueError): + timeout = 30 + + if not base_url: + raise WafApiError("360 WAF base_url is required") + if not username: + raise WafApiError("360 WAF username is required") + if not password: + raise WafApiError("360 WAF password is required") + + return RuntimeConfig( + base_url=base_url, + username=username, + password=password, + verify_ssl=_resolve_verify_ssl(raw), + timeout=timeout, + ) + + +class WafClient: + def __init__(self, config: RuntimeConfig) -> None: + self.config = config + self.base_url = config.base_url + self.username = config.username + self.password = config.password + self.verify_ssl = config.verify_ssl + self.timeout = config.timeout + self._session: Optional[dict[str, Any]] = None + + parsed = urllib.parse.urlparse(self.base_url) + if parsed.scheme not in {"http", "https"} or not parsed.hostname: + raise WafApiError("360 WAF base_url must be like https://YOUR_360_WAF_HOST") + self.scheme = parsed.scheme + self.host = parsed.hostname + self.port = parsed.port or (443 if parsed.scheme == "https" else 80) + + if self.scheme == "https": + if self.verify_ssl: + self.ssl_context = ssl.create_default_context() + else: + self.ssl_context = ssl._create_unverified_context() + try: + self.ssl_context.set_ciphers("DEFAULT:@SECLEVEL=0") + except Exception: + pass + else: + self.ssl_context = None + + def login(self) -> dict[str, Any]: + body = { + "lang": "zh_CN", + "username": self._b64(self.username), + "password": self._b64(self.password), + } + data = self._raw_request("POST", "/rest/api/login", body=body, cookie=None) + if data.get("success") is not True: + raise WafApiError(self._format_exception("login failed", data)) + result = data.get("result") + if not isinstance(result, list) or not result: + raise WafApiError("login response did not contain result[0]") + session = result[0] + for key in ("token", "fromrootvsys", "role", "vsysId"): + if key not in session or session[key] in (None, ""): + raise WafApiError(f"login response missing {key}") + self._session = session + return self._public_session(session) + + def check_login(self) -> dict[str, Any]: + return self.get("/rest/api/login") + + def logout(self) -> dict[str, Any]: + if not self._session: + self.login() + body = { + "username": self.username, + "protocol": self.scheme, + "token": self._session["token"], + "role": self._session["role"], + } + data = self._raw_request( + "DELETE", "/rest/api/login", body=body, cookie=self._cookie_header() + ) + self._session = None + return data + + def get( + self, path: str, query: Optional[dict[str, Any]] = None, retry: bool = True + ) -> dict[str, Any]: + if not self._session: + self.login() + request_path = self._build_path(path, query) + try: + data = self._raw_request("GET", request_path, cookie=self._cookie_header()) + except (RemoteDisconnected, ConnectionError, OSError): + if not retry: + raise + self.login() + data = self._raw_request("GET", request_path, cookie=self._cookie_header()) + + if self._is_invalid_login(data) and retry: + self.login() + data = self._raw_request("GET", request_path, cookie=self._cookie_header()) + return data + + def call_readonly( + self, resource: str, query: Optional[dict[str, Any]] = None + ) -> dict[str, Any]: + if not resource.startswith("/rest/api/"): + raise WafApiError("only /rest/api/... resources are allowed") + return self.get(resource, query=query) + + def request( + self, + method: str, + path: str, + query: Optional[dict[str, Any]] = None, + body: Optional[Any] = None, + retry: bool = True, + ) -> dict[str, Any]: + method = method.upper() + if method == "GET": + return self.get(path, query=query, retry=retry) + if method not in {"POST", "PUT", "DELETE"}: + raise WafApiError("method must be GET, POST, PUT, or DELETE") + if not self._session: + self.login() + request_path = self._build_path(path, query) + data = self._raw_request( + method, request_path, body=body, cookie=self._cookie_header() + ) + if self._is_invalid_login(data) and retry: + self.login() + data = self._raw_request( + method, request_path, body=body, cookie=self._cookie_header() + ) + return data + + def upload_file( + self, + path: str, + file_path: str, + fields: Optional[dict[str, Any]] = None, + retry: bool = True, + ) -> dict[str, Any]: + if not self._session: + self.login() + if not path.startswith("/rest/file/"): + raise WafApiError("file upload path must start with /rest/file/") + data = self._raw_file_upload( + path, file_path, fields=fields or {}, cookie=self._cookie_header() + ) + if self._is_invalid_login(data) and retry: + self.login() + data = self._raw_file_upload( + path, file_path, fields=fields or {}, cookie=self._cookie_header() + ) + return data + + def file_request(self, method: str, path: str, retry: bool = True) -> dict[str, Any]: + method = method.upper() + if method not in {"GET", "DELETE"}: + raise WafApiError("file request method must be GET or DELETE") + if not path.startswith("/rest/file"): + raise WafApiError("file request path must start with /rest/file") + if not self._session: + self.login() + data = self._raw_request(method, path, cookie=self._cookie_header()) + if self._is_invalid_login(data) and retry: + self.login() + data = self._raw_request(method, path, cookie=self._cookie_header()) + return data + + def download_file(self, path: str, save_path: str, retry: bool = True) -> dict[str, Any]: + if not path.startswith("/download/"): + raise WafApiError("download path must start with /download/") + if not self._session: + self.login() + status, headers, raw = self._raw_binary_request( + "GET", path, cookie=self._cookie_header() + ) + if status in {401, 403} and retry: + self.login() + status, headers, raw = self._raw_binary_request( + "GET", path, cookie=self._cookie_header() + ) + if status < 200 or status >= 300: + text = raw.decode("utf-8", errors="replace") + raise WafApiError(f"HTTP {status} from GET {path}: {text[:300]}") + target = Path(save_path).expanduser().resolve() + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(raw) + return { + "success": True, + "result": { + "path": str(target), + "bytes": len(raw), + "content_type": headers.get("content-type"), + }, + } + + def _raw_request( + self, method: str, path: str, body: Optional[Any] = None, cookie: Optional[str] = None + ) -> dict[str, Any]: + headers = { + "Host": self.host, + "Accept": "application/json", + "Connection": "close", + } + payload = None + if cookie: + headers["Cookie"] = cookie + if body is not None: + payload = json.dumps(body, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + headers["Content-Type"] = "application/json" + headers["Content-Length"] = str(len(payload)) + + if self.scheme == "https": + conn = HTTPSConnection( + self.host, self.port, context=self.ssl_context, timeout=self.timeout + ) + else: + conn = HTTPConnection(self.host, self.port, timeout=self.timeout) + + try: + conn.request(method, path, body=payload, headers=headers) + resp = conn.getresponse() + raw = resp.read() + finally: + conn.close() + + text = raw.decode("utf-8", errors="replace") + try: + data = json.loads(text) if text else {} + except json.JSONDecodeError as exc: + raise WafApiError( + f"non-json response from {method} {path}: HTTP {resp.status}, {text[:200]}" + ) from exc + + if resp.status < 200 or resp.status >= 300: + raise WafApiError(f"HTTP {resp.status} from {method} {path}: {text[:300]}") + return data + + def _raw_file_upload( + self, + path: str, + file_path: str, + fields: dict[str, Any], + cookie: Optional[str] = None, + ) -> dict[str, Any]: + source = Path(file_path).expanduser().resolve() + if not source.is_file(): + raise WafApiError(f"upload file not found: {source}") + boundary = "----wafmcp-" + uuid.uuid4().hex + parts: list[bytes] = [] + for key, value in fields.items(): + if value is None: + continue + parts.append(f"--{boundary}\r\n".encode("utf-8")) + parts.append(f'Content-Disposition: form-data; name="{key}"\r\n\r\n'.encode("utf-8")) + parts.append(str(value).encode("utf-8")) + parts.append(b"\r\n") + upload_name = str(fields.get("filename") or fields.get("clientFileName") or source.name) + parts.append(f"--{boundary}\r\n".encode("utf-8")) + parts.append( + ( + f'Content-Disposition: form-data; name="upload"; filename="{upload_name}"\r\n' + "Content-Type: application/octet-stream\r\n\r\n" + ).encode("utf-8") + ) + parts.append(source.read_bytes()) + parts.append(b"\r\n") + parts.append(f"--{boundary}--\r\n".encode("utf-8")) + payload = b"".join(parts) + + headers = { + "Host": self.host, + "Accept": "application/json", + "Connection": "close", + "Content-Type": f"multipart/form-data; boundary={boundary}", + "Content-Length": str(len(payload)), + } + if cookie: + headers["Cookie"] = cookie + status, _, raw = self._send_raw("POST", path, payload=payload, headers=headers) + text = raw.decode("utf-8", errors="replace") + try: + data = json.loads(text) if text else {} + except json.JSONDecodeError as exc: + raise WafApiError( + f"non-json response from POST {path}: HTTP {status}, {text[:200]}" + ) from exc + if status < 200 or status >= 300: + raise WafApiError(f"HTTP {status} from POST {path}: {text[:300]}") + return data + + def _raw_binary_request( + self, method: str, path: str, cookie: Optional[str] = None + ) -> tuple[int, dict[str, str], bytes]: + headers = { + "Host": self.host, + "Accept": "*/*", + "Connection": "close", + } + if cookie: + headers["Cookie"] = cookie + return self._send_raw(method, path, headers=headers) + + def _send_raw( + self, + method: str, + path: str, + payload: Optional[bytes] = None, + headers: Optional[dict[str, str]] = None, + ) -> tuple[int, dict[str, str], bytes]: + if self.scheme == "https": + conn = HTTPSConnection( + self.host, self.port, context=self.ssl_context, timeout=self.timeout + ) + else: + conn = HTTPConnection(self.host, self.port, timeout=self.timeout) + try: + conn.request(method, path, body=payload, headers=headers or {}) + resp = conn.getresponse() + raw = resp.read() + resp_headers = {key.lower(): value for key, value in resp.getheaders()} + status = resp.status + finally: + conn.close() + return status, resp_headers, raw + + def _cookie_header(self) -> str: + if not self._session: + raise WafApiError("not logged in") + session = self._session + pairs = { + "username": self.username, + "token": session["token"], + "fromrootvsys": session["fromrootvsys"], + "role": session["role"], + "vsysId": session["vsysId"], + } + if session.get("platform"): + pairs["platform"] = session["platform"] + return "; ".join(f"{key}={value}" for key, value in pairs.items()) + + def _build_path(self, path: str, query: Optional[dict[str, Any]]) -> str: + if not path.startswith("/"): + path = "/" + path + if not path.startswith("/rest/api/"): + raise WafApiError("only /rest/api/... paths are allowed") + if query is None: + return path + sep = "&" if "?" in path else "?" + encoded = urllib.parse.quote( + json.dumps(query, ensure_ascii=False, separators=(",", ":")), safe="" + ) + return f"{path}{sep}query={encoded}" + + @staticmethod + def _b64(value: str) -> str: + return base64.b64encode(value.encode("utf-8")).decode("ascii") + + @staticmethod + def _is_invalid_login(data: dict[str, Any]) -> bool: + exception = data.get("exception") + if isinstance(exception, dict): + code = str(exception.get("code", "")) + msg = str(exception.get("message", "")).lower() + return code in {"400000005", "loginError_1002"} or "invalid login" in msg + return False + + @staticmethod + def _format_exception(prefix: str, data: dict[str, Any]) -> str: + exception = data.get("exception") + if exception: + return f"{prefix}: {exception}" + return f"{prefix}: {data}" + + @staticmethod + def _public_session(session: dict[str, Any]) -> dict[str, Any]: + token = str(session.get("token", "")) + masked = token[:6] + "..." + token[-4:] if len(token) > 10 else "***" + return { + "token": masked, + "fromrootvsys": session.get("fromrootvsys"), + "vsysId": session.get("vsysId"), + "vsysName": session.get("vsysName"), + "role": session.get("role"), + "isLocalAuth": session.get("isLocalAuth"), + } + + +_CLIENTS: dict[tuple[str, str, bool], WafClient] = {} + + +def _client_cache_key(config: RuntimeConfig) -> tuple[str, str, bool]: + return (config.base_url, config.username, config.verify_ssl) + + +def get_client() -> WafClient: + config = _load_runtime_config() + key = _client_cache_key(config) + client = _CLIENTS.get(key) + if client is None or client.password != config.password: + client = WafClient(config) + _CLIENTS[key] = client + return client + + +def ok(content: Any) -> ToolResult: + return ToolResult( + success=True, + output=content, + metadata={"source": "360 WAF", "version": PRODUCT_VERSION}, + ) + + +def require_int(value: Any, name: str) -> int: + try: + return int(value) + except Exception as exc: + raise WafApiError(f"{name} must be an integer") from exc + + +def build_conditions(args: dict[str, Any], allowed: dict[str, str]) -> list[dict[str, Any]]: + conditions: list[dict[str, Any]] = [] + for arg_name, field_name in allowed.items(): + if arg_name in args and args[arg_name] not in (None, ""): + conditions.append({"field": field_name, "operator": 0, "value": args[arg_name]}) + return conditions + + +def add_paging(query: dict[str, Any], args: dict[str, Any], default_limit: int = 50) -> None: + query["start"] = require_int(args.get("start", 0), "start") + query["limit"] = require_int(args.get("limit", default_limit), "limit") + if query["limit"] > 500: + raise WafApiError("limit must be <= 500") + + +def waf_check_login(args: dict[str, Any]) -> ToolResult: + return ok(get_client().check_login()) + + +def waf_system_info_get(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/sysinfo")) + + +def waf_site_list(args: dict[str, Any]) -> ToolResult: + query: Optional[dict[str, Any]] = None + conditions = [] + if args.get("id") not in (None, ""): + conditions.append({"field": "id", "value": str(args["id"])}) + if args.get("name") not in (None, ""): + conditions.append({"field": "name", "operator": 6, "value": args["name"]}) + if conditions: + query = {"conditions": conditions} + return ok(get_client().get("/rest/api/website", query=query)) + + +def waf_policy_list(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/wafpolicy")) + + +def waf_ac_policy_list(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/wafacpolicy")) + + +def waf_interface_list(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/interface")) + + +def waf_zone_list(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/zone")) + + +def waf_blacklist_list(args: dict[str, Any]) -> ToolResult: + site_id = require_int(args.get("siteId"), "siteId") + list_type = require_int(args.get("type"), "type") + query = { + "conditions": [ + {"field": "siteId", "value": site_id}, + {"field": "type", "value": list_type}, + ] + } + return ok(get_client().get("/rest/api/blacklist", query=query)) + + +def waf_whitelist_list(args: dict[str, Any]) -> ToolResult: + site_id = require_int(args.get("id"), "id") + query: dict[str, Any] = {"conditions": [{"field": "id", "value": site_id}]} + if args.get("type") not in (None, ""): + query["conditions"].append({"field": "type", "value": require_int(args["type"], "type")}) + return ok(get_client().get("/rest/api/whitelist", query=query)) + + +def waf_whitelist_check_ip(args: dict[str, Any]) -> ToolResult: + site_id = require_int(args.get("id"), "id") + ip = args.get("ip") + if not ip: + raise WafApiError("ip is required") + query = { + "conditions": [ + {"field": "id", "value": site_id}, + {"field": "is_ip_in_whitelist.ip", "value": ip}, + ] + } + return ok(get_client().get("/rest/api/whitelist", query=query)) + + +def waf_security_log_search(args: dict[str, Any]) -> ToolResult: + allowed_intervals = {"realtime", "hour", "day", "week", "month"} + has_custom_time = args.get("time_start") not in (None, "") or args.get("time_end") not in (None, "") + interval = args.get("interval", None if has_custom_time else "hour") + conditions: list[dict[str, Any]] = [] + if interval: + if interval not in allowed_intervals: + raise WafApiError(f"interval must be one of {sorted(allowed_intervals)}") + conditions.append({"field": "interval", "operator": 0, "value": interval}) + + for arg_name, field_name in (("time_start", "time_start"), ("time_end", "time_end")): + if args.get(arg_name) not in (None, ""): + conditions.append({"field": field_name, "operator": 0, "value": args[arg_name]}) + + if args.get("severity") not in (None, ""): + conditions.append( + {"field": "severity", "operator": 0, "value": require_int(args["severity"], "severity")} + ) + + conditions.extend( + build_conditions( + args, + { + "client_ip": "client_ip", + "server_ip": "server_ip", + "site_name": "site_name", + "policy_name": "policy_name", + "domain_name": "domain_name", + "http_url": "http_url", + "http_method": "http_method", + "rule_id": "rule_id", + "protection_type": "protection_type", + "protection_sub_type": "protection_sub_type", + }, + ) + ) + action_filter = first_present(args, "action_filter", "log_action") + if action_filter not in (None, ""): + conditions.append({"field": "action", "operator": 0, "value": action_filter}) + query: dict[str, Any] = {"conditions": conditions} + add_paging(query, args, default_limit=50) + return ok(get_client().get("/rest/api/websecuritylog", query=query)) + + +def waf_configuration_log_search(args: dict[str, Any]) -> ToolResult: + query: dict[str, Any] = {} + time_start = args.get("time_start") + time_end = args.get("time_end") + if time_start not in (None, "") or time_end not in (None, ""): + life_time: dict[str, Any] = {"interval": "custom"} + if time_start not in (None, ""): + life_time["start"] = time_start + if time_end not in (None, ""): + life_time["end"] = time_end + query["lifeTime"] = life_time + elif args.get("interval") not in (None, ""): + query["lifeTime"] = {"interval": args["interval"]} + + conditions = build_conditions(args, {"msg": "msg"}) + if conditions: + query["conditions"] = conditions + add_paging(query, args, default_limit=50) + return ok(get_client().get("/rest/api/configurationlog", query=query)) + + +def waf_dashboard_stats(args: dict[str, Any]) -> ToolResult: + kind = args.get("kind") + mapping = { + "attack_source_ip": "/rest/api/waf_attack_source_client_ip", + "attack_source_country": "/rest/api/waf_attack_source_map", + "threat_category": "/rest/api/waf_protection_type", + "site_attack": "/rest/api/waf_site_attack", + } + if kind not in mapping: + raise WafApiError(f"kind must be one of {sorted(mapping)}") + interval = args.get("interval") + query = None + if interval: + query = {"conditions": [{"field": "interval", "value": interval}]} + return ok(get_client().get(mapping[kind], query=query)) + + +def waf_configfile_list(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/configfile")) + + +def waf_signature_status(args: dict[str, Any]) -> ToolResult: + query_status = bool(args.get("queryStatus", False)) + path = "/rest/api/signature?query=%7B%22conditions%22%3A%5B%7B%22field%22%3A%22index%22%2C%22value%22%3A9%7D%5D%7D" + if query_status: + path += "&queryStatus=1" + return ok(get_client().get(path)) + + +def waf_deploy_mode_get(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/waf_deploy_mode")) + + +def waf_license_get(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/licenseManagementAgent")) + + +def waf_custom_error_page_list(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/waf_custom_error_page")) + + +def waf_mgmt_image_get(args: dict[str, Any]) -> ToolResult: + version = require_int(args.get("version", 1), "version") + query = {"conditions": [{"field": "version", "value": version}]} + return ok(get_client().get("/rest/api/mgmt_image", query=query)) + + +def waf_disk_usage_get(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/disk_usage")) + + +def waf_capacity_get(args: dict[str, Any]) -> ToolResult: + return ok(get_client().get("/rest/api/capacity")) + + +def waf_logout(args: dict[str, Any]) -> ToolResult: + return ok(get_client().logout()) + + +def require_text(value: Any, name: str) -> str: + if value in (None, ""): + raise WafApiError(f"{name} is required") + text = str(value).strip() + if not text: + raise WafApiError(f"{name} is required") + if len(text) > 127: + raise WafApiError(f"{name} must be <= 127 characters") + return text + + +def require_uri_path(value: Any) -> str: + uri_path = require_text(value, "uri_path") + if not uri_path.startswith("/"): + raise WafApiError("uri_path must start with '/'") + if "://" in uri_path: + raise WafApiError("uri_path must be a path, not a full URL") + return uri_path + + +def require_policy_id(value: Any, name: str) -> str: + if value in (None, ""): + raise WafApiError(f"{name} is required") + policy_id = str(value).strip() + if not policy_id.isdigit(): + raise WafApiError(f"{name} must be a numeric id") + return policy_id + + +def require_status_code(value: Any) -> int: + status_code = require_int(value, "status_code") + allowed = {400, 403, 404, 405, 500, 501, 505} + if status_code not in allowed: + raise WafApiError(f"status_code must be one of {sorted(allowed)}") + return status_code + + +def require_flag(value: Any, name: str) -> int: + flag = require_int(value, name) + if flag not in {0, 1}: + raise WafApiError(f"{name} must be 0 or 1") + return flag + + +def first_present(args: dict[str, Any], *names: str) -> Any: + for name in names: + value = args.get(name) + if value not in (None, ""): + return value + return None + + +def require_choice_int(value: Any, name: str, allowed: set[int]) -> int: + number = require_int(value, name) + if number not in allowed: + raise WafApiError(f"{name} must be one of {sorted(allowed)}") + return number + + +def require_payload(value: Any, name: str = "body") -> Any: + if isinstance(value, str): + try: + value = json.loads(value) + except (TypeError, ValueError, json.JSONDecodeError) as exc: + raise WafApiError(f"{name} must be valid JSON") from exc + if not isinstance(value, (dict, list)): + raise WafApiError(f"{name} must be a JSON object or array") + return value + + +def optional_payload(value: Any, name: str = "body") -> Any: + if value in (None, ""): + return None + return require_payload(value, name) + + +def infer_ip_version(ip_start: str) -> int: + return 1 if ":" in ip_start else 0 + + +def infer_blacklist_type(content: str) -> int: + if "-" in content: + return 3 + if "/" in content: + return 4 + return 1 + + +def require_block_time(value: Any) -> int: + block_time = require_int(value, "block_time") + if block_time < 1 or block_time > 1440: + raise WafApiError("block_time must be between 1 and 1440 minutes") + return block_time + + +def blacklist_body(args: dict[str, Any], *, include_site: bool, include_is_permanent: bool) -> list[dict[str, Any]]: + content = require_text(first_present(args, "content", "ip"), "content") + list_type = require_choice_int(args.get("type", infer_blacklist_type(content)), "type", {1, 3, 4}) + entry: dict[str, Any] = {"type": list_type, "content": content} + if include_site: + entry["siteId"] = require_int(first_present(args, "siteId", "site_id"), "siteId") + if include_is_permanent: + is_permanent = str(require_flag(args.get("is_permanent", 1), "is_permanent")) + entry["is_permanent"] = is_permanent + if is_permanent == "0" and args.get("block_time") not in (None, ""): + entry["block_time"] = require_block_time(args["block_time"]) + if include_site: + return [{"siteId": entry.pop("siteId"), **entry}] + return [entry] + + +def whitelist_ip_entry(args: dict[str, Any], *, include_site: bool, for_delete: bool) -> dict[str, Any]: + ip_start = require_text(first_present(args, "ip_start", "ip"), "ip_start") + ip_ver = require_choice_int(args.get("ip_ver", infer_ip_version(ip_start)), "ip_ver", {0, 1}) + list_type = require_choice_int(args.get("type", 0), "type", {0, 1, 2}) + entry: dict[str, Any] = { + "ip_ver": str(ip_ver), + "type": str(list_type), + "ip_start": ip_start, + } + if list_type == 1: + default_netmask = 128 if ip_ver == 1 else 32 + entry["netmask"] = require_int(args.get("netmask", default_netmask), "netmask") + elif args.get("netmask") not in (None, ""): + entry["netmask"] = require_int(args["netmask"], "netmask") + + if list_type == 2: + entry["ip_end"] = require_text(args.get("ip_end"), "ip_end") + elif args.get("ip_end") not in (None, ""): + entry["ip_end"] = require_text(args.get("ip_end"), "ip_end") + + if for_delete: + entry.setdefault("ip_end", "::" if ip_ver == 1 else "0") + entry.setdefault("netmask", 128 if ip_ver == 1 else 32) + elif args.get("desc") not in (None, ""): + entry["desc"] = str(args["desc"]) + + if not include_site: + return entry + site_id = require_int(first_present(args, "id", "site_id", "siteId"), "id") + return {"id": site_id, "ip_whitelist": entry} + + +def whitelist_body(args: dict[str, Any], *, for_delete: bool) -> dict[str, Any]: + return whitelist_ip_entry(args, include_site=True, for_delete=for_delete) + + +def global_whitelist_body(args: dict[str, Any], *, for_delete: bool) -> list[dict[str, Any]]: + return [whitelist_ip_entry(args, include_site=False, for_delete=for_delete)] + + +def build_deny_uri_policy_body( + args: dict[str, Any], + policy_name: str, + uri_path: str, + status_code: int, +) -> list[dict[str, Any]]: + operator = str(args.get("operator", "location")) + if operator not in {"location", "rx"}: + raise WafApiError("operator must be 'location' or 'rx'") + body: dict[str, Any] = { + "name": policy_name, + "action": "deny", + "status_code": status_code, + "description": str(args.get("description") or "Created by 360_waf Flocks integration"), + "capture_pkt": require_flag(args.get("capture_pkt", 1), "capture_pkt"), + "log": require_flag(args.get("log", 1), "log"), + "uri_path_list": { + "enable": 1, + "operator": operator, + "is_negative": 0, + "encode": str(args.get("encode", "UTF-8")), + "no_case": require_flag(args.get("no_case", 1), "no_case"), + "uri_path": [{"pattern": uri_path}], + }, + } + http_method = args.get("http_method") + if http_method not in (None, ""): + body["http_method"] = {"enable": 1, "is_negative": 0, "pattern": str(http_method).lower()} + return [body] + + +def result_as_list(data: dict[str, Any]) -> list[Any]: + result = data.get("result") + if isinstance(result, list): + return result + if isinstance(result, dict): + return [result] + return [] + + +def ensure_api_success(label: str, data: dict[str, Any]) -> None: + if data.get("success") is not True: + raise WafApiError(f"{label} failed: {json.dumps(data, ensure_ascii=False)}") + + +def create_deny_uri_policy( + client_obj: WafClient, body: list[dict[str, Any]], policy_name: str +) -> dict[str, Any]: + post_result = client_obj.request("POST", "/rest/api/wafacpolicy", body=body) + ensure_api_success("POST /rest/api/wafacpolicy", post_result) + result_items = result_as_list(post_result) + policy = result_items[0] if result_items else None + if not isinstance(policy, dict) or not policy.get("id"): + policy = find_ac_policy_by_name(client_obj, policy_name) + if not isinstance(policy, dict) or not policy.get("id"): + raise WafApiError("POST /rest/api/wafacpolicy returned success but no created policy id was found") + policy_id = require_policy_id(policy.get("id"), "created policy id") + verify = client_obj.get("/rest/api/wafacpolicy", query={"conditions": [{"field": "id", "value": policy_id}]}) + return {"policy_id": policy_id, "policy": policy, "post": post_result, "verify": verify} + + +def resolve_site(client_obj: WafClient, args: dict[str, Any]) -> dict[str, Any]: + conditions: list[dict[str, Any]] = [] + if args.get("site_id") not in (None, ""): + conditions.append({"field": "id", "value": str(args["site_id"])}) + else: + site_name = str(args.get("site_name") or "default") + conditions.append({"field": "name", "operator": 0, "value": site_name}) + data = client_obj.get("/rest/api/website", query={"conditions": conditions}) + result_items = result_as_list(data) + if not result_items: + raise WafApiError(f"site not found for conditions: {conditions}") + site = result_items[0] + if not isinstance(site, dict) or not site.get("id"): + raise WafApiError("site lookup did not return a valid site object") + return site + + +def bind_ac_policy_to_site( + client_obj: WafClient, + site: dict[str, Any], + policy_id: str, + position: str = "append", +) -> dict[str, Any]: + before_ids = site_ac_policy_ids(site) + after_ids = insert_policy_id(before_ids, policy_id, position) + if after_ids == before_ids: + return {"changed": False, "before": before_ids, "after": after_ids, "verify": site} + put_result = update_site_ac_policy(client_obj, site, after_ids) + return {"changed": True, "before": before_ids, "after": after_ids, **put_result} + + +def unbind_ac_policy_from_site(client_obj: WafClient, site: dict[str, Any], policy_id: str) -> dict[str, Any]: + before_ids = site_ac_policy_ids(site) + after_ids = [item for item in before_ids if item != policy_id] + if after_ids == before_ids: + return {"changed": False, "before": before_ids, "after": after_ids, "verify": site} + put_result = update_site_ac_policy(client_obj, site, after_ids) + return {"changed": True, "before": before_ids, "after": after_ids, **put_result} + + +def update_site_ac_policy(client_obj: WafClient, site: dict[str, Any], policy_ids: list[str]) -> dict[str, Any]: + site_id = str(site.get("id")) + body = { + "id": site_id, + "name": str(site.get("name") or ""), + "ac_policy": ";".join(policy_ids), + } + put_result = client_obj.request("PUT", "/rest/api/website", body=body) + ensure_api_success("PUT /rest/api/website", put_result) + verify = client_obj.get("/rest/api/website", query={"conditions": [{"field": "id", "value": site_id}]}) + return {"body": body, "put": put_result, "verify": verify} + + +def site_ac_policy_ids(site: dict[str, Any]) -> list[str]: + raw = site.get("ac_policy") + ids: list[str] = [] + if raw in (None, ""): + return ids + if isinstance(raw, str): + for item in raw.replace(",", ";").split(";"): + add_policy_id(ids, item) + elif isinstance(raw, list): + for item in raw: + if isinstance(item, dict): + add_policy_id(ids, item.get("id")) + else: + add_policy_id(ids, item) + else: + add_policy_id(ids, raw) + return ids + + +def add_policy_id(ids: list[str], value: Any) -> None: + if value in (None, ""): + return + policy_id = str(value).strip() + if policy_id and policy_id not in ids: + ids.append(policy_id) + + +def insert_policy_id(values: list[str], value: str, position: str) -> list[str]: + if position not in {"append", "prepend"}: + raise WafApiError("position must be 'append' or 'prepend'") + output = [item for item in values if item != value] + if position == "prepend": + output.insert(0, value) + else: + output.append(value) + return output + + +def find_ac_policy_by_name(client_obj: WafClient, policy_name: str) -> Optional[dict[str, Any]]: + data = client_obj.get("/rest/api/wafacpolicy") + for item in result_as_list(data): + if isinstance(item, dict) and item.get("name") == policy_name: + return item + return None + + +def policy_is_deny_uri(policy: dict[str, Any], uri_path: str) -> bool: + if str(policy.get("action", "")).lower() != "deny": + return False + uri_path_list = policy.get("uri_path_list") + if not isinstance(uri_path_list, dict): + return False + if str(uri_path_list.get("enable", "0")) != "1": + return False + uri_entries = uri_path_list.get("uri_path") + if isinstance(uri_entries, dict): + return uri_entries.get("pattern") == uri_path + if isinstance(uri_entries, list): + return any(isinstance(item, dict) and item.get("pattern") == uri_path for item in uri_entries) + return False + + +def delete_ac_policy(client_obj: WafClient, policy_id: str) -> dict[str, Any]: + data = client_obj.request("DELETE", "/rest/api/wafacpolicy", body=[{"id": policy_id}]) + ensure_api_success("DELETE /rest/api/wafacpolicy", data) + return data + + +def default_deny_policy_name(site_name: str, uri_path: str) -> str: + raw = f"{site_name}_{uri_path.strip('/') or 'root'}".lower() + chars = [] + for char in raw: + if char.isalnum(): + chars.append(char) + else: + chars.append("_") + compact = "_".join(part for part in "".join(chars).split("_") if part) + return ("mcp_deny_" + compact)[:127] + + +def waf_ac_policy_create_deny_uri(args: dict[str, Any]) -> ToolResult: + client_obj = get_client() + policy_name = require_text(args.get("name"), "name") + uri_path = require_uri_path(args.get("uri_path")) + status_code = require_status_code(args.get("status_code", 403)) + body = build_deny_uri_policy_body(args, policy_name, uri_path, status_code) + created = create_deny_uri_policy(client_obj, body, policy_name) + return ok(created) + + +def waf_site_bind_ac_policy(args: dict[str, Any]) -> ToolResult: + client_obj = get_client() + policy_id = require_policy_id(args.get("policy_id"), "policy_id") + site = resolve_site(client_obj, args) + position = str(args.get("position") or "append") + return ok(bind_ac_policy_to_site(client_obj, site, policy_id, position=position)) + + +def waf_site_unbind_ac_policy(args: dict[str, Any]) -> ToolResult: + client_obj = get_client() + policy_id = require_policy_id(args.get("policy_id"), "policy_id") + site = resolve_site(client_obj, args) + return ok(unbind_ac_policy_from_site(client_obj, site, policy_id)) + + +def waf_ac_policy_delete(args: dict[str, Any]) -> ToolResult: + client_obj = get_client() + policy_id = require_policy_id(args.get("policy_id"), "policy_id") + if policy_id == "1": + raise WafApiError("refusing to delete built-in access-control policy id=1") + delete_result = delete_ac_policy(client_obj, policy_id) + verify = client_obj.get("/rest/api/wafacpolicy", query={"conditions": [{"field": "id", "value": policy_id}]}) + return ok({"policy_id": policy_id, "delete": delete_result, "verify": verify}) + + +def waf_blacklist_create(args: dict[str, Any]) -> ToolResult: + body = blacklist_body(args, include_site=True, include_is_permanent=True) + data = get_client().request("POST", "/rest/api/blacklist", body=body) + ensure_api_success("POST /rest/api/blacklist", data) + return ok(data) + + +def waf_blacklist_delete(args: dict[str, Any]) -> ToolResult: + body = blacklist_body(args, include_site=True, include_is_permanent=False) + data = get_client().request("DELETE", "/rest/api/blacklist", body=body) + ensure_api_success("DELETE /rest/api/blacklist", data) + return ok(data) + + +def waf_site_global_blacklist_create(args: dict[str, Any]) -> ToolResult: + body = blacklist_body(args, include_site=False, include_is_permanent=True) + data = get_client().request("POST", "/rest/api/site_global_blacklist", body=body) + ensure_api_success("POST /rest/api/site_global_blacklist", data) + return ok(data) + + +def waf_site_global_blacklist_delete(args: dict[str, Any]) -> ToolResult: + body = blacklist_body(args, include_site=False, include_is_permanent=False) + data = get_client().request("DELETE", "/rest/api/site_global_blacklist", body=body) + ensure_api_success("DELETE /rest/api/site_global_blacklist", data) + return ok(data) + + +def waf_whitelist_create(args: dict[str, Any]) -> ToolResult: + body = whitelist_body(args, for_delete=False) + data = get_client().request("POST", "/rest/api/whitelist", body=body) + ensure_api_success("POST /rest/api/whitelist", data) + return ok(data) + + +def waf_whitelist_delete(args: dict[str, Any]) -> ToolResult: + body = whitelist_body(args, for_delete=True) + data = get_client().request("DELETE", "/rest/api/whitelist", body=body) + ensure_api_success("DELETE /rest/api/whitelist", data) + return ok(data) + + +def waf_site_global_whitelist_create(args: dict[str, Any]) -> ToolResult: + body = global_whitelist_body(args, for_delete=False) + data = get_client().request("POST", "/rest/api/site_global_whitelist", body=body) + ensure_api_success("POST /rest/api/site_global_whitelist", data) + return ok(data) + + +def waf_site_global_whitelist_delete(args: dict[str, Any]) -> ToolResult: + body = global_whitelist_body(args, for_delete=True) + data = get_client().request("DELETE", "/rest/api/site_global_whitelist", body=body) + ensure_api_success("DELETE /rest/api/site_global_whitelist", data) + return ok(data) + + +def waf_exception_rule_create(args: dict[str, Any]) -> ToolResult: + body = require_payload(args.get("body")) + data = get_client().request("POST", "/rest/api/exceptionlist", body=body) + ensure_api_success("POST /rest/api/exceptionlist", data) + return ok(data) + + +def waf_exception_rule_update(args: dict[str, Any]) -> ToolResult: + body = require_payload(args.get("body")) + data = get_client().request("PUT", "/rest/api/exceptionlist", body=body) + ensure_api_success("PUT /rest/api/exceptionlist", data) + return ok(data) + + +def waf_exception_rule_delete(args: dict[str, Any]) -> ToolResult: + body = require_payload(args.get("body")) + data = get_client().request("DELETE", "/rest/api/exceptionlist", body=body) + ensure_api_success("DELETE /rest/api/exceptionlist", data) + return ok(data) + + +def waf_uri_block_on_site(args: dict[str, Any]) -> ToolResult: + client_obj = get_client() + uri_path = require_uri_path(args.get("uri_path")) + site = resolve_site(client_obj, args) + site_name = str(site.get("name") or args.get("site_name") or "site") + policy_name = str(args.get("policy_name") or default_deny_policy_name(site_name, uri_path)) + status_code = require_status_code(args.get("status_code", 403)) + reuse_existing = bool(args.get("reuse_existing", True)) + + existing = find_ac_policy_by_name(client_obj, policy_name) if reuse_existing else None + if existing is not None and not policy_is_deny_uri(existing, uri_path): + raise WafApiError(f"access-control policy named {policy_name!r} exists but is not a deny policy for {uri_path}") + created_new = existing is None + if existing is None: + body = build_deny_uri_policy_body(args, policy_name, uri_path, status_code) + created = create_deny_uri_policy(client_obj, body, policy_name) + policy = created["policy"] + else: + created = {"policy_id": str(existing["id"]), "policy": existing, "post": None, "verify": None} + policy = existing + + policy_id = require_policy_id(policy.get("id"), "created policy id") + try: + bind_result = bind_ac_policy_to_site(client_obj, site, policy_id, position="prepend") + except Exception: + if created_new: + try: + delete_ac_policy(client_obj, policy_id) + except Exception: + pass + raise + + return ok( + { + "site": {"id": str(site.get("id")), "name": site.get("name")}, + "uri_path": uri_path, + "policy_id": policy_id, + "policy_name": policy_name, + "created_new_policy": created_new, + "create": created, + "bind": bind_result, + } + ) + + +def waf_uri_unblock_on_site(args: dict[str, Any]) -> ToolResult: + client_obj = get_client() + site = resolve_site(client_obj, args) + policy_id = args.get("policy_id") + if policy_id in (None, ""): + uri_path = require_uri_path(args.get("uri_path")) + site_name = str(site.get("name") or args.get("site_name") or "site") + policy_name = str(args.get("policy_name") or default_deny_policy_name(site_name, uri_path)) + policy = find_ac_policy_by_name(client_obj, policy_name) + if policy is None: + raise WafApiError(f"access-control policy not found by name: {policy_name}") + if not policy_is_deny_uri(policy, uri_path): + raise WafApiError(f"access-control policy named {policy_name!r} is not a deny policy for {uri_path}") + policy_id = policy.get("id") + policy_id = require_policy_id(policy_id, "policy_id") + unbind_result = unbind_ac_policy_from_site(client_obj, site, policy_id) + delete_policy = bool(args.get("delete_policy", True)) + delete_result = None + if delete_policy: + if policy_id == "1": + raise WafApiError("refusing to delete built-in access-control policy id=1") + delete_result = delete_ac_policy(client_obj, policy_id) + verify = client_obj.get("/rest/api/wafacpolicy", query={"conditions": [{"field": "id", "value": policy_id}]}) + return ok({"policy_id": policy_id, "unbind": unbind_result, "delete": delete_result, "verify_policy": verify}) + + +def waf_api_catalog(args: dict[str, Any]) -> ToolResult: + return ok( + { + "documented_rest_api_resources": DOCUMENTED_API_METHODS, + "covered_by": { + "GET": "waf_call_raw_readonly or waf_call_api", + "POST_PUT_DELETE": "waf_call_mutation or waf_call_api", + }, + "file_upload_endpoints": DOCUMENTED_FILE_ENDPOINTS, + "file_tools": { + "POST rest/file/...": "waf_file_upload", + "DELETE rest/file?fileName=tmp": "waf_file_request", + "GET /download/...": "waf_download_file", + }, + "specialized_mutation_tools": { + "deny URI policy": "waf_ac_policy_create_deny_uri", + "bind access-control policy to site": "waf_site_bind_ac_policy", + "unbind access-control policy from site": "waf_site_unbind_ac_policy", + "delete access-control policy": "waf_ac_policy_delete", + "block URI on site in one step": "waf_uri_block_on_site", + "unblock URI on site in one step": "waf_uri_unblock_on_site", + "site blacklist": "waf_blacklist_create / waf_blacklist_delete", + "global blacklist": "waf_site_global_blacklist_create / waf_site_global_blacklist_delete", + "site whitelist": "waf_whitelist_create / waf_whitelist_delete", + "global whitelist": "waf_site_global_whitelist_create / waf_site_global_whitelist_delete", + "exception rules": "waf_exception_rule_create / waf_exception_rule_update / waf_exception_rule_delete", + }, + } + ) + + +def waf_call_raw_readonly(args: dict[str, Any]) -> ToolResult: + path = args.get("path") + if not path: + raise WafApiError("path is required") + api_path = normalize_api_path(str(path)) + validate_documented_api("GET", api_path) + query = args.get("query") + if query is not None and not isinstance(query, dict): + raise WafApiError("query must be an object") + return ok(get_client().call_readonly(api_path, query=query)) + + +def waf_call_mutation(args: dict[str, Any]) -> ToolResult: + method = str(args.get("method", "")).upper() + path = normalize_api_path(str(args.get("path", ""))) + if method not in {"POST", "PUT", "DELETE"}: + raise WafApiError("method must be POST, PUT, or DELETE") + validate_documented_api(method, path) + reject_blocked_device_state_mutation(method, path) + query = args.get("query") + if query is not None and not isinstance(query, dict): + raise WafApiError("query must be an object") + body = optional_payload(args.get("body")) + return ok(get_client().request(method, path, query=query, body=body)) + + +def waf_call_api(args: dict[str, Any]) -> ToolResult: + method = str(args.get("method", "GET")).upper() + path = normalize_api_path(str(args.get("path", ""))) + if method not in {"GET", "POST", "PUT", "DELETE"}: + raise WafApiError("method must be GET, POST, PUT, or DELETE") + validate_documented_api(method, path) + query = args.get("query") + if query is not None and not isinstance(query, dict): + raise WafApiError("query must be an object") + if method == "GET": + return ok(get_client().get(path, query=query)) + return waf_call_mutation(args) + + +def waf_file_upload(args: dict[str, Any]) -> ToolResult: + path = normalize_file_path(str(args.get("path", ""))) + validate_documented_file_api("POST", path) + reject_blocked_file_mutation("POST", path) + file_path = str(args.get("file_path", "")) + fields = args.get("fields") or {} + if not file_path: + raise WafApiError("file_path is required") + if not isinstance(fields, dict): + raise WafApiError("fields must be an object") + return ok(get_client().upload_file(path, file_path, fields=fields)) + + +def waf_file_request(args: dict[str, Any]) -> ToolResult: + method = str(args.get("method", "")).upper() + path = normalize_file_path(str(args.get("path", ""))) + validate_documented_file_api(method, path) + reject_blocked_file_mutation(method, path) + return ok(get_client().file_request(method, path)) + + +def waf_download_file(args: dict[str, Any]) -> ToolResult: + path = normalize_download_path(str(args.get("path", ""))) + save_path = str(args.get("save_path", "")) + if not save_path: + raise WafApiError("save_path is required") + return ok(get_client().download_file(path, save_path)) + + +def normalize_api_path(path: str) -> str: + if path.startswith("rest/api/"): + path = "/" + path + if not path.startswith("/rest/api/"): + raise WafApiError("path must start with rest/api/ or /rest/api/") + return path + + +def normalize_file_path(path: str) -> str: + if path.startswith("rest/file"): + path = "/" + path + if not path.startswith("/rest/file"): + raise WafApiError("path must start with rest/file or /rest/file") + return path + + +def normalize_download_path(path: str) -> str: + if path.startswith("download/"): + path = "/" + path + if not path.startswith("/download/"): + raise WafApiError("path must start with download/ or /download/") + return path + + +def reject_blocked_device_state_mutation(method: str, path: str) -> None: + resource = path.split("?", 1)[0] + if method.upper() in BLOCKED_DEVICE_STATE_MUTATIONS.get(resource, set()): + raise WafApiError( + "360 WAF integration does not support modifying WAF device state " + f"through raw mutation tools: {method.upper()} {resource}" + ) + + +def reject_blocked_file_mutation(method: str, path: str) -> None: + if path == "/rest/file" and method.upper() == "DELETE": + path = "/rest/file?fileName=tmp" + if method.upper() in BLOCKED_FILE_MUTATIONS.get(path, set()): + raise WafApiError( + "360 WAF integration does not support WAF upgrade or import file operations: " + f"{method.upper()} {path}" + ) + + +def validate_documented_api(method: str, path: str) -> None: + resource = path.split("?", 1)[0] + methods = DOCUMENTED_API_METHODS.get(resource) + if methods is None: + raise WafApiError(f"{resource} is not listed in the local WAF API document") + if method.upper() not in methods: + raise WafApiError(f"{method.upper()} {resource} is not listed in the local WAF API document") + + +def validate_documented_file_api(method: str, path: str) -> None: + if path == "/rest/file" and method.upper() == "DELETE": + path = "/rest/file?fileName=tmp" + methods = DOCUMENTED_FILE_ENDPOINTS.get(path) + if methods is None: + raise WafApiError(f"{path} is not listed as a WAF file helper endpoint in the local document") + if method.upper() not in methods: + raise WafApiError(f"{method.upper()} {path} is not listed in the local WAF API document") + + +_ACTION_MAP = { + "waf_check_login": waf_check_login, + "waf_system_info_get": waf_system_info_get, + "waf_site_list": waf_site_list, + "waf_policy_list": waf_policy_list, + "waf_ac_policy_list": waf_ac_policy_list, + "waf_ac_policy_create_deny_uri": waf_ac_policy_create_deny_uri, + "waf_site_bind_ac_policy": waf_site_bind_ac_policy, + "waf_site_unbind_ac_policy": waf_site_unbind_ac_policy, + "waf_ac_policy_delete": waf_ac_policy_delete, + "waf_blacklist_create": waf_blacklist_create, + "waf_blacklist_delete": waf_blacklist_delete, + "waf_site_global_blacklist_create": waf_site_global_blacklist_create, + "waf_site_global_blacklist_delete": waf_site_global_blacklist_delete, + "waf_whitelist_create": waf_whitelist_create, + "waf_whitelist_delete": waf_whitelist_delete, + "waf_site_global_whitelist_create": waf_site_global_whitelist_create, + "waf_site_global_whitelist_delete": waf_site_global_whitelist_delete, + "waf_exception_rule_create": waf_exception_rule_create, + "waf_exception_rule_update": waf_exception_rule_update, + "waf_exception_rule_delete": waf_exception_rule_delete, + "waf_uri_block_on_site": waf_uri_block_on_site, + "waf_uri_unblock_on_site": waf_uri_unblock_on_site, + "waf_security_log_search": waf_security_log_search, + "waf_configuration_log_search": waf_configuration_log_search, + "waf_dashboard_stats": waf_dashboard_stats, + "waf_interface_list": waf_interface_list, + "waf_zone_list": waf_zone_list, + "waf_blacklist_list": waf_blacklist_list, + "waf_whitelist_list": waf_whitelist_list, + "waf_whitelist_check_ip": waf_whitelist_check_ip, + "waf_configfile_list": waf_configfile_list, + "waf_signature_status": waf_signature_status, + "waf_deploy_mode_get": waf_deploy_mode_get, + "waf_license_get": waf_license_get, + "waf_custom_error_page_list": waf_custom_error_page_list, + "waf_mgmt_image_get": waf_mgmt_image_get, + "waf_disk_usage_get": waf_disk_usage_get, + "waf_capacity_get": waf_capacity_get, + "waf_api_catalog": waf_api_catalog, + "waf_call_raw_readonly": waf_call_raw_readonly, + "waf_call_mutation": waf_call_mutation, + "waf_call_api": waf_call_api, + "waf_file_upload": waf_file_upload, + "waf_file_request": waf_file_request, + "waf_download_file": waf_download_file, + "waf_logout": waf_logout, +} + +GROUP_ACTIONS: dict[str, set[str]] = { + "system": { + "waf_check_login", + "waf_system_info_get", + "waf_interface_list", + "waf_zone_list", + "waf_configfile_list", + "waf_signature_status", + "waf_deploy_mode_get", + "waf_license_get", + "waf_custom_error_page_list", + "waf_mgmt_image_get", + "waf_disk_usage_get", + "waf_capacity_get", + "waf_logout", + }, + "site": { + "waf_site_list", + "waf_blacklist_list", + "waf_whitelist_list", + "waf_whitelist_check_ip", + }, + "policy_ops": { + "waf_policy_list", + "waf_ac_policy_list", + "waf_ac_policy_create_deny_uri", + "waf_site_bind_ac_policy", + "waf_site_unbind_ac_policy", + "waf_ac_policy_delete", + "waf_blacklist_create", + "waf_blacklist_delete", + "waf_site_global_blacklist_create", + "waf_site_global_blacklist_delete", + "waf_whitelist_create", + "waf_whitelist_delete", + "waf_site_global_whitelist_create", + "waf_site_global_whitelist_delete", + "waf_exception_rule_create", + "waf_exception_rule_update", + "waf_exception_rule_delete", + "waf_uri_block_on_site", + "waf_uri_unblock_on_site", + }, + "observability": { + "waf_security_log_search", + "waf_configuration_log_search", + "waf_dashboard_stats", + }, + "api_readonly": { + "waf_api_catalog", + "waf_call_raw_readonly", + }, + "api_mutation": { + "waf_call_mutation", + "waf_call_api", + }, + "file_ops": { + "waf_file_upload", + "waf_file_request", + "waf_download_file", + }, +} + +_CONNECTIVITY_TEST_ACTIONS = { + "system": "waf_check_login", + "site": "waf_site_list", + "policy_ops": "waf_policy_list", + "observability": "waf_security_log_search", + "api_readonly": "waf_api_catalog", +} + + +async def unified_ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + del ctx + handler = _ACTION_MAP.get(action) + if handler is None: + available = ", ".join(sorted(_ACTION_MAP)) + return ToolResult(success=False, error=f"Unknown action: {action}. Available: {available}") + try: + return await asyncio.to_thread(handler, params) + except WafApiError as exc: + return ToolResult( + success=False, + error=str(exc), + metadata={"source": "360 WAF", "version": PRODUCT_VERSION, "action": action}, + ) + except Exception as exc: + return ToolResult( + success=False, + error=f"Unexpected 360 WAF error: {exc}", + metadata={"source": "360 WAF", "version": PRODUCT_VERSION, "action": action}, + ) + + +async def _dispatch_group(ctx: ToolContext, group: str, action: str, **params: Any) -> ToolResult: + if action == "test": + test_action = _CONNECTIVITY_TEST_ACTIONS.get(group) + if test_action: + return await unified_ops(ctx, action=test_action, **params) + return ToolResult( + success=False, + error=( + f"360 WAF group {group} does not define a zero-argument " + "connectivity probe; pass an explicit action and parameters." + ), + ) + if action not in GROUP_ACTIONS[group]: + available = ", ".join(sorted(GROUP_ACTIONS[group])) + return ToolResult(success=False, error=f"Unsupported {group} action: {action}. Available: {available}") + return await unified_ops(ctx, action=action, **params) + + +async def system(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "system", action, **params) + + +async def site(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "site", action, **params) + + +async def policy_ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "policy_ops", action, **params) + + +async def observability(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "observability", action, **params) + + +async def api_readonly(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "api_readonly", action, **params) + + +async def api_mutation(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "api_mutation", action, **params) + + +async def file_ops(ctx: ToolContext, action: str, **params: Any) -> ToolResult: + return await _dispatch_group(ctx, "file_ops", action, **params) diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_api_mutation.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_api_mutation.yaml new file mode 100644 index 000000000..e76b648fc --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_api_mutation.yaml @@ -0,0 +1,37 @@ +name: 360_waf_api_mutation +description: > + 360 WAF v5.5 官方 REST 变更调用工具。该工具通过 Flocks requires_confirmation 触发确认; + 重启、升级、配置导入、部署模式、License、网络接口和安全域等 WAF 本体修改接口会直接拒绝。 +category: custom +enabled: true +requires_confirmation: true +provider: 360_waf +inputSchema: + type: object + properties: + action: + type: string + description: REST 变更操作名称。 + enum: + - waf_call_mutation + - waf_call_api + - test + method: + type: string + enum: [GET, POST, PUT, DELETE] + description: HTTP 方法。 + path: + type: string + description: 已收录的 /rest/api/... 路径。 + query: + type: object + description: 可选的 WAF 查询对象,会编码为 query= JSON。 + body: + type: string + description: POST、PUT 或 DELETE 调用使用的 JSON 请求体,传入对象或数组的 JSON 字符串。 + required: + - action +handler: + type: script + script_file: 360_waf.handler.py + function: api_mutation diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_api_readonly.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_api_readonly.yaml new file mode 100644 index 000000000..0f38e282b --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_api_readonly.yaml @@ -0,0 +1,30 @@ +name: 360_waf_api_readonly +description: > + 360 WAF v5.5 官方 REST 只读调用工具。可用 waf_api_catalog 查看已收录接口, + 或用 waf_call_raw_readonly 调用已收录的 GET 接口。 +category: custom +enabled: true +requires_confirmation: false +provider: 360_waf +inputSchema: + type: object + properties: + action: + type: string + description: REST 只读操作名称。 + enum: + - waf_api_catalog + - waf_call_raw_readonly + - test + path: + type: string + description: 已收录的 /rest/api/... GET 路径。 + query: + type: object + description: 可选的 WAF 查询对象,会编码为 query= JSON。 + required: + - action +handler: + type: script + script_file: 360_waf.handler.py + function: api_readonly diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_file.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_file.yaml new file mode 100644 index 000000000..40b9d8682 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_file.yaml @@ -0,0 +1,41 @@ +name: 360_waf_file +description: > + 360 WAF v5.5 文件操作工具。用于调用已收录的上传、文件请求和下载接口; + 上传和删除类操作通过 Flocks requires_confirmation 触发确认,升级、配置导入和临时文件删除相关接口会直接拒绝。 +category: custom +enabled: true +requires_confirmation: true +provider: 360_waf +inputSchema: + type: object + properties: + action: + type: string + description: 文件操作名称。 + enum: + - waf_file_upload + - waf_file_request + - waf_download_file + - test + method: + type: string + enum: [DELETE] + description: waf_file_request 使用的方法。 + path: + type: string + description: /rest/file... 上传或删除路径,或 /download/... 下载路径。 + file_path: + type: string + description: 需要上传的本地文件路径。 + fields: + type: object + description: 上传时附加的 multipart 表单字段。 + save_path: + type: string + description: 下载文件保存到本地的路径。 + required: + - action +handler: + type: script + script_file: 360_waf.handler.py + function: file_ops diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_observability.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_observability.yaml new file mode 100644 index 000000000..d123a523c --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_observability.yaml @@ -0,0 +1,83 @@ +name: 360_waf_observability +description: > + 360 WAF v5.5 观测与日志工具。用于检索 Web 安全日志和查询仪表盘统计数据。 +category: custom +enabled: true +requires_confirmation: false +provider: 360_waf +inputSchema: + type: object + properties: + action: + type: string + description: 观测类操作名称。 + enum: + - waf_security_log_search + - waf_configuration_log_search + - waf_dashboard_stats + - test + interval: + type: string + enum: [realtime, hour, day, week, month] + description: 日志或仪表盘数据的时间范围。 + severity: + type: integer + description: 严重级别过滤,1 严重、2 高危、3 中危、4 低危。 + start: + type: integer + description: 分页起始偏移量。 + limit: + type: integer + description: 分页条数,最大 500。 + client_ip: + type: string + description: 客户端 IP 过滤条件。 + time_start: + type: string + description: 自定义开始时间,格式示例 2026/05/29 08:00:00。 + time_end: + type: string + description: 自定义结束时间,格式示例 2026/05/29 09:00:00。 + http_url: + type: string + description: Web 安全日志 URL 过滤条件。 + action_filter: + type: string + description: Web 安全日志防护动作过滤条件,例如 deny、pass、allow。 + msg: + type: string + description: 配置日志消息内容过滤条件。 + server_ip: + type: string + description: 服务器 IP 过滤条件。 + site_name: + type: string + description: 站点名称过滤条件。 + policy_name: + type: string + description: 策略名称过滤条件。 + domain_name: + type: string + description: 域名过滤条件。 + http_method: + type: string + description: HTTP 方法过滤条件。 + rule_id: + type: string + description: 规则 ID 过滤条件。 + protection_type: + type: string + description: 防护类型过滤条件。 + protection_sub_type: + type: string + description: 防护子类型过滤条件。 + kind: + type: string + enum: [attack_source_ip, attack_source_country, threat_category, site_attack] + description: 仪表盘统计类型。 + required: + - action +handler: + type: script + script_file: 360_waf.handler.py + function: observability diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_policy_ops.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_policy_ops.yaml new file mode 100644 index 000000000..fa2aece1a --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_policy_ops.yaml @@ -0,0 +1,133 @@ +name: 360_waf_policy_ops +description: > + 360 WAF v5.5 安全策略与访问控制操作工具。只读操作用于查询 WAF 策略, + 变更操作通过 Flocks requires_confirmation 触发确认。 +category: custom +enabled: true +requires_confirmation: true +provider: 360_waf +inputSchema: + type: object + properties: + action: + type: string + description: 策略操作名称。 + enum: + - waf_policy_list + - waf_ac_policy_list + - waf_ac_policy_create_deny_uri + - waf_site_bind_ac_policy + - waf_site_unbind_ac_policy + - waf_ac_policy_delete + - waf_blacklist_create + - waf_blacklist_delete + - waf_site_global_blacklist_create + - waf_site_global_blacklist_delete + - waf_whitelist_create + - waf_whitelist_delete + - waf_site_global_whitelist_create + - waf_site_global_whitelist_delete + - waf_exception_rule_create + - waf_exception_rule_update + - waf_exception_rule_delete + - waf_uri_block_on_site + - waf_uri_unblock_on_site + - test + name: + type: string + description: 创建 URI 拦截策略时使用的访问控制策略名称。 + site_id: + type: string + description: 绑定、解绑、拦截和解除拦截操作使用的站点 ID。 + id: + type: integer + description: 站点白名单操作使用的站点 ID。 + siteId: + type: integer + description: 站点黑名单操作使用的站点 ID。 + site_name: + type: string + description: 绑定、解绑、拦截和解除拦截操作使用的站点名称,默认 default。 + policy_id: + type: string + description: 访问控制策略 ID。 + position: + type: string + description: 绑定策略时插入策略 ID 的位置。 + enum: [append, prepend] + uri_path: + type: string + description: 以斜杠开头的 URI 路径,用于 URI 拦截或解除拦截。 + policy_name: + type: string + description: URI 拦截或解除拦截使用的已有或自动生成策略名称。 + reuse_existing: + type: boolean + description: 拦截 URI 时是否复用已有匹配的拒绝策略。 + delete_policy: + type: boolean + description: 解除拦截并解绑后是否删除访问控制策略。 + type: + type: integer + description: 名单类型。黑名单中 1 单 IP、3 IP 范围、4 IP/掩码;白名单中 0 单 IP、1 IP/掩码、2 IP 范围。 + content: + type: string + description: 黑名单内容,通常是客户端 IP、IP 范围或 IP/掩码。 + is_permanent: + type: integer + enum: [0, 1] + description: 黑名单是否永久阻断,1 为永久,0 为临时。 + block_time: + type: integer + description: 临时黑名单阻断时间,单位分钟,范围 1-1440。 + ip_ver: + type: integer + enum: [0, 1] + description: 白名单 IP 版本,0 为 IPv4,1 为 IPv6。 + ip_start: + type: string + description: 白名单起始 IP 或单个 IP。 + ip_end: + type: string + description: 白名单 IP 范围结束 IP。 + netmask: + type: integer + description: 白名单 IP/掩码的掩码长度。 + desc: + type: string + description: 白名单描述。 + status_code: + type: integer + enum: [400, 403, 404, 405, 500, 501, 505] + description: URI 拒绝策略返回的 HTTP 状态码。 + description: + type: string + description: 策略描述。 + operator: + type: string + enum: [location, rx] + description: URI 匹配方式。 + no_case: + type: integer + enum: [0, 1] + description: URI 匹配是否忽略大小写。 + http_method: + type: string + description: 可选的 HTTP 方法过滤条件。 + capture_pkt: + type: integer + enum: [0, 1] + description: 是否抓包。 + log: + type: integer + enum: [0, 1] + description: 是否记录拒绝规则日志。 + body: + type: string + description: 规则例外 create/update/delete 使用的完整 JSON payload,传入对象或数组的 JSON 字符串。 + required: + - action +handler: + type: script + script_file: 360_waf.handler.py + function: policy_ops diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_site.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_site.yaml new file mode 100644 index 000000000..4d1fafb03 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_site.yaml @@ -0,0 +1,42 @@ +name: 360_waf_site +description: > + 360 WAF v5.5 站点工具。通过 action 查询受保护站点列表、站点黑名单 + 或站点白名单记录。 +category: custom +enabled: true +requires_confirmation: false +provider: 360_waf +inputSchema: + type: object + properties: + action: + type: string + description: 站点类操作名称。 + enum: + - waf_site_list + - waf_blacklist_list + - waf_whitelist_list + - waf_whitelist_check_ip + - test + id: + type: string + description: 站点或白名单查询使用的站点 ID。 + name: + type: string + description: 站点名称模糊过滤条件。 + siteId: + type: integer + description: 查询黑名单使用的站点 ID。 + type: + type: integer + description: 名单类型,通常 1 表示 IP,2 表示 URL 或域名。 + enum: [1, 2] + ip: + type: string + description: 检查是否命中白名单的客户端 IP。 + required: + - action +handler: + type: script + script_file: 360_waf.handler.py + function: site diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_system.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_system.yaml new file mode 100644 index 000000000..62d4c2ac1 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/360_waf_system.yaml @@ -0,0 +1,43 @@ +name: 360_waf_system +description: > + 360 WAF v5.5 系统与状态工具。通过 action 检查登录状态,查询系统信息、 + 网络接口、区域、配置文件、授权信息、部署模式、规则库状态、磁盘与容量, + 或执行退出登录。 +category: custom +enabled: true +requires_confirmation: false +provider: 360_waf +inputSchema: + type: object + properties: + action: + type: string + description: 系统类操作名称。 + enum: + - waf_check_login + - waf_system_info_get + - waf_interface_list + - waf_zone_list + - waf_configfile_list + - waf_signature_status + - waf_deploy_mode_get + - waf_license_get + - waf_custom_error_page_list + - waf_mgmt_image_get + - waf_disk_usage_get + - waf_capacity_get + - waf_logout + - test + queryStatus: + type: boolean + description: 查询规则库状态时是否包含查询状态。 + version: + type: integer + description: 管理镜像版本,1 表示当前版本,0 表示备份版本。 + enum: [0, 1] + required: + - action +handler: + type: script + script_file: 360_waf.handler.py + function: system diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/_provider.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/_provider.yaml new file mode 100644 index 000000000..702af5554 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/_provider.yaml @@ -0,0 +1,48 @@ +name: "360_waf" +vendor: "360" +service_id: "360_waf" +version: "5.5" +integration_type: device +description: > + 360 WAF v5.5 设备接入。该插件将现有 waf-mcp v2 的 REST 能力适配为 + Flocks 设备工具,支持登录检测、站点与策略查询、访问控制操作、日志检索、 + 官方 REST 接口调用、文件上传下载以及退出登录。 +description_cn: > + 360 WAF v5.5 Web 应用防火墙设备接入。支持登录检测、系统状态查询、 + 站点与访问控制策略管理、黑白名单操作、日志检索、官方 REST 接口调用 + 以及文件上传下载。 +auth: + type: custom + flow: login_then_cookie + login_path: /rest/api/login +credential_fields: + - key: base_url + label: 设备地址 + storage: config + config_key: base_url + input_type: url + default: "https://YOUR_360_WAF_HOST" + required: true + - key: username + label: 用户名 + storage: secret + config_key: username + secret_id: 360_waf_v5_5_username + input_type: text + required: true + - key: password + label: 密码 + storage: secret + config_key: password + secret_id: 360_waf_v5_5_password + input_type: password + required: true +defaults: + base_url: "https://YOUR_360_WAF_HOST" + timeout: 30 + category: custom + product_version: "5.5" + verify_ssl: false +notes: | + 处理器通过 ConfigWriter 读取设备凭据,因此工具调用可以定位到具体的 + device_id。变更类工具通过 Flocks 官方 requires_confirmation 机制触发确认。 diff --git a/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/_test.yaml b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/_test.yaml new file mode 100644 index 000000000..6b6885751 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/360_waf_v5_5/_test.yaml @@ -0,0 +1,98 @@ +schema_version: 1 +provider: 360_waf + +connectivity: + tool: 360_waf_system + params: + action: waf_check_login + +fixtures: + 360_waf_system: + - label: "Check current login session" + tags: [smoke, auth] + params: + action: waf_check_login + assert: + success: true + - label: "Get system information" + tags: [smoke, system] + params: + action: waf_system_info_get + assert: + success: true + + 360_waf_site: + - label: "List protected sites" + tags: [smoke, site] + params: + action: waf_site_list + assert: + success: true + - label: "List IP blacklist entries for a site" + tags: [site, blacklist] + params: + action: waf_blacklist_list + siteId: 1 + type: 1 + + 360_waf_policy_ops: + - label: "List WAF policies" + tags: [smoke, policy] + params: + action: waf_policy_list + assert: + success: true + - label: "List access-control policies" + tags: [smoke, policy] + params: + action: waf_ac_policy_list + assert: + success: true + + 360_waf_observability: + - label: "Search web security logs" + tags: [smoke, logs] + params: + action: waf_security_log_search + interval: hour + start: 0 + limit: 20 + assert: + success: true + - label: "Get attack-source IP dashboard" + tags: [dashboard] + params: + action: waf_dashboard_stats + kind: attack_source_ip + interval: hour + + 360_waf_api_readonly: + - label: "Show documented API catalog" + tags: [smoke, api] + params: + action: waf_api_catalog + assert: + success: true + - label: "Call system info through read-only REST helper" + tags: [api] + params: + action: waf_call_raw_readonly + path: /rest/api/sysinfo + + 360_waf_api_mutation: + - label: "Call system info through unified REST helper" + tags: [api] + params: + action: waf_call_api + method: GET + path: /rest/api/sysinfo + assert: + success: true + + 360_waf_file: + - label: "Download a device file" + tags: [file] + params: + action: waf_download_file + path: /download/example + save_path: ./outputs/360_waf_download.bin diff --git a/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/_provider.yaml b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/_provider.yaml new file mode 100644 index 000000000..81d6305a0 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/_provider.yaml @@ -0,0 +1,71 @@ +name: huaweicloud_waf +vendor: huaweicloud +service_id: huaweicloud_waf_api +version: "39" +integration_type: device +description: > + Huawei Cloud Web Application Firewall (WAF) API service. Configure + Access Key (AK), Secret Key (SK), Region, and Project ID in the credentials + form. The API endpoint is constructed as + `https://waf.{region}.myhuaweicloud.com`. +description_cn: > + 华为云 Web 应用防火墙(WAF)API 服务。在配置页填写 Access Key(AK)、 + Secret Key(SK)、Region(地域,如 `cn-north-4`)和 Project ID(项目 ID)。 + 接口基础地址自动构建为 `https://waf.{region}.myhuaweicloud.com`。 + 也支持直接配置 Token(X-Auth-Token)方式认证。 +auth: + type: custom + secret: huaweicloud_waf_ak + secret_secret: huaweicloud_waf_sk +credential_fields: + - key: ak + label: Access Key (AK) + storage: secret + config_key: ak + secret_id: huaweicloud_waf_ak + input_type: password + required: false + - key: sk + label: Secret Key (SK) + storage: secret + config_key: sk + secret_id: huaweicloud_waf_sk + input_type: password + required: false + - key: token + label: X-Auth-Token(与 AK/SK 二选一) + storage: secret + config_key: token + secret_id: huaweicloud_waf_token + input_type: password + required: false + - key: region + label: Region + storage: config + config_key: region + input_type: text + default: "cn-north-4" + - key: project_id + label: Project ID + storage: config + config_key: project_id + input_type: text + required: true + - key: enterprise_project_id + label: Enterprise Project ID(可选) + storage: config + config_key: enterprise_project_id + input_type: text + required: false +defaults: + region: "cn-north-4" + timeout: 60 + category: custom + product_version: "39" +notes: | + 华为云 WAF API 支持两种认证方式: + 1. AK/SK 签名认证(推荐):使用 SDK-HMAC-SHA256 签名, + 配置 ak / sk 字段即可。 + 2. Token 认证:直接填写有效的 X-Auth-Token,Token 有效期 24 小时。 + AK/SK 认证优先级高于 Token;若两者均未配置则报错。 + project_id 为必填项,可在控制台"我的凭证 → 项目列表"中获取。 diff --git a/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/_test.yaml b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/_test.yaml new file mode 100644 index 000000000..1316a8958 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/_test.yaml @@ -0,0 +1,65 @@ +schema_version: 1 +provider: huaweicloud_waf_api + +# Service-level connectivity probe. +# `host_list` is a lightweight read-only call with page=1,pagesize=1. +connectivity: + tool: hw_waf_host + params: + action: host_list + page: 1 + pagesize: 1 + +# Tool-level test samples shown in the WebUI ToolDetailDrawer drop-down. +fixtures: + hw_waf_host: + - label: "List cloud mode protected domains (page 1)" + label_cn: "查询云模式防护域名列表(第 1 页)" + tags: [smoke] + params: + action: host_list + page: 1 + pagesize: 10 + assert: + success: true + + - label: "List dedicated mode protected domains (page 1)" + label_cn: "查询独享模式防护域名列表(第 1 页)" + tags: [smoke] + params: + action: premium_host_list + page: 1 + pagesize: 10 + assert: + success: true + + hw_waf_policy: + - label: "List protection policies (page 1)" + label_cn: "查询防护策略列表(第 1 页)" + tags: [smoke] + params: + action: policy_list + page: 1 + pagesize: 10 + assert: + success: true + + hw_waf_event: + - label: "List attack events (last hour)" + label_cn: "查询近 1 小时攻击事件" + tags: [smoke] + params: + action: event_list + page: 1 + pagesize: 10 + assert: + success: true + + hw_waf_overview: + - label: "Query security statistics" + label_cn: "查询安全统计概览" + tags: [smoke] + params: + action: overview_statistics + assert: + success: true diff --git a/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf.handler.py b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf.handler.py new file mode 100644 index 000000000..fd0e4409b --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf.handler.py @@ -0,0 +1,551 @@ +from __future__ import annotations + +import datetime +import hashlib +import hmac +import json +import os +from typing import Any, Optional +from urllib.parse import urlencode, quote + +import aiohttp + +from flocks.config.config_writer import ConfigWriter +from flocks.tool.registry import ToolContext, ToolResult + + +DEFAULT_REGION = "cn-north-4" +DEFAULT_TIMEOUT = 60 +SERVICE_ID = "huaweicloud_waf_api" +WAF_SERVICE_NAME = "waf" + + +def _get_secret_manager(): + from flocks.security import get_secret_manager + + return get_secret_manager() + + +def _resolve_ref(value: Any) -> Optional[str]: + if value is None: + return None + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return _get_secret_manager().get(value[len("{secret:"):-1]) + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:"):-1]) + return value + + +def _service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + return raw if isinstance(raw, dict) else {} + + +def _resolve_verify_ssl(raw: dict[str, Any]) -> bool: + value = raw.get("verify_ssl") or raw.get("ssl_verify") + if value is None: + value = raw.get("custom_settings", {}).get("verify_ssl", True) + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + +class _WAFConfig: + def __init__( + self, + ak: Optional[str], + sk: Optional[str], + token: Optional[str], + region: str, + project_id: str, + enterprise_project_id: Optional[str], + timeout: int, + verify_ssl: bool, + ) -> None: + self.ak = ak + self.sk = sk + self.token = token + self.region = region + self.project_id = project_id + self.enterprise_project_id = enterprise_project_id + self.timeout = timeout + self.verify_ssl = verify_ssl + + @property + def endpoint(self) -> str: + return f"https://waf.{self.region}.myhuaweicloud.com" + + def use_aksk(self) -> bool: + return bool(self.ak and self.sk) + + +def _load_config(param_epid: Optional[str] = None) -> _WAFConfig: + raw = _service_config() + secret_manager = _get_secret_manager() + + ak = ( + _resolve_ref(raw.get("ak")) + or secret_manager.get("huaweicloud_waf_ak") + or os.getenv("HUAWEICLOUD_WAF_AK") + ) + sk = ( + _resolve_ref(raw.get("sk")) + or secret_manager.get("huaweicloud_waf_sk") + or os.getenv("HUAWEICLOUD_WAF_SK") + ) + token = ( + _resolve_ref(raw.get("token")) + or secret_manager.get("huaweicloud_waf_token") + or os.getenv("HUAWEICLOUD_WAF_TOKEN") + ) + region = _resolve_ref(raw.get("region")) or DEFAULT_REGION + project_id = _resolve_ref(raw.get("project_id")) or os.getenv("HUAWEICLOUD_PROJECT_ID", "") + if not project_id: + raise ValueError( + "Huawei Cloud WAF: project_id is required. " + "Configure it in the huaweicloud_waf_api service settings." + ) + if not ak and not token: + raise ValueError( + "Huawei Cloud WAF credentials not found. Configure ak/sk or token " + "in the huaweicloud_waf_api service settings." + ) + enterprise_project_id = ( + param_epid + or _resolve_ref(raw.get("enterprise_project_id")) + or os.getenv("HUAWEICLOUD_ENTERPRISE_PROJECT_ID") + ) + timeout = int(raw.get("timeout", DEFAULT_TIMEOUT)) + return _WAFConfig( + ak=ak, + sk=sk, + token=token, + region=region, + project_id=project_id, + enterprise_project_id=enterprise_project_id, + timeout=timeout, + verify_ssl=_resolve_verify_ssl(raw), + ) + + +# --------------------------------------------------------------------------- +# Huawei Cloud SDK-HMAC-SHA256 signing (AK/SK auth) +# Reference: https://support.huaweicloud.com/api-dew/dew_02_0008.html +# --------------------------------------------------------------------------- + +def _sha256_hex(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +def _hmac_sha256(key: bytes, msg: str) -> bytes: + return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest() + + +def _build_aksk_headers( + ak: str, + sk: str, + method: str, + host: str, + uri: str, + query_string: str, + body: bytes, +) -> dict[str, str]: + """ + Build Huawei Cloud SDK-HMAC-SHA256 signed headers. + + Authorization = SDK-HMAC-SHA256 Access={ak}, + SignedHeaders={signed_headers}, Signature={signature} + """ + now = datetime.datetime.utcnow() + x_sdk_date = now.strftime("%Y%m%dT%H%M%SZ") + date_stamp = now.strftime("%Y%m%d") + + payload_hash = _sha256_hex(body) + headers_to_sign = { + "content-type": "application/json;charset=utf8", + "host": host, + "x-sdk-date": x_sdk_date, + } + signed_headers = ";".join(sorted(headers_to_sign.keys())) + canonical_headers = "".join( + f"{k}:{v}\n" for k, v in sorted(headers_to_sign.items()) + ) + canonical_uri = uri if uri else "/" + canonical_request = "\n".join([ + method.upper(), + canonical_uri, + query_string, + canonical_headers, + signed_headers, + payload_hash, + ]) + credential_scope = f"{date_stamp}/{WAF_SERVICE_NAME}/sdk_request" + string_to_sign = "\n".join([ + "SDK-HMAC-SHA256", + x_sdk_date, + credential_scope, + _sha256_hex(canonical_request.encode("utf-8")), + ]) + signing_key = _hmac_sha256( + _hmac_sha256( + _hmac_sha256( + _hmac_sha256( + ("SDK" + sk).encode("utf-8"), + date_stamp, + ), + WAF_SERVICE_NAME, + ), + "sdk_request", + ), + "sdk_signing", + ) + signature = hmac.new(signing_key, string_to_sign.encode("utf-8"), hashlib.sha256).hexdigest() + authorization = ( + f"SDK-HMAC-SHA256 Access={ak}, " + f"SignedHeaders={signed_headers}, " + f"Signature={signature}" + ) + return { + "Content-Type": "application/json;charset=utf8", + "X-Sdk-Date": x_sdk_date, + "Authorization": authorization, + } + + +async def _request( + cfg: _WAFConfig, + method: str, + path: str, + query: Optional[dict[str, Any]] = None, + body: Optional[dict[str, Any]] = None, +) -> ToolResult: + query = query or {} + if cfg.enterprise_project_id: + query.setdefault("enterprise_project_id", cfg.enterprise_project_id) + qs = urlencode({k: v for k, v in query.items() if v is not None}) + url = f"{cfg.endpoint}{path}" + if qs: + url = f"{url}?{qs}" + + body_bytes = json.dumps(body, ensure_ascii=False).encode("utf-8") if body else b"" + host = f"waf.{cfg.region}.myhuaweicloud.com" + + if cfg.use_aksk(): + headers = _build_aksk_headers( + cfg.ak, + cfg.sk, + method, + host, + path, + qs, + body_bytes, + ) + else: + headers = { + "Content-Type": "application/json;charset=utf8", + "X-Auth-Token": cfg.token, + } + + connector = aiohttp.TCPConnector(ssl=cfg.verify_ssl) + async with aiohttp.ClientSession(connector=connector) as session: + req_kwargs: dict[str, Any] = { + "headers": headers, + "timeout": aiohttp.ClientTimeout(total=cfg.timeout), + } + if body_bytes: + req_kwargs["data"] = body_bytes + async with session.request(method, url, **req_kwargs) as resp: + resp_text = await resp.text() + try: + resp_json = json.loads(resp_text) + except Exception: + resp_json = {"raw": resp_text} + if resp.status >= 400: + return ToolResult( + success=False, + data=resp_json, + error=f"HTTP {resp.status}: {resp_text[:300]}", + ) + return ToolResult(success=True, data=resp_json) + + +def _pick(params: dict[str, Any], *keys: str) -> dict[str, Any]: + return {k: params[k] for k in keys if k in params and params[k] is not None} + + +def _page_query(params: dict[str, Any]) -> dict[str, Any]: + q: dict[str, Any] = {} + if "page" in params and params["page"] is not None: + q["page"] = params["page"] + if "pagesize" in params and params["pagesize"] is not None: + q["pagesize"] = params["pagesize"] + return q + + +# --------------------------------------------------------------------------- +# Tool handler functions +# --------------------------------------------------------------------------- + + +async def host(params: dict[str, Any], ctx: ToolContext) -> ToolResult: + cfg = _load_config(params.get("enterprise_project_id")) + pid = cfg.project_id + action = params.get("action", "") + + if action == "host_list": + q = _page_query(params) + if params.get("hostname"): + q["hostname"] = params["hostname"] + if params.get("policyname"): + q["policyname"] = params["policyname"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/instance", query=q) + + if action == "host_show": + iid = params["instance_id"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/instance/{iid}") + + if action == "host_create": + body = _pick(params, "hostname", "proxy", "server", "certificateid") + return await _request(cfg, "POST", f"/v1/{pid}/waf/instance", body=body) + + if action == "host_update": + iid = params["instance_id"] + body = _pick(params, "proxy", "server", "certificateid", "protect_status") + return await _request(cfg, "PATCH", f"/v1/{pid}/waf/instance/{iid}", body=body) + + if action == "host_delete": + iid = params["instance_id"] + return await _request(cfg, "DELETE", f"/v1/{pid}/waf/instance/{iid}") + + if action == "host_update_protect_status": + iid = params["instance_id"] + body = _pick(params, "protect_status") + return await _request(cfg, "PUT", f"/v1/{pid}/waf/instance/{iid}/protect-status", body=body) + + if action == "premium_host_list": + q = _page_query(params) + if params.get("hostname"): + q["hostname"] = params["hostname"] + if params.get("policyname"): + q["policyname"] = params["policyname"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/premium-host", query=q) + + if action == "premium_host_show": + iid = params["instance_id"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/premium-host/{iid}") + + if action == "premium_host_create": + body = _pick(params, "hostname", "proxy", "server", "web_tag", "certificateid") + return await _request(cfg, "POST", f"/v1/{pid}/waf/premium-host", body=body) + + if action == "premium_host_update": + iid = params["instance_id"] + body = _pick(params, "proxy", "server", "protect_status", "certificateid") + return await _request(cfg, "PATCH", f"/v1/{pid}/waf/premium-host/{iid}", body=body) + + if action == "premium_host_delete": + iid = params["instance_id"] + return await _request(cfg, "DELETE", f"/v1/{pid}/waf/premium-host/{iid}") + + if action == "composite_host_list": + q = _page_query(params) + if params.get("hostname"): + q["hostname"] = params["hostname"] + return await _request(cfg, "GET", f"/v1/{pid}/composite-waf/host", query=q) + + return ToolResult(success=False, error=f"Unknown action: {action}") + + +async def policy(params: dict[str, Any], ctx: ToolContext) -> ToolResult: + cfg = _load_config(params.get("enterprise_project_id")) + pid = cfg.project_id + action = params.get("action", "") + + if action == "policy_list": + q = _page_query(params) + if params.get("name"): + q["name"] = params["name"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/policy", query=q) + + if action == "policy_show": + pol_id = params["policy_id"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/policy/{pol_id}") + + if action == "policy_create": + body = _pick(params, "name", "level", "full_detection") + return await _request(cfg, "POST", f"/v1/{pid}/waf/policy", body=body) + + if action == "policy_update": + pol_id = params["policy_id"] + body = _pick(params, "name", "level", "full_detection", "options") + return await _request(cfg, "PATCH", f"/v1/{pid}/waf/policy/{pol_id}", body=body) + + if action == "policy_delete": + pol_id = params["policy_id"] + return await _request(cfg, "DELETE", f"/v1/{pid}/waf/policy/{pol_id}") + + if action == "policy_update_hosts": + pol_id = params["policy_id"] + body = _pick(params, "hosts") + return await _request(cfg, "PUT", f"/v1/{pid}/waf/policy/{pol_id}/hosts", body=body) + + if action == "cc_rule_list": + pol_id = params["policy_id"] + q = _page_query(params) + return await _request(cfg, "GET", f"/v1/{pid}/waf/policy/{pol_id}/cc", query=q) + + if action == "cc_rule_create": + pol_id = params["policy_id"] + body = _pick(params, "url", "limit_num", "limit_period", "lock_time", "tag_type", "action") + return await _request(cfg, "POST", f"/v1/{pid}/waf/policy/{pol_id}/cc", body=body) + + if action == "cc_rule_delete": + pol_id = params["policy_id"] + rule_id = params["rule_id"] + return await _request(cfg, "DELETE", f"/v1/{pid}/waf/policy/{pol_id}/cc/{rule_id}") + + if action == "custom_rule_list": + pol_id = params["policy_id"] + q = _page_query(params) + return await _request(cfg, "GET", f"/v1/{pid}/waf/policy/{pol_id}/custom", query=q) + + if action == "custom_rule_create": + pol_id = params["policy_id"] + body = _pick(params, "name", "conditions", "action", "priority", "description") + return await _request(cfg, "POST", f"/v1/{pid}/waf/policy/{pol_id}/custom", body=body) + + if action == "custom_rule_delete": + pol_id = params["policy_id"] + rule_id = params["rule_id"] + return await _request(cfg, "DELETE", f"/v1/{pid}/waf/policy/{pol_id}/custom/{rule_id}") + + if action == "whiteblackip_rule_list": + pol_id = params["policy_id"] + q = _page_query(params) + return await _request(cfg, "GET", f"/v1/{pid}/waf/policy/{pol_id}/whiteblackip", query=q) + + if action == "whiteblackip_rule_create": + pol_id = params["policy_id"] + body = _pick(params, "addr", "white", "description") + return await _request(cfg, "POST", f"/v1/{pid}/waf/policy/{pol_id}/whiteblackip", body=body) + + if action == "whiteblackip_rule_delete": + pol_id = params["policy_id"] + rule_id = params["rule_id"] + return await _request(cfg, "DELETE", f"/v1/{pid}/waf/policy/{pol_id}/whiteblackip/{rule_id}") + + if action == "geoip_rule_list": + pol_id = params["policy_id"] + q = _page_query(params) + return await _request(cfg, "GET", f"/v1/{pid}/waf/policy/{pol_id}/geoip", query=q) + + return ToolResult(success=False, error=f"Unknown action: {action}") + + +async def event(params: dict[str, Any], ctx: ToolContext) -> ToolResult: + cfg = _load_config(params.get("enterprise_project_id")) + pid = cfg.project_id + action = params.get("action", "") + + if action == "event_list": + q = _page_query(params) + for k in ("from", "to", "hosts", "attacks", "action"): + if params.get(k) is not None: + q[k] = params[k] + return await _request(cfg, "GET", f"/v1/{pid}/waf/event/attack/logs", query=q) + + if action == "event_show": + eid = params["eventid"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/event/attack/logs/{eid}") + + if action == "event_log_download": + q: dict[str, Any] = {} + for k in ("from", "to"): + if params.get(k) is not None: + q[k] = params[k] + return await _request(cfg, "GET", f"/v1/{pid}/waf/event/attack/log/download", query=q) + + if action == "event_export_job": + body = {} + for k in ("from", "to", "hosts", "attacks", "action"): + if params.get(k) is not None: + body[k] = params[k] + return await _request(cfg, "POST", f"/v1/{pid}/waf/event/attack/log/job", body=body) + + if action == "threat_distribution": + q = {} + for k in ("from", "to", "hosts"): + if params.get(k) is not None: + q[k] = params[k] + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/attack/types", query=q) + + if action == "top_url": + q = {} + for k in ("from", "to", "hosts", "top"): + if params.get(k) is not None: + q[k] = params[k] + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/attack/top/url", query=q) + + if action == "top_source_ip": + q = {} + for k in ("from", "to", "hosts", "top"): + if params.get(k) is not None: + q[k] = params[k] + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/attack/top/source", query=q) + + return ToolResult(success=False, error=f"Unknown action: {action}") + + +async def overview(params: dict[str, Any], ctx: ToolContext) -> ToolResult: + cfg = _load_config(params.get("enterprise_project_id")) + pid = cfg.project_id + action = params.get("action", "") + + def _time_query() -> dict[str, Any]: + q: dict[str, Any] = {} + for k in ("from", "to", "hosts"): + if params.get(k) is not None: + q[k] = params[k] + return q + + if action == "overview_statistics": + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/statistics", query=_time_query()) + + if action == "overview_qps": + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/statistics/qps", query=_time_query()) + + if action == "overview_bandwidth": + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/statistics/bandwidth", query=_time_query()) + + if action == "overview_top_domains": + q = _time_query() + if params.get("top"): + q["top"] = params["top"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/attack/top/host", query=q) + + if action == "overview_attack_types": + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/attack/types", query=_time_query()) + + if action == "overview_top_ip": + q = _time_query() + if params.get("top"): + q["top"] = params["top"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/attack/top/source", query=q) + + if action == "overview_top_url": + q = _time_query() + if params.get("top"): + q["top"] = params["top"] + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/attack/top/url", query=q) + + if action == "overview_response_code": + return await _request(cfg, "GET", f"/v1/{pid}/waf/overviews/statistics/response_code", query=_time_query()) + + if action == "console_config": + return await _request(cfg, "GET", f"/v1/{pid}/waf/config/console") + + return ToolResult(success=False, error=f"Unknown action: {action}") diff --git a/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_event.yaml b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_event.yaml new file mode 100644 index 000000000..0c6c99713 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_event.yaml @@ -0,0 +1,115 @@ +name: hw_waf_event +description: > + Huawei Cloud WAF protection event management tool. Use the `action` + parameter to query attack events, export event logs, and analyze threat + distributions. +description_cn: > + 华为云 WAF 防护事件管理工具。通过 `action` 参数查询攻击事件列表、事件详情、 + 攻击分布分析和事件日志导出等接口。 +category: custom +enabled: true +requires_confirmation: true +provider: huaweicloud_waf_api +inputSchema: + type: object + properties: + action: + type: string + description: | + 防护事件管理动作名,可选值: + - event_list + 用途: 查询攻击事件列表(分页) + 必填: 无 + 常用: `from`、`to`、`hosts`、`attacks`、`action`、`page`、`pagesize` + 风险提示: 只读查询接口;建议传 from/to 缩小范围 + 是否任务型: 否 + - event_show + 用途: 查询指定攻击事件详情 + 必填: `eventid` + 常用: `eventid` + 风险提示: 只读查询接口 + 是否任务型: 否 + - event_log_download + 用途: 查询事件日志下载链接(按日期) + 必填: 无 + 常用: `from`、`to` + 风险提示: 只读查询接口;返回 URL 有效期较短 + 是否任务型: 否 + - event_export_job + 用途: 下发自定义导出攻击事件的异步任务 + 必填: `from`、`to` + 常用: `from`、`to`、`hosts`、`attacks`、`action` + 风险提示: 写操作(下发异步任务);任务完成后可通过 `event_log_download` 获取结果 + 是否任务型: 是 + - threat_distribution + 用途: 查询攻击事件分布类型(Top 攻击类型) + 必填: `from`、`to` + 常用: `from`、`to`、`hosts` + 风险提示: 只读查询接口 + 是否任务型: 否 + - top_url + 用途: 查询事件日志中的 Top 被攻击 URL + 必填: `from`、`to` + 常用: `from`、`to`、`hosts`、`top` + 风险提示: 只读查询接口 + 是否任务型: 否 + - top_source_ip + 用途: 查询 Top 攻击源 IP + 必填: `from`、`to` + 常用: `from`、`to`、`hosts`、`top` + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - event_list + - event_show + - event_log_download + - event_export_job + - threat_distribution + - top_url + - top_source_ip + eventid: + type: string + description: 攻击事件 ID + from: + type: integer + description: > + 查询开始时间,Unix 毫秒级时间戳。 + 可使用 Python datetime 动态计算,例如: + int(datetime.now().timestamp() * 1000) - 3600000(最近 1 小时) + to: + type: integer + description: 查询结束时间,Unix 毫秒级时间戳,必须大于 from + hosts: + type: array + items: + type: string + description: 防护域名 ID 列表(过滤用) + attacks: + type: array + items: + type: string + description: 攻击类型列表(如 `sqli`、`xss`、`cmdi`、`cc` 等) + action: + type: string + description: 防护动作筛选,`block`(拦截)或 `log`(仅记录) + top: + type: integer + description: 返回 Top N 条数,默认 5 + default: 5 + page: + type: integer + description: 页码,从 1 开始 + default: 1 + pagesize: + type: integer + description: 每页数量,最大 100 + default: 10 + enterprise_project_id: + type: string + description: 企业项目 ID(可选) + required: + - action +handler: + type: script + script_file: hw_waf.handler.py + function: event diff --git a/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_host.yaml b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_host.yaml new file mode 100644 index 000000000..620fa7a36 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_host.yaml @@ -0,0 +1,148 @@ +name: hw_waf_host +description: > + Huawei Cloud WAF protected domain management tool. Use the `action` + parameter to query, create, update, or delete cloud mode and dedicated + mode protected domains. +description_cn: > + 华为云 WAF 防护域名管理工具。通过 `action` 参数调用云模式和独享模式防护域名 + 的查询、创建、更新和删除等接口。 + 请在 WAF 服务配置中填写 AK/SK(或 Token)、Region 和 Project ID。 +category: custom +enabled: true +requires_confirmation: true +provider: huaweicloud_waf_api +inputSchema: + type: object + properties: + action: + type: string + description: | + 防护域名管理动作名,可选值: + - host_list + 用途: 查询云模式防护域名列表 + 必填: 无 + 常用: `page`、`pagesize`、`hostname`、`policyname` + 风险提示: 只读查询接口 + 是否任务型: 否 + - host_show + 用途: 根据域名 ID 查询云模式防护域名详情 + 必填: `instance_id` + 常用: `instance_id` + 风险提示: 只读查询接口 + 是否任务型: 否 + - host_create + 用途: 创建云模式防护域名 + 必填: `hostname`、`proxy`、`server` + 常用: `hostname`、`proxy`、`server`、`certificateid` + 风险提示: 写操作;会新增防护域名接入 + 是否任务型: 否 + - host_update + 用途: 更新云模式防护域名配置 + 必填: `instance_id` + 常用: `instance_id`、`proxy`、`server`、`certificateid`、`protect_status` + 风险提示: 写操作;会修改域名防护配置 + 是否任务型: 否 + - host_delete + 用途: 删除云模式防护域名 + 必填: `instance_id` + 常用: `instance_id` + 风险提示: 高风险写操作;删除后防护立即失效 + 是否任务型: 否 + - host_update_protect_status + 用途: 修改域名防护状态(开启/关闭/观察) + 必填: `instance_id`、`protect_status` + 常用: `instance_id`、`protect_status` + 风险提示: 写操作;会影响域名流量防护状态 + 是否任务型: 否 + - premium_host_list + 用途: 查询独享模式防护域名列表 + 必填: 无 + 常用: `page`、`pagesize`、`hostname`、`policyname` + 风险提示: 只读查询接口 + 是否任务型: 否 + - premium_host_show + 用途: 查看独享模式域名配置详情 + 必填: `instance_id` + 常用: `instance_id` + 风险提示: 只读查询接口 + 是否任务型: 否 + - premium_host_create + 用途: 创建独享模式防护域名 + 必填: `hostname`、`proxy`、`server`、`web_tag` + 常用: `hostname`、`proxy`、`server`、`web_tag`、`certificateid` + 风险提示: 写操作;会新增独享模式域名接入 + 是否任务型: 否 + - premium_host_update + 用途: 修改独享模式域名配置 + 必填: `instance_id` + 常用: `instance_id`、`proxy`、`server`、`protect_status` + 风险提示: 写操作;会修改独享域名防护配置 + 是否任务型: 否 + - premium_host_delete + 用途: 删除独享模式防护域名 + 必填: `instance_id` + 常用: `instance_id` + 风险提示: 高风险写操作;删除后防护立即失效 + 是否任务型: 否 + - composite_host_list + 用途: 查询全部(云模式+独享模式)防护域名列表 + 必填: 无 + 常用: `page`、`pagesize`、`hostname` + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - host_list + - host_show + - host_create + - host_update + - host_delete + - host_update_protect_status + - premium_host_list + - premium_host_show + - premium_host_create + - premium_host_update + - premium_host_delete + - composite_host_list + instance_id: + type: string + description: 防护域名 ID + hostname: + type: string + description: 域名(查询时为模糊匹配关键词,创建时为精确域名) + policyname: + type: string + description: 防护策略名称(查询过滤用) + proxy: + type: boolean + description: 是否使用代理(true/false) + server: + type: array + items: + type: object + description: 源站服务器配置列表,每项含 `address`、`port`、`type`、`weight` 等字段 + certificateid: + type: string + description: 证书 ID(HTTPS 域名必填) + protect_status: + type: integer + description: 防护状态,`0` 关闭防护、`1` 开启防护、`-1` 观察模式 + web_tag: + type: string + description: 独享模式域名标签(创建独享域名时必填) + page: + type: integer + description: 页码,从 1 开始 + default: 1 + pagesize: + type: integer + description: 每页数量 + default: 10 + enterprise_project_id: + type: string + description: 企业项目 ID(可选,覆盖配置中的默认值) + required: + - action +handler: + type: script + script_file: hw_waf.handler.py + function: host diff --git a/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_overview.yaml b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_overview.yaml new file mode 100644 index 000000000..ba43e8b20 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_overview.yaml @@ -0,0 +1,110 @@ +name: hw_waf_overview +description: > + Huawei Cloud WAF security overview and statistics tool. Use the `action` + parameter to query QPS data, bandwidth statistics, attack counts, top + attacked domains, and security report subscriptions. +description_cn: > + 华为云 WAF 安全总览与统计工具。通过 `action` 参数查询 QPS 数据、带宽统计、 + 攻击防护次数、Top 被攻击域名,以及安全报告订阅管理等接口。 +category: custom +enabled: true +requires_confirmation: true +provider: huaweicloud_waf_api +inputSchema: + type: object + properties: + action: + type: string + description: | + 安全总览动作名,可选值: + - overview_statistics + 用途: 查询安全总览请求与攻击数量汇总 + 必填: `from`、`to` + 常用: `from`、`to`、`hosts` + 风险提示: 只读查询接口 + 是否任务型: 否 + - overview_qps + 用途: 查询 QPS 时序数据(每 5 分钟一个数据点) + 必填: `from`、`to` + 常用: `from`、`to`、`hosts` + 风险提示: 只读查询接口 + 是否任务型: 否 + - overview_bandwidth + 用途: 查询带宽时序数据 + 必填: `from`、`to` + 常用: `from`、`to`、`hosts` + 风险提示: 只读查询接口 + 是否任务型: 否 + - overview_top_domains + 用途: 查询 Top 受攻击域名排行 + 必填: `from`、`to` + 常用: `from`、`to`、`top` + 风险提示: 只读查询接口 + 是否任务型: 否 + - overview_attack_types + 用途: 查询攻击类型分布 + 必填: `from`、`to` + 常用: `from`、`to`、`hosts` + 风险提示: 只读查询接口 + 是否任务型: 否 + - overview_top_ip + 用途: 查询 Top 攻击源 IP + 必填: `from`、`to` + 常用: `from`、`to`、`hosts`、`top` + 风险提示: 只读查询接口 + 是否任务型: 否 + - overview_top_url + 用途: 查询 Top 被攻击 URL + 必填: `from`、`to` + 常用: `from`、`to`、`hosts`、`top` + 风险提示: 只读查询接口 + 是否任务型: 否 + - overview_response_code + 用途: 查询响应码时序统计数据 + 必填: `from`、`to` + 常用: `from`、`to`、`hosts` + 风险提示: 只读查询接口 + 是否任务型: 否 + - console_config + 用途: 查询当前局点支持的 WAF 特性配置 + 必填: 无 + 常用: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - overview_statistics + - overview_qps + - overview_bandwidth + - overview_top_domains + - overview_attack_types + - overview_top_ip + - overview_top_url + - overview_response_code + - console_config + from: + type: integer + description: > + 查询开始时间,Unix 毫秒级时间戳。 + 可使用 Python datetime 动态计算: + int(datetime.now().timestamp() * 1000) - 86400000(最近 24 小时) + to: + type: integer + description: 查询结束时间,Unix 毫秒级时间戳,必须大于 from + hosts: + type: array + items: + type: string + description: 防护域名 ID 列表(过滤用,不传则汇总所有域名) + top: + type: integer + description: 返回 Top N 条数,默认 5 + default: 5 + enterprise_project_id: + type: string + description: 企业项目 ID(可选) + required: + - action +handler: + type: script + script_file: hw_waf.handler.py + function: overview diff --git a/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_policy.yaml b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_policy.yaml new file mode 100644 index 000000000..fb04016f7 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huaweicloud_waf_v39/hw_waf_policy.yaml @@ -0,0 +1,207 @@ +name: hw_waf_policy +description: > + Huawei Cloud WAF protection policy management tool. Use the `action` + parameter to query, create, update, or delete protection policies, and + manage policy rules (CC rules, custom rules, blacklist/whitelist, etc.). +description_cn: > + 华为云 WAF 防护策略管理工具。通过 `action` 参数调用防护策略的查询、创建、 + 更新和删除接口,以及 CC 规则、精准防护规则、黑白名单、地理位置控制等规则管理接口。 +category: custom +enabled: true +requires_confirmation: true +provider: huaweicloud_waf_api +inputSchema: + type: object + properties: + action: + type: string + description: | + 防护策略管理动作名,可选值: + - policy_list + 用途: 查询防护策略列表 + 必填: 无 + 常用: `page`、`pagesize`、`name` + 风险提示: 只读查询接口 + 是否任务型: 否 + - policy_show + 用途: 根据 ID 查询防护策略详情 + 必填: `policy_id` + 常用: `policy_id` + 风险提示: 只读查询接口 + 是否任务型: 否 + - policy_create + 用途: 创建防护策略 + 必填: `name` + 常用: `name`、`level`、`full_detection` + 风险提示: 写操作;会新增防护策略 + 是否任务型: 否 + - policy_update + 用途: 更新防护策略 + 必填: `policy_id` + 常用: `policy_id`、`name`、`level`、`full_detection`、`options` + 风险提示: 写操作;会修改防护策略参数 + 是否任务型: 否 + - policy_delete + 用途: 删除防护策略 + 必填: `policy_id` + 常用: `policy_id` + 风险提示: 高风险写操作;删除后绑定该策略的域名需重新绑定 + 是否任务型: 否 + - policy_update_hosts + 用途: 更新防护策略绑定的域名列表 + 必填: `policy_id`、`hosts` + 常用: `policy_id`、`hosts` + 风险提示: 写操作;会变更策略与域名的绑定关系 + 是否任务型: 否 + - cc_rule_list + 用途: 查询 CC 防护规则列表 + 必填: `policy_id` + 常用: `policy_id`、`page`、`pagesize` + 风险提示: 只读查询接口 + 是否任务型: 否 + - cc_rule_create + 用途: 创建 CC 防护规则 + 必填: `policy_id`、`url`、`limit_num`、`limit_period`、`lock_time`、`tag_type`、`action` + 常用: `policy_id`、`url`、`limit_num`、`limit_period`、`lock_time`、`tag_type`、`action` + 风险提示: 写操作;会新增 CC 规则 + 是否任务型: 否 + - cc_rule_delete + 用途: 删除 CC 防护规则 + 必填: `policy_id`、`rule_id` + 常用: `policy_id`、`rule_id` + 风险提示: 写操作;会删除 CC 规则 + 是否任务型: 否 + - custom_rule_list + 用途: 查询精准防护规则列表 + 必填: `policy_id` + 常用: `policy_id`、`page`、`pagesize` + 风险提示: 只读查询接口 + 是否任务型: 否 + - custom_rule_create + 用途: 创建精准防护规则 + 必填: `policy_id`、`name`、`conditions`、`action` + 常用: `policy_id`、`name`、`conditions`、`action`、`priority` + 风险提示: 写操作;会新增精准防护规则 + 是否任务型: 否 + - custom_rule_delete + 用途: 删除精准防护规则 + 必填: `policy_id`、`rule_id` + 常用: `policy_id`、`rule_id` + 风险提示: 写操作;会删除精准防护规则 + 是否任务型: 否 + - whiteblackip_rule_list + 用途: 查询黑白名单规则列表 + 必填: `policy_id` + 常用: `policy_id`、`page`、`pagesize` + 风险提示: 只读查询接口 + 是否任务型: 否 + - whiteblackip_rule_create + 用途: 创建黑白名单规则(IP/IP 段 允许/拦截) + 必填: `policy_id`、`addr`、`white` + 常用: `policy_id`、`addr`、`white`、`description` + 风险提示: 写操作;会新增 IP 黑白名单规则,影响访问控制 + 是否任务型: 否 + - whiteblackip_rule_delete + 用途: 删除黑白名单规则 + 必填: `policy_id`、`rule_id` + 常用: `policy_id`、`rule_id` + 风险提示: 写操作;会删除 IP 黑白名单规则 + 是否任务型: 否 + - geoip_rule_list + 用途: 查询地理位置访问控制规则列表 + 必填: `policy_id` + 常用: `policy_id`、`page`、`pagesize` + 风险提示: 只读查询接口 + 是否任务型: 否 + enum: + - policy_list + - policy_show + - policy_create + - policy_update + - policy_delete + - policy_update_hosts + - cc_rule_list + - cc_rule_create + - cc_rule_delete + - custom_rule_list + - custom_rule_create + - custom_rule_delete + - whiteblackip_rule_list + - whiteblackip_rule_create + - whiteblackip_rule_delete + - geoip_rule_list + policy_id: + type: string + description: 防护策略 ID + rule_id: + type: string + description: 规则 ID + name: + type: string + description: 策略名称(查询时为模糊匹配关键词,创建时为精确名称) + level: + type: integer + description: 防护等级,`1` 宽松、`2` 中等(默认)、`3` 严格 + full_detection: + type: boolean + description: 是否开启全检测模式 + options: + type: object + description: 防护开关选项对象(如 webattack、cc、custom 等开关) + hosts: + type: array + items: + type: object + description: 域名列表,每项含 `id`(域名 ID)字段 + url: + type: string + description: CC 规则匹配的 URL 路径 + limit_num: + type: integer + description: CC 规则限制请求数 + limit_period: + type: integer + description: CC 规则统计时间窗(秒) + lock_time: + type: integer + description: CC 规则封禁时长(秒) + tag_type: + type: string + description: CC 规则标签类型,如 `ip`、`cookie`、`header` 等 + action: + type: object + description: 规则匹配后的动作,含 `category`(`block`/`pass`/`log`)等字段 + conditions: + type: array + items: + type: object + description: 精准防护规则条件列表 + priority: + type: integer + description: 精准防护规则优先级(数字越小优先级越高) + addr: + type: string + description: 黑白名单 IP 地址或 CIDR 段 + white: + type: integer + description: 白名单标识,`0` 黑名单(拦截)、`1` 白名单(放行) + description: + type: string + description: 规则备注描述 + page: + type: integer + description: 页码,从 1 开始 + default: 1 + pagesize: + type: integer + description: 每页数量 + default: 10 + enterprise_project_id: + type: string + description: 企业项目 ID(可选) + required: + - action +handler: + type: script + script_file: hw_waf.handler.py + function: policy diff --git a/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/_provider.yaml b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/_provider.yaml new file mode 100644 index 000000000..0a49d2bde --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/_provider.yaml @@ -0,0 +1,50 @@ +name: huorong_edr +vendor: huorong +service_id: huorong_api +version: "1.0" +integration_type: device +description: > + Huorong EDR (Endpoint Detection & Response) API service. Configure the + Access Key ID, Secret Key, and Base URL in the credentials form. + Base URL should point to your Huorong console (e.g., `http://192.168.1.100:801`). +description_cn: > + 火绒终端安全管理系统 API 服务。在配置页分别填写 Access Key ID(secret_id)、 + Secret Key(secret_key)和 Base URL(控制台地址,例如 `http://192.168.1.100:801`)。 +auth: + type: custom + secret: huorong_secret_id + secret_secret: huorong_secret_key +credential_fields: + - key: secret_id + label: Access Key ID + storage: secret + config_key: secretId + secret_id: huorong_secret_id + input_type: password + required: true + - key: secret_key + label: Secret Key + storage: secret + config_key: secretKey + secret_id: huorong_secret_key + input_type: password + required: true + - key: base_url + label: Base URL + storage: config + config_key: base_url + input_type: url + default: "http://localhost:801" +defaults: + base_url: "http://localhost:801" + timeout: 30 + category: custom + product_version: "1.0" +notes: | + 火绒 API 使用 HMAC-SHA1 请求签名认证。 + Authorization 请求头格式:HRESS{secret_id}:{expires}:{url_encoded_sign} + 签名字符串:{secret_id}\n{expires}\n{METHOD}\n{content_md5}\n{canonical_resource} + - expires:Unix 秒级时间戳 + 3600 + - content_md5:base64(md5(request_body)) + - canonical_resource:请求路径去掉开头的 "/" + 也支持 URL 查询参数方式:?ak={secret_id}&expires={expires}&sign={sign} diff --git a/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/_test.yaml b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/_test.yaml new file mode 100644 index 000000000..1747864f9 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/_test.yaml @@ -0,0 +1,49 @@ +schema_version: 1 +provider: huorong_api + +# Service-level connectivity probe. +# `group_list` is a lightweight read-only call with no required params. +connectivity: + tool: huorong_group + params: + action: group_list + +# Tool-level test samples shown in the WebUI ToolDetailDrawer drop-down. +fixtures: + huorong_group: + - label: "List all groups" + label_cn: "获取全部分组列表" + tags: [smoke] + params: + action: group_list + assert: + success: true + + huorong_clnts: + - label: "List online endpoints" + label_cn: "查询上线终端列表" + tags: [smoke] + params: + action: clnts_online + offset: 0 + assert: + success: true + + - label: "List endpoint details (page 1)" + label_cn: "查询终端详情(第 1 页)" + tags: [smoke] + params: + action: clnts_list + offset: 0 + assert: + success: true + + huorong_task: + - label: "Create virus scan task" + label_cn: "创建查杀扫描任务" + tags: [smoke] + params: + action: task_create + offset: 0 + assert: + success: true diff --git a/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong.handler.py b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong.handler.py new file mode 100644 index 000000000..3168dfdbc --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong.handler.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import os +import time +import urllib.parse as up +from typing import Any, Optional + +import aiohttp + +from flocks.config.config_writer import ConfigWriter +from flocks.tool.registry import ToolContext, ToolResult + + +DEFAULT_BASE_URL = "http://localhost:801" +DEFAULT_TIMEOUT = 30 +SERVICE_ID = "huorong_api" + + +def _get_secret_manager(): + from flocks.security import get_secret_manager + + return get_secret_manager() + + +def _resolve_ref(value: Any) -> Optional[str]: + if value is None: + return None + if not isinstance(value, str): + return str(value) + if value.startswith("{secret:") and value.endswith("}"): + return _get_secret_manager().get(value[len("{secret:"):-1]) + if value.startswith("{env:") and value.endswith("}"): + return os.getenv(value[len("{env:"):-1]) + return value + + +def _service_config() -> dict[str, Any]: + raw = ConfigWriter.get_api_service_raw(SERVICE_ID) + return raw if isinstance(raw, dict) else {} + + +def _resolve_verify_ssl(raw: dict[str, Any]) -> bool: + value = raw.get("verify_ssl") + if value is None: + value = raw.get("ssl_verify") + if value is None: + custom_settings = raw.get("custom_settings", {}) + if isinstance(custom_settings, dict): + value = custom_settings.get("verify_ssl", False) + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.strip().lower() in {"1", "true", "yes", "on"} + return bool(value) + + +def _resolve_runtime_config() -> tuple[str, int, str, str, bool]: + raw = _service_config() + base_url = ( + _resolve_ref(raw.get("base_url")) + or _resolve_ref(raw.get("baseUrl")) + or DEFAULT_BASE_URL + ).rstrip("/") + timeout = raw.get("timeout", DEFAULT_TIMEOUT) + try: + timeout = int(timeout) + except (TypeError, ValueError): + timeout = DEFAULT_TIMEOUT + + secret_manager = _get_secret_manager() + + secret_id = ( + _resolve_ref(raw.get("secretId")) + or _resolve_ref(raw.get("secret_id")) + or secret_manager.get("huorong_secret_id") + or os.getenv("HUORONG_SECRET_ID") + ) + secret_key = ( + _resolve_ref(raw.get("secretKey")) + or _resolve_ref(raw.get("secret_key")) + or secret_manager.get("huorong_secret_key") + or os.getenv("HUORONG_SECRET_KEY") + ) + if not secret_id or not secret_key: + raise ValueError( + "Huorong API credentials not found. Configure secretId and secretKey " + "in the huorong_api service settings." + ) + return base_url, timeout, secret_id, secret_key, _resolve_verify_ssl(raw) + + +def _build_auth_header( + secret_id: str, + secret_key: str, + method: str, + path: str, + body: str, +) -> str: + """ + Build Huorong HMAC-SHA1 Authorization header. + + Authorization = HRESS{secret_id}:{expires}:{url_encoded_sign} + StringToSign = {secret_id}\\n{expires}\\n{method}\\n{content_md5}\\n{canonical_resource} + content_md5 = base64(md5(body)) + canonical_resource = path without leading "/" + sign = url_encode(base64(hmac-sha1(secret_key, string_to_sign))) + """ + expires = int(time.time()) + 3600 + body_bytes = body.encode("utf-8") if isinstance(body, str) else body + content_md5 = base64.b64encode( + hashlib.md5(body_bytes).digest() + ).decode("utf-8") + canonical_resource = path.lstrip("/") + string_to_sign = f"{secret_id}\n{expires}\n{method}\n{content_md5}\n{canonical_resource}" + sign_bytes = hmac.new( + secret_key.encode("utf-8"), + string_to_sign.encode("utf-8"), + "sha1", + ).digest() + sign = up.quote(base64.b64encode(sign_bytes).decode("utf-8")) + return f"HRESS{secret_id}:{expires}:{sign}" + + +async def _post( + base_url: str, + path: str, + body: dict[str, Any], + secret_id: str, + secret_key: str, + timeout: int, + verify_ssl: bool, +) -> ToolResult: + body_str = json.dumps(body, ensure_ascii=False) + auth_header = _build_auth_header(secret_id, secret_key, "POST", path, body_str) + url = f"{base_url}{path}" + connector = aiohttp.TCPConnector(ssl=verify_ssl) + async with aiohttp.ClientSession(connector=connector) as session: + async with session.post( + url, + data=body_str.encode("utf-8"), + headers={ + "Content-Type": "application/json;charset=UTF-8", + "Authorization": auth_header, + }, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as resp: + resp_text = await resp.text() + try: + resp_json = json.loads(resp_text) + except Exception: + resp_json = {"raw": resp_text} + if resp.status >= 400: + return ToolResult( + success=False, + data=resp_json, + error=f"HTTP {resp.status}: {resp_text[:200]}", + ) + return ToolResult(success=True, data=resp_json) + + +def _pick(params: dict[str, Any], *keys: str) -> dict[str, Any]: + return {k: params[k] for k in keys if k in params and params[k] is not None} + + +# --------------------------------------------------------------------------- +# Public handler functions (one per tool YAML's `function` field) +# --------------------------------------------------------------------------- + + +async def group(params: dict[str, Any], ctx: ToolContext) -> ToolResult: + base_url, timeout, secret_id, secret_key, verify_ssl = _resolve_runtime_config() + action = params.get("action", "") + + if action == "group_list": + return await _post(base_url, "/api/group/_list", {}, secret_id, secret_key, timeout, verify_ssl) + + if action == "group_create": + body = _pick(params, "group_name", "parent_group") + body.setdefault("parent_group", 0) + return await _post(base_url, "/api/group/_create", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "group_rename": + body = _pick(params, "group_id", "group_name") + return await _post(base_url, "/api/group/_rename", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "group_delete": + body = _pick(params, "group_id") + return await _post(base_url, "/api/group/_delete", body, secret_id, secret_key, timeout, verify_ssl) + + return ToolResult(success=False, error=f"Unknown action: {action}") + + +async def clnts(params: dict[str, Any], ctx: ToolContext) -> ToolResult: + base_url, timeout, secret_id, secret_key, verify_ssl = _resolve_runtime_config() + action = params.get("action", "") + + if action == "clnts_online": + body = _pick(params, "offset") + body.setdefault("offset", 0) + return await _post(base_url, "/api/clnts/_online", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "clnts_list": + body = _pick(params, "offset") + body.setdefault("offset", 0) + return await _post(base_url, "/api/clnts/_list", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "clnts_info": + body = _pick(params, "clients") + return await _post(base_url, "/api/clnts/_info", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "clnts_info2": + body = _pick(params, "clients") + return await _post(base_url, "/api/clnts/_info2", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "clnts_rename": + body = _pick(params, "client_id", "client_name") + return await _post(base_url, "/api/clnts/_rename", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "clnts_group": + body = _pick(params, "group_id", "clients") + return await _post(base_url, "/api/clnts/_group", body, secret_id, secret_key, timeout, verify_ssl) + + if action == "clnts_leak": + body = _pick(params, "clients") + return await _post(base_url, "/api/clnts/_leak", body, secret_id, secret_key, timeout, verify_ssl) + + return ToolResult(success=False, error=f"Unknown action: {action}") + + +async def task(params: dict[str, Any], ctx: ToolContext) -> ToolResult: + base_url, timeout, secret_id, secret_key, verify_ssl = _resolve_runtime_config() + action = params.get("action", "") + + if action == "task_create": + body = _pick(params, "offset") + body.setdefault("offset", 0) + return await _post(base_url, "/api/task/_create", body, secret_id, secret_key, timeout, verify_ssl) + + return ToolResult(success=False, error=f"Unknown action: {action}") diff --git a/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_clnts.yaml b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_clnts.yaml new file mode 100644 index 000000000..8b8838531 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_clnts.yaml @@ -0,0 +1,94 @@ +name: huorong_clnts +description: > + Huorong EDR endpoint (client) management tool. Use the `action` parameter + to query online status, list details, rename, move groups, or check for + vulnerabilities. +description_cn: > + 火绒终端(客户端)管理工具。通过 `action` 参数调用终端上线查询、详情列表、 + 重命名、分组变更和高危漏洞查询等接口。 + 请在火绒服务配置中分别填写 Access Key ID、Secret Key 和 Base URL。 +category: custom +enabled: true +requires_confirmation: true +provider: huorong_api +inputSchema: + type: object + properties: + action: + type: string + description: | + 终端管理动作名,可选值: + - clnts_online + 用途: 查询当前在线终端列表(分页) + 必填: 无 + 常用: `offset` + 风险提示: 只读查询接口 + 是否任务型: 否 + - clnts_list + 用途: 查询终端详情列表(分页) + 必填: 无 + 常用: `offset` + 风险提示: 只读查询接口 + 是否任务型: 否 + - clnts_info + 用途: 查询指定终端详细信息(v1) + 必填: `clients` + 常用: `clients` + 风险提示: 只读查询接口;`clients` 为终端 client_id 列表 + 是否任务型: 否 + - clnts_info2 + 用途: 查询指定终端详细信息(v2,字段更丰富) + 必填: `clients` + 常用: `clients` + 风险提示: 只读查询接口;`clients` 为终端 client_id 列表 + 是否任务型: 否 + - clnts_rename + 用途: 修改终端名称 + 必填: `client_id`、`client_name` + 常用: `client_id`、`client_name` + 风险提示: 写操作;会直接修改终端显示名称 + 是否任务型: 否 + - clnts_group + 用途: 修改终端所属分组 + 必填: `group_id`、`clients` + 常用: `group_id`、`clients` + 风险提示: 写操作;会将终端移入指定分组 + 是否任务型: 否 + - clnts_leak + 用途: 查询存在高危漏洞未修复的终端 + 必填: `clients` + 常用: `clients` + 风险提示: 只读查询接口;返回存在高危漏洞的终端信息 + 是否任务型: 否 + enum: + - clnts_online + - clnts_list + - clnts_info + - clnts_info2 + - clnts_rename + - clnts_group + - clnts_leak + offset: + type: integer + description: 分页偏移量,从 0 开始 + default: 0 + clients: + type: array + items: + type: string + description: 终端 client_id 列表(40 字符十六进制字符串) + client_id: + type: string + description: 单个终端 client_id(40 字符十六进制字符串) + client_name: + type: string + description: 终端新名称 + group_id: + type: integer + description: 目标分组 ID + required: + - action +handler: + type: script + script_file: huorong.handler.py + function: clnts diff --git a/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_group.yaml b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_group.yaml new file mode 100644 index 000000000..6c0c2f726 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_group.yaml @@ -0,0 +1,63 @@ +name: huorong_group +description: > + Huorong EDR group management tool. Use the `action` parameter to list, + create, rename, or delete endpoint groups. +description_cn: > + 火绒终端分组管理工具。通过 `action` 参数调用分组查询、创建、重命名和删除接口。 + 请在火绒服务配置中分别填写 Access Key ID、Secret Key 和 Base URL。 +category: custom +enabled: true +requires_confirmation: true +provider: huorong_api +inputSchema: + type: object + properties: + action: + type: string + description: | + 分组管理动作名,可选值: + - group_list + 用途: 获取全部分组列表 + 必填: 无 + 常用: 无 + 风险提示: 只读查询接口 + 是否任务型: 否 + - group_create + 用途: 创建分组 + 必填: `group_name` + 常用: `parent_group` + 风险提示: 写操作;会新增分组 + 是否任务型: 否 + - group_rename + 用途: 修改分组名称 + 必填: `group_id`、`group_name` + 常用: `group_id`、`group_name` + 风险提示: 写操作;会修改已有分组名称 + 是否任务型: 否 + - group_delete + 用途: 删除分组 + 必填: `group_id` + 常用: `group_id` + 风险提示: 高风险写操作;删除后分组及归属终端均受影响 + 是否任务型: 否 + enum: + - group_list + - group_create + - group_rename + - group_delete + group_id: + type: integer + description: 分组 ID + group_name: + type: string + description: 分组名称 + parent_group: + type: integer + description: 父分组 ID,默认为 0(根分组) + default: 0 + required: + - action +handler: + type: script + script_file: huorong.handler.py + function: group diff --git a/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_task.yaml b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_task.yaml new file mode 100644 index 000000000..941449302 --- /dev/null +++ b/.flocks/flockshub/plugins/tools/device/huorong_edr_v1_0/huorong_task.yaml @@ -0,0 +1,36 @@ +name: huorong_task +description: > + Huorong EDR task management tool. Use the `action` parameter to create + virus scan tasks and query task results. +description_cn: > + 火绒任务管理工具。通过 `action` 参数调用查杀扫描任务创建和任务结果查询接口。 + 请在火绒服务配置中分别填写 Access Key ID、Secret Key 和 Base URL。 +category: custom +enabled: true +requires_confirmation: true +provider: huorong_api +inputSchema: + type: object + properties: + action: + type: string + description: | + 任务管理动作名,可选值: + - task_create + 用途: 创建查杀扫描任务 + 必填: 无(可传 offset 控制分页) + 常用: `offset` + 风险提示: 写操作;会向指定终端下发扫描任务 + 是否任务型: 是 + enum: + - task_create + offset: + type: integer + description: 分页偏移量,从 0 开始 + default: 0 + required: + - action +handler: + type: script + script_file: huorong.handler.py + function: task diff --git a/.flocks/plugins/agents/device-inspector/agent.yaml b/.flocks/plugins/agents/device-inspector/agent.yaml new file mode 100644 index 000000000..52dff1030 --- /dev/null +++ b/.flocks/plugins/agents/device-inspector/agent.yaml @@ -0,0 +1,42 @@ +name: device-inspector +description: >- + Generic device inspection agent for connected security devices. Discovers enabled + devices through device_context, dynamically finds the right device tools, and + performs read-only health and status inspections with precise metrics when available. +description_cn: >- + 通用设备巡检 Agent:先通过 device_context 识别已接入且已启用的安全设备,再动态发现对应工具, + 执行只读状态/健康巡检并输出尽可能精确的指标。 +mode: subagent +strategy: react +delegatable: true +hidden: false +tags: + - security + - device-inspection +prompt_file: prompt.md +color: "#3498DB" +temperature: 0.2 +tools: + - read + - glob + - grep + - bash + - todoread + - todowrite + - tool_search + - device_context + - run_slash_command + - skill_load +prompt_metadata: + category: security + cost: medium + triggers: + - domain: device-inspection + trigger: "需要巡检已接入设备的运行状态、资源使用率、节点健康、接口状态或版本信息时" + use_when: + - 用户要求查看已接入安全设备的系统状态、健康状态或资源指标 + - 用户要求对单台、多台或整组设备做巡检、体检、状态汇总 + - 需要先确认设备名称与 device_id,再调用对应设备工具 + avoid_when: + - 需要修改设备配置、重启设备或执行其他高风险写操作 + - 需要深入分析安全告警、事件、样本或主机取证,应使用更专门的 agent diff --git a/.flocks/plugins/agents/device-inspector/prompt.md b/.flocks/plugins/agents/device-inspector/prompt.md new file mode 100644 index 000000000..8ef2603c9 --- /dev/null +++ b/.flocks/plugins/agents/device-inspector/prompt.md @@ -0,0 +1,179 @@ +You are a **Generic Device Inspection Agent**. + +## Mission + +你负责对当前**已接入且已启用**的安全设备执行通用巡检,输出设备状态、资源使用率、节点健康、接口状态、版本信息等巡检结果。 + +你的目标不是绑定某个固定客户环境,而是: + +1. 先识别当前有哪些设备可巡检 +2. 再锁定目标设备和 `device_id` +3. 动态发现该设备可调用的工具 +4. 仅执行**只读巡检** +5. 用**具体数值**和明确证据输出结果 + +## Working Language + +- 默认使用与用户相同的语言回复 +- 用户用中文时,全部输出中文 + +## Mandatory Workflow + +### Step 1: 设备发现(强制) + +在任何巡检任务开始前,**先调用 `device_context`**,不要猜设备名、机房、`device_id` 或工具前缀。 + +你必须从 `device_context` 结果中确认: + +- 机房名称 +- 设备名称 +- `device_id` +- `tool_set_id` +- 厂商 / 产品线 +- 设备是否启用 + +执行规则: + +- 如果没有任何已启用设备,直接说明“当前没有可巡检的已接入设备” +- 如果用户指定了设备名、机房名或产品名,先在 `device_context` 结果中做匹配 +- 如果同类设备有多个候选而用户没有明确指定,**不要猜测**,必须列出候选并让用户选择 +- 如果用户要求“巡检全部设备”或“巡检某机房所有设备”,基于 `device_context` 的结果批量处理 + +### Step 2: 工具发现(强制) + +锁定目标设备后,再调用 `tool_search` 动态发现该设备对应的巡检工具。 + +优先组合以下关键词检索: + +- `tool_set_id` +- 设备名 +- 厂商名 +- 产品名 +- 巡检意图关键词:`status`、`system`、`dashboard`、`monitor`、`health`、`resource`、`cpu`、`memory`、`disk`、`interface`、`version` + +工具选择原则: + +- 优先选择只读的状态、概览、监控、系统运行类工具 +- 优先选择能直接返回结构化指标的工具 +- 如果存在 grouped tool,优先选择最贴近用户意图的 action +- 调用前必须以**当前 callable schema** 为准,参数名必须逐字匹配 schema +- 只调用已经进入当前 callable schema 的工具 +- **每次设备工具调用都必须显式传入 `device_id`** + +### Step 3: Skill 闸门(强制) + +涉及下列产品时,必须先读取并遵循对应 skill,再调用相关工具: + +| 产品 / 关键词 | 必须先加载的 skill | +|---|---| +| TDP / 微步 NDR | `tdp-use` | +| OneSEC / OneDNS / 微步 EDR | `onesec-use` | +| OneSIG / SIG / 安全互联网网关 | `onesig-use` | +| SkyEye / 天眼 / 网神分析平台 | `skyeye-use` | +| 青藤 / 青藤云安全 | `qingteng-use` | +| 深信服 EDR | `sangfor-edr-use` | +| 深信服 XDR | `sangfor-xdr-use` | + +硬性要求: +- 如果设备不在上述 skill 列表,可以运行 `run_slash_command(command="skills")` 查看可用 skills +- 在未阅读对应 skill 并完成模式判断前,**不要直接调用**该产品的专用工具 +- 若 skill 要求优先 API,则先走 API +- 若 skill 说明该场景必须走浏览器或人工登录,但当前能力不足,明确告知调用方需要补充的条件,不要编造结果 + +### Step 4: 巡检执行 + +默认以**最小必要集**完成巡检,避免无关查询。 + +优先采集以下信息,能拿到多少就采多少: + +- CPU 使用率 +- 内存使用率 / 容量 +- 磁盘使用率 / 容量 +- 网络吞吐 / 上下行流量 +- 节点状态 / 集群状态 / 服务状态 +- 接口状态 +- 系统运行时长 +- 版本信息 +- 告警中的设备健康摘要(如果用户明确要求) + +执行原则: + +- 用户只指定单台设备时,只巡检该设备 +- 用户要求整批巡检时,按设备逐台执行 +- 默认**串行**巡检,不并行猜测,避免共享 session 或上下文串台 +- 某项指标缺失时,明确写“未提供”或“不适用” +- 某次调用失败时,保留失败原因并继续执行后续可行检查 + +### Step 5: 时间参数规则 + +若目标工具涉及时间参数,必须按对应 skill 或工具 schema 的要求处理: + +- 需要动态时间范围时,可用 `bash` 执行只读的 `uv run python` 仅计算时间参数 +- 禁止手写、猜测或硬编码时间戳 +- 秒级 / 毫秒级 / 日期区间字符串,均以实际 schema 为准 + +## Output Requirements + +输出必须**可审计、可复核、不可含糊**。 + +### 建议输出结构 + +#### 1. 巡检范围 + +- 用户请求的范围 +- 实际命中的设备 +- 未命中的设备或需要用户确认的候选 + +#### 2. 逐设备结果 + +每台设备至少包含: + +- 设备名称 +- 机房 +- `device_id` +- `tool_set_id` +- 巡检结论 + +并优先给出如下表格: + +| 指标 | 数值 | 来源工具 | 备注 | +|---|---|---|---| + +#### 3. 异常项 + +列出: + +- 超阈值资源指标 +- 关键服务异常 +- 节点异常 / 接口异常 +- 因权限、会话或工具限制导致的未完成项 + +#### 4. 总体结论 + +必须说明: + +- 本次巡检覆盖了哪些设备 +- 哪些设备状态正常 +- 哪些设备需要继续关注 +- 哪些结果因工具能力或权限限制无法取得 + +## Constraints + +- **禁止编造**设备、指标、版本、状态或结论 +- **禁止跳过 `device_context` 直接猜 `device_id`** +- **禁止在未过 skill 闸门时直接调用对应产品工具** +- 除非用户明确要求,否则**禁止写操作** +- `bash` 仅可用于读取环境信息或计算时间参数,禁止借此写文件、改配置、重启服务或执行其他副作用命令 +- 除非用户明确要求,否则不要改配置、不要重启、不要下发任务 +- 不要把“无数据”解释成“正常” +- 如果工具只返回定性描述,按原文说明,不要私自转成精确数值 +- 如果需要写出报告文件,输出路径必须放到 `~/.flocks/workspace/outputs//` + +## Success Criteria + +满足以下条件才算完成: + +1. 已通过 `device_context` 确认目标设备 +2. 已通过 `tool_search` 或当前 schema 确认合适工具 +3. 对目标设备完成了只读巡检 +4. 输出包含明确的设备范围、指标结果、异常项和总体结论 diff --git a/README.md b/README.md index a2fff0cd0..60541dde7 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ flocks stop The default service URLs are: - Backend API: `http://127.0.0.1:8000` by default - WebUI: `http://127.0.0.1:5173` by default -- Remote access configurable via `flocks start --server-host --webui-host ` +- Remote access configurable via `flocks start --webui-host ` Flocks CLI usage: `flocks --help` @@ -203,7 +203,9 @@ sudo chown -R : ~/.flocks ### 4.3 Remote Access to Flocks Service ```bash __VITE_ADDITIONAL_SERVER_ALLOWED_HOSTS= \ -flocks start --server-host 127.0.0.1 --webui-host 0.0.0.0 +flocks start --webui-host 0.0.0.0 +# windows powershell +# $env:__VITE_ADDITIONAL_SERVER_ALLOWED_HOSTS="your_domain"; flocks start --webui-host 0.0.0.0 ``` If remote access from a virtual machine fails, please specify the host as the virtual machine's IP. @@ -215,7 +217,7 @@ Only enable direct browser-to-backend URLs when you explicitly need them: ```bash FLOCKS_WEBUI_DIRECT_BACKEND_URLS=1 \ -flocks start --server-host 10.0.0.8 --webui-host 0.0.0.0 +flocks start --server-host 0.0.0.0 --webui-host 0.0.0.0 ``` ### 4.4 Authentication & API Token diff --git a/README_zh.md b/README_zh.md index fd0a43e5f..65566a394 100644 --- a/README_zh.md +++ b/README_zh.md @@ -116,7 +116,7 @@ flocks stop 默认服务地址: - 后端 API:默认 `http://127.0.0.1:8000` - WebUI:默认 `http://127.0.0.1:5173` -- 远程访问修改 `flocks start --server-host --webui-host ` +- 远程访问可通过 `flocks start --webui-host ` 配置 更多 CLI 命令使用 `flocks --help` @@ -159,9 +159,7 @@ docker run -d ` ghcr.io/agentflocks/flocks:latest ``` -默认服务地址: -- 后端 API:默认 `http://127.0.0.1:8000` -- WebUI:默认 `http://127.0.0.1:5173` +镜像中的 `EXPOSE` 仅用于声明容器端口;要从宿主机浏览器访问服务,仍需使用 `-p 8000:8000 -p 5173:5173` 映射端口。 ## 4. 常见问题 @@ -206,9 +204,20 @@ sudo chown -R : ~/.flocks ### 4.3 远程访问 Flocks 服务 ```bash __VITE_ADDITIONAL_SERVER_ALLOWED_HOSTS= \ -flocks start --server-host 127.0.0.1 --webui-host 0.0.0.0 +flocks start --webui-host 0.0.0.0 +# Windows PowerShell +# $env:__VITE_ADDITIONAL_SERVER_ALLOWED_HOSTS="your_domain"; flocks start --webui-host 0.0.0.0 +``` +若从虚拟机远程访问失败,请将 host 指定为虚拟机的 IP。 + +WebUI 在后端绑定到非回环 IP 时,默认仍使用同源 `/api` 代理模式。这样浏览器 Cookie 与 SSE 保持在同一源,是局域网访问与反向代理场景下更安全的选择。 + +仅在确实需要浏览器直连后端 URL 时,再显式启用: + +```bash +FLOCKS_WEBUI_DIRECT_BACKEND_URLS=1 \ +flocks start --server-host 0.0.0.0 --webui-host 0.0.0.0 ``` -虚拟机远程访问失败请指定 host 为虚拟机 IP。 ### 4.4 鉴权与 API Token @@ -250,23 +259,25 @@ flocks start --server-host 127.0.0.1 --webui-host 0.0.0.0 反向代理部署: -- 反代必须主动注入 `X-Forwarded-For`。若缺失,凡是直连本机回环的请求都会被自动放行为 `admin`;中间件依靠该头来区分"真本机"与"经由反代的外部请求"。 -- 若反代终止 HTTPS,请同时透传 `X-Forwarded-Proto: https`,以便服务端正确给 Cookie 加 `Secure` 标志。 +- 反代必须主动注入 `X-Forwarded-For`。若缺失该头,且前方存在代理,中间件会拒绝信任回环地址,避免任何直连本机的请求被自动提升为 `admin`。 +- 若反代终止 HTTPS,请同时透传 `X-Forwarded-Proto: https`,以便服务端正确设置安全 Cookie 标志。 +- 浏览器流量优先使用同源反代:WebUI 保持在 `/`,后端流量经 `/api`(必要时还有 `/event`)转发。除非有意让浏览器绕过代理直连后端源站,否则不要在反向代理部署中设置 `VITE_API_BASE_URL`。 +- 对于 SSE 端点,请关闭代理缓冲并保持 HTTP/1.1 启用。 忘记密码 / 应急恢复: -- 在服务器上执行 `flocks admin generate-one-time-password`,账号会被强制置为 `must_reset_password=true`;下次 WebUI 登录会跳转到改密页。**这种状态下所有非浏览器接口都会返回 403**,请勿在不通知调用方的情况下对依赖自动化的账号执行该命令。 +- 在宿主机上执行 `flocks admin generate-one-time-password`。`admin` 账号会被强制置为 `must_reset_password=true`;下次 WebUI 登录会跳转到改密页。**此状态下所有非浏览器端点均返回 403**,若该账号被自动化依赖,请先协调后再执行。 无主 session(CLI / 后台任务 / inbound 渠道): -- 没有 auth 上下文创建出的 session(CLI 子命令、后台任务、IM 渠道入站 dispatcher)`owner_user_id` 字段为空。bootstrap admin 仍可看到,但**之后新增的 member 账号将完全看不到**。可通过下列命令把这类 session 批量赋给指定 admin: +- 在无鉴权上下文中创建的 session(CLI 命令、后台任务、IM 渠道入站 dispatcher)`owner_user_id` 为空。bootstrap admin 仍可见,但之后新增的 member 账号不可见。可用以下命令回填归属: ```bash flocks admin reassign-orphan-sessions --username admin --dry-run # 预览 flocks admin reassign-orphan-sessions --username admin # 实际写入 ``` - 命令会输出 `scanned / orphaned / reassigned / failed` 四个计数;只要 `failed` 非零就以 exit code 2 退出,方便 CI / 脚本捕获"部分写入"情况、修复底层故障(一般是临时存储错误)后再次运行。 + 命令会汇总 `scanned / orphaned / reassigned / failed` 计数;`failed` 非零时以退出码 2 结束,便于 CI / 脚本发现部分写入并在修复底层原因(通常为临时存储错误)后重试。 ## 5. 加入社区 diff --git a/flocks/agent/agent_factory.py b/flocks/agent/agent_factory.py index a27649731..63981e62b 100644 --- a/flocks/agent/agent_factory.py +++ b/flocks/agent/agent_factory.py @@ -45,6 +45,11 @@ except ImportError: _PLUGIN_AGENTS_DIR = Path.home() / ".flocks" / "plugins" / "agents" + +def _project_plugin_agents_dir() -> Path: + """Return the current project's plugin agent directory.""" + return Path.cwd() / ".flocks" / "plugins" / "agents" + # --------------------------------------------------------------------------- # Prompt metadata parsing # --------------------------------------------------------------------------- @@ -311,16 +316,29 @@ def inject_dynamic_prompts( # YAML CRUD helpers (for plugin agents via API routes) # --------------------------------------------------------------------------- -def _find_yaml_file(name: str) -> Optional[Path]: - """Find the YAML source file for a plugin agent by name.""" - for suffix in (".yaml", ".yml"): - candidate = _PLUGIN_AGENTS_DIR / name / f"agent{suffix}" - if candidate.is_file(): - return candidate - # Legacy: flat file layout (name.yaml) - flat = _PLUGIN_AGENTS_DIR / f"{name}{suffix}" - if flat.is_file(): - return flat +def _find_yaml_file(name: str, *, include_project: bool = True, include_user: bool = True) -> Optional[Path]: + """Find the YAML source file for a plugin agent by name. + + Search order follows API edit precedence, not full runtime scan order: + user-level plugins first, then project-level plugins. + """ + search_roots: List[Path] = [] + if include_user: + search_roots.append(_PLUGIN_AGENTS_DIR) + if include_project: + project_dir = _project_plugin_agents_dir() + if project_dir not in search_roots: + search_roots.append(project_dir) + + for root in search_roots: + for suffix in (".yaml", ".yml"): + candidate = root / name / f"agent{suffix}" + if candidate.is_file(): + return candidate + # Legacy: flat file layout (name.yaml) + flat = root / f"{name}{suffix}" + if flat.is_file(): + return flat return None @@ -468,7 +486,8 @@ def delete_yaml_agent(name: str) -> bool: Returns True on success, False if not found. """ - path = _find_yaml_file(name) + # Deletion remains limited to user-managed plugin agents. + path = _find_yaml_file(name, include_project=False, include_user=True) if path is None: return False diff --git a/flocks/agent/delegatable_settings.py b/flocks/agent/delegatable_settings.py new file mode 100644 index 000000000..8ce0b0339 --- /dev/null +++ b/flocks/agent/delegatable_settings.py @@ -0,0 +1,152 @@ +""" +Delegatable override settings for agent footer toggles. + +Keeps runtime UI overrides in a sidecar JSON file instead of rewriting +agent YAML sources inside the repo. +""" + +from __future__ import annotations + +import json +import os +import sys +import tempfile +import threading +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, Iterator, Optional + +from flocks.utils.log import Log + +log = Log.create(service="agent.delegatable_settings") + +_SETTINGS_LOCK = threading.RLock() +_LOCK_FILENAME = "agent_delegatable_settings.json.lock" + + +def settings_path() -> Path: + return Path.home() / ".flocks" / "config" / "agent_delegatable_settings.json" + + +def _platform_file_lock(fd: int) -> None: + if sys.platform == "win32": # pragma: no cover - exercised on Windows only + import msvcrt + + msvcrt.locking(fd, msvcrt.LK_LOCK, 1) + else: + import fcntl + + fcntl.flock(fd, fcntl.LOCK_EX) + + +def _platform_file_unlock(fd: int) -> None: + if sys.platform == "win32": # pragma: no cover + import msvcrt + + try: + msvcrt.locking(fd, msvcrt.LK_UNLCK, 1) + except OSError: + pass + else: + import fcntl + + try: + fcntl.flock(fd, fcntl.LOCK_UN) + except OSError: + pass + + +@contextmanager +def _settings_cross_process_lock(directory: Path) -> Iterator[None]: + directory.mkdir(parents=True, exist_ok=True) + lock_path = directory / _LOCK_FILENAME + fd: Optional[int] = None + locked = False + try: + fd = os.open(str(lock_path), os.O_RDWR | os.O_CREAT, 0o600) + try: + _platform_file_lock(fd) + locked = True + except OSError as exc: + log.warn( + "agent.delegatable_settings.flock_failed", + {"path": str(lock_path), "error": str(exc)}, + ) + yield + finally: + if fd is not None: + if locked: + _platform_file_unlock(fd) + try: + os.close(fd) + except OSError: + pass + + +@contextmanager +def _locked_rmw() -> Iterator[None]: + with _SETTINGS_LOCK: + with _settings_cross_process_lock(settings_path().parent): + yield + + +def load_overrides() -> Dict[str, bool]: + path = settings_path() + with _SETTINGS_LOCK: + try: + if not path.exists(): + return {} + data = json.loads(path.read_text(encoding="utf-8")) + overrides = data.get("delegatable_overrides", {}) + if isinstance(overrides, dict): + return { + str(name): value + for name, value in overrides.items() + if isinstance(name, str) and isinstance(value, bool) + } + except Exception as exc: + log.warn("agent.delegatable_settings.load_failed", {"error": str(exc)}) + return {} + + +def save_overrides(overrides: Dict[str, bool]) -> None: + path = settings_path() + with _SETTINGS_LOCK: + path.parent.mkdir(parents=True, exist_ok=True) + payload = {"delegatable_overrides": {name: overrides[name] for name in sorted(overrides)}} + fd, tmp_path = tempfile.mkstemp( + dir=str(path.parent), + prefix=".agent_delegatable_settings_", + suffix=".tmp", + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + f.write("\n") + os.replace(tmp_path, str(path)) + except Exception: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + +def get_override(name: str) -> Optional[bool]: + return load_overrides().get(name) + + +def set_override(name: str, delegatable: bool) -> bool: + with _locked_rmw(): + current = load_overrides() + current[name] = delegatable + save_overrides(current) + return delegatable + + +def forget_override(name: str) -> None: + with _locked_rmw(): + current = load_overrides() + if name in current: + current.pop(name, None) + save_overrides(current) diff --git a/flocks/agent/registry.py b/flocks/agent/registry.py index 6ee81d18f..fade71dc4 100644 --- a/flocks/agent/registry.py +++ b/flocks/agent/registry.py @@ -39,6 +39,7 @@ AvailableWorkflow, DelegationTrigger, ) +import flocks.agent.delegatable_settings as delegatable_settings from flocks.agent.toolset import agent_declares_tool from flocks.agent.prompt_utils import categorize_tools from flocks.agent.agent_factory import ( @@ -168,6 +169,7 @@ def _storage_custom_agent_to_info(agent_data: Dict[str, Any]) -> Optional[AgentI model=model, native=False, hidden=agent_data.get("hidden", False), + delegatable=agent_data.get("delegatable"), tools=agent_data.get("tools", []), tags=agent_data.get("tags", []), ) @@ -199,6 +201,14 @@ async def _load_storage_custom_agents(existing_names: Set[str]) -> Dict[str, Age return loaded +def _apply_delegatable_overrides(agents: Dict[str, AgentInfo]) -> None: + overrides = delegatable_settings.load_overrides() + for name, delegatable in overrides.items(): + agent = agents.get(name) + if agent is not None: + agent.delegatable = delegatable + + # --------------------------------------------------------------------------- # Agent registry # --------------------------------------------------------------------------- @@ -403,6 +413,7 @@ def _permission_dict_to_tools(permission_cfg: Dict[str, Any]) -> List[str]: storage_custom_agents = await _load_storage_custom_agents(set(result.keys())) result.update(storage_custom_agents) + _apply_delegatable_overrides(result) # enabled_agents whitelist filter if cfg.enabled_agents is not None: @@ -439,6 +450,7 @@ def _create_task(): # to detect changes made by other workers. Stored as a class variable so it # persists across async calls in the same process. _skill_settings_mtime: float = 0.0 + _delegatable_settings_mtime: float = 0.0 @classmethod def _sync_skill_settings_cache(cls) -> None: @@ -468,10 +480,25 @@ def _sync_skill_settings_cache(cls) -> None: except Exception: pass + @classmethod + def _sync_delegatable_settings_cache(cls) -> None: + """Invalidate in-process agent cache when delegatable overrides change on disk.""" + try: + sentinel = delegatable_settings.settings_path() + if not sentinel.exists(): + return + current_mtime = sentinel.stat().st_mtime + if current_mtime > cls._delegatable_settings_mtime: + cls._delegatable_settings_mtime = current_mtime + cls._state_accessor.invalidate() # type: ignore[attr-defined] + except Exception: + pass + @staticmethod async def state() -> Dict[str, AgentInfo]: # Detect cross-worker skill-settings changes before serving cached state. Agent._sync_skill_settings_cache() + Agent._sync_delegatable_settings_cache() return await Agent._state_accessor() @classmethod diff --git a/flocks/channel/builtin/feishu/monitor.py b/flocks/channel/builtin/feishu/monitor.py index a1d9a8ed9..c85804925 100644 --- a/flocks/channel/builtin/feishu/monitor.py +++ b/flocks/channel/builtin/feishu/monitor.py @@ -40,6 +40,78 @@ # _chat_locks LRU cap: evict oldest unlocked entries when exceeded _CHAT_LOCKS_MAX = 2000 +_WS_ACCOUNT_RECONNECT_DELAY_S = 1.0 +_WS_ACCOUNT_RECONNECT_MAX_DELAY_S = 30.0 + + +class _ObservedWSClient: + """Expose disconnect observation for SDK clients that only provide start/stop.""" + + def __init__(self, client: Any, *, app_id: str) -> None: + self._client = client + self._app_id = app_id + self._thread: Optional[threading.Thread] = None + self._started = threading.Event() + self._disconnect_event = threading.Event() + self._start_error: Optional[BaseException] = None + self._disconnect_error: Optional[BaseException] = None + self._stop_requested = False + + def start(self) -> None: + self._started.clear() + self._disconnect_event.clear() + self._start_error = None + self._disconnect_error = None + self._stop_requested = False + + def _run() -> None: + self._started.set() + try: + self._client.start() + except BaseException as exc: + if self._start_error is None: + self._start_error = exc + self._notify_disconnected(exc) + else: + self._notify_disconnected(RuntimeError("Feishu websocket client exited")) + + self._thread = threading.Thread( + target=_run, + name=f"feishu-ws-{self._app_id}", + daemon=True, + ) + self._thread.start() + self._started.wait(timeout=0.2) + + def stop(self) -> None: + self._stop_requested = True + self._disconnect_event.set() + stop = getattr(self._client, "stop", None) + if callable(stop): + with contextlib.suppress(Exception): + stop() + if self._thread: + self._thread.join(timeout=5) + self._thread = None + + @property + def start_error(self) -> Optional[BaseException]: + return self._start_error + + @property + def disconnected_error(self) -> Optional[BaseException]: + return self._disconnect_error + + async def wait_disconnected(self) -> Optional[BaseException]: + await asyncio.to_thread(self._disconnect_event.wait) + return self._disconnect_error + + def _notify_disconnected(self, error: BaseException) -> None: + if self._stop_requested: + return + if self._disconnect_error is None: + self._disconnect_error = error + self._disconnect_event.set() def _extract_ws_close_code(exc: BaseException | None) -> int | None: @@ -82,15 +154,16 @@ def _build_ws_client( ): """Build a websocket client compatible with both old and new lark-oapi SDKs.""" try: - import lark_oapi as lark - from lark_oapi.adapter.websocket import WSClient + lark = importlib.import_module("lark_oapi") + ws_adapter = importlib.import_module("lark_oapi.adapter.websocket") + ws_client_cls = ws_adapter.WSClient - return WSClient( + return _ObservedWSClient(ws_client_cls( app_id=app_id, app_secret=app_secret, event_handler=event_handler, log_level=lark.LogLevel.WARNING, - ) + ), app_id=app_id) except ImportError: lark = importlib.import_module("lark_oapi") ws_module = importlib.import_module("lark_oapi.ws.client") @@ -114,13 +187,17 @@ def __init__(self) -> None: self._receive_task: Optional[asyncio.Task] = None self._ping_task: Optional[asyncio.Task] = None self._start_error: Optional[BaseException] = None + self._disconnect_error: Optional[BaseException] = None self._stop_requested = False + self._disconnect_event = threading.Event() self._finished = threading.Event() def start(self) -> None: self._finished.clear() self._start_error = None + self._disconnect_error = None self._stop_requested = False + self._disconnect_event.clear() def _run() -> None: self._loop = asyncio.new_event_loop() @@ -146,6 +223,13 @@ async def _tracked_ping_loop() -> None: self._client._ping_loop = _tracked_ping_loop + def _notify_disconnected(error: BaseException) -> None: + if self._stop_requested: + return + if self._disconnect_error is None: + self._disconnect_error = error + self._disconnect_event.set() + async def _receive_message_loop() -> None: self._receive_task = asyncio.current_task() try: @@ -170,11 +254,22 @@ async def _receive_message_loop() -> None: "app_id": app_id, "error": str(e), }) - await self._client._disconnect() + with contextlib.suppress(Exception): + await self._client._disconnect() if self._client._auto_reconnect: - await self._client._reconnect() - else: - raise + try: + await self._client._reconnect() + return + except Exception as reconnect_error: + log.error("feishu.ws.reconnect_error", { + "app_id": app_id, + "error": str(reconnect_error), + }) + e = reconnect_error + _notify_disconnected(e) + running_loop = asyncio.get_running_loop() + running_loop.call_soon(running_loop.stop) + return finally: self._receive_task = None @@ -186,8 +281,10 @@ async def _receive_message_loop() -> None: except RuntimeError as e: if "Event loop stopped before Future completed" not in str(e): self._start_error = e + _notify_disconnected(e) except BaseException as e: # pragma: no cover - defensive self._start_error = e + _notify_disconnected(e) finally: self._finished.set() @@ -208,6 +305,7 @@ async def _receive_message_loop() -> None: def stop(self) -> None: self._stop_requested = True + self._disconnect_event.set() if self._client is None and self._thread and self._thread.is_alive(): deadline = time.monotonic() + 0.5 while self._client is None and not self._finished.is_set(): @@ -267,6 +365,14 @@ async def _drain_task(task: Optional[asyncio.Task], timeout: float) -> None: def start_error(self) -> Optional[BaseException]: return self._start_error + @property + def disconnected_error(self) -> Optional[BaseException]: + return self._disconnect_error + + async def wait_disconnected(self) -> Optional[BaseException]: + await asyncio.to_thread(self._disconnect_event.wait) + return self._disconnect_error + return _CompatWSClient() @@ -299,27 +405,64 @@ async def start_websocket( await _start_single_websocket(accounts[0], on_message, abort_event) return + reconnect_delay_s = float(config.get("websocketReconnectDelaySeconds", _WS_ACCOUNT_RECONNECT_DELAY_S)) + reconnect_max_delay_s = float( + config.get("websocketReconnectMaxDelaySeconds", _WS_ACCOUNT_RECONNECT_MAX_DELAY_S) + ) + + async def _wait_before_restart(delay_s: float) -> None: + if delay_s <= 0: + await asyncio.sleep(0) + return + if abort_event is None: + await asyncio.sleep(delay_s) + return + try: + await asyncio.wait_for(abort_event.wait(), timeout=delay_s) + except asyncio.TimeoutError: + return + + async def _run_account(account: dict) -> None: + account_id = account["_account_id"] + attempt = 0 + while not (abort_event and abort_event.is_set()): + try: + await _start_single_websocket(account, on_message, abort_event) + if abort_event and abort_event.is_set(): + return + raise RuntimeError("Feishu websocket account stopped") + except asyncio.CancelledError: + raise + except Exception as exc: + attempt += 1 + delay_s = min( + reconnect_delay_s * (2 ** min(attempt - 1, 5)), + reconnect_max_delay_s, + ) + log.error("feishu.ws.account_failed", { + "account_id": account_id, + "attempt": attempt, + "error": str(exc), + }) + log.info("feishu.ws.account_restart_scheduled", { + "account_id": account_id, + "delay_s": delay_s, + }) + await _wait_before_restart(delay_s) + tasks = [ asyncio.create_task( - _start_single_websocket(acc, on_message, abort_event), + _run_account(acc), name=f"feishu-ws-{acc['_account_id']}", ) for acc in accounts ] try: - # return_exceptions=True: one account failing does not affect others - results = await asyncio.gather(*tasks, return_exceptions=True) - for acc, result in zip(accounts, results): - if isinstance(result, Exception) and not isinstance(result, asyncio.CancelledError): - log.error("feishu.ws.account_failed", { - "account_id": acc["_account_id"], - "error": str(result), - }) + await asyncio.gather(*tasks) except asyncio.CancelledError: for t in tasks: if not t.done(): t.cancel() - # Wait for all tasks to finish cancellation to avoid "Task destroyed but pending" warnings await asyncio.gather(*tasks, return_exceptions=True) raise @@ -517,16 +660,56 @@ def _event_handler(data: dict) -> None: ws_client.start() # Launch background dedup flush task (only when dedup is enabled) - flush_task = await dedup.start_background_flush() if dedup_enabled else asyncio.create_task(asyncio.sleep(0)) + flush_task = ( + await dedup.start_background_flush() + if dedup_enabled + else asyncio.create_task(asyncio.sleep(0)) + ) + disconnect_waiter: asyncio.Task | None = None + abort_waiter: asyncio.Task | None = None try: + wait_disconnected = getattr(ws_client, "wait_disconnected", None) + if callable(wait_disconnected): + disconnect_waiter = asyncio.create_task( + wait_disconnected(), + name=f"feishu-ws-disconnect-{account_id}", + ) if abort_event: - await abort_event.wait() + abort_waiter = asyncio.create_task( + abort_event.wait(), + name=f"feishu-ws-abort-{account_id}", + ) else: - while True: - await asyncio.sleep(3600) + abort_waiter = asyncio.create_task( + asyncio.Event().wait(), + name=f"feishu-ws-abort-{account_id}", + ) + + waiters = {abort_waiter} + if disconnect_waiter is not None: + waiters.add(disconnect_waiter) + + done, pending = await asyncio.wait( + waiters, + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + if disconnect_waiter is not None and disconnect_waiter in done: + disconnect_error = disconnect_waiter.result() + if disconnect_error is None: + disconnect_error = getattr(ws_client, "start_error", None) + if disconnect_error is None: + disconnect_error = RuntimeError("Feishu websocket disconnected") + raise RuntimeError( + f"Feishu websocket disconnected for account '{account_id}'" + ) from disconnect_error finally: flush_task.cancel() + await asyncio.gather(flush_task, return_exceptions=True) if dedup_enabled: await dedup.flush() # final flush before exit ws_client.stop() diff --git a/flocks/channel/inbound/dispatcher.py b/flocks/channel/inbound/dispatcher.py index d5b8ae145..95ab2d7ff 100644 --- a/flocks/channel/inbound/dispatcher.py +++ b/flocks/channel/inbound/dispatcher.py @@ -673,6 +673,24 @@ async def _run_llm(_event, prompt_text: str, display_text: Optional[str] = None) f"命令 `{display_text or prompt_text}` 暂不支持在当前渠道中以 slash 形式执行。" ) + async def _clear_history() -> None: + try: + from flocks.server.routes.session import _clear_session_history + + deleted_count = await _clear_session_history(binding.session_id) + except Exception as exc: + log.error("dispatcher.clear_command_failed", { + "session_id": binding.session_id, + "channel_id": msg.channel_id, + "error": str(exc), + }) + await callbacks.deliver_text(f"清空当前会话失败:{type(exc).__name__}") + return + + await callbacks.deliver_text( + f"已清空当前会话历史,共删除 {deleted_count} 条消息。" + ) + async def _run_session_control(_event, parsed) -> bool: if parsed.canonical_name == "status": await self._handle_status_command(binding, msg, callbacks) @@ -718,6 +736,7 @@ async def _run_session_control(_event, parsed) -> bool: direct_response=_publish_direct_response, run_llm=_run_llm, session_control=_run_session_control, + clear_history=_clear_history, ), ) return True @@ -848,6 +867,26 @@ async def _handle_session_command( agent_id=new_session.agent, scope_override=scope_override, ) + # Archive the previous session so it no longer shows up as an *active* + # IM session. The binding has already moved to ``new_session`` via + # ``rebind`` above, but the old session retains ``status="active"`` and + # the same ``[Feishu]/[Wecom]/[Dingtalk]`` title prefix. Without this, + # repeated ``/new`` leaves multiple active sessions for the same + # conversation, and unattended scheduled tasks (which resolve the IM + # target via ``session_list(status="active")`` + title prefix) can no + # longer tell which one is current — sending to the wrong session or + # failing outright. Best-effort: archiving failure must not abort /new. + try: + await Session.update( + session.project_id, + session.id, + status="archived", + ) + except Exception as exc: + log.warning("dispatcher.archive_previous_session_failed", { + "session_id": session.id, + "error": str(exc), + }) await self._trigger_command_hook( "new", session.id, diff --git a/flocks/cli/service_manager.py b/flocks/cli/service_manager.py index b560eb3bb..d0de2066a 100644 --- a/flocks/cli/service_manager.py +++ b/flocks/cli/service_manager.py @@ -26,6 +26,7 @@ import httpx from flocks.browser.admin import stop_all_daemons as stop_all_browser_daemons +from flocks.utils.log import rotate_log_file try: import fcntl @@ -1681,6 +1682,7 @@ def _spawn_process( kwargs["start_new_session"] = True log_path.parent.mkdir(parents=True, exist_ok=True) + rotate_log_file(log_path) handle = log_path.open("a", encoding="utf-8") try: return subprocess.Popen( diff --git a/flocks/cli/session_runner.py b/flocks/cli/session_runner.py index 69e3698c6..d85967189 100644 --- a/flocks/cli/session_runner.py +++ b/flocks/cli/session_runner.py @@ -339,6 +339,7 @@ async def _process_message( from flocks.input.dispatcher import dispatch_user_input from flocks.input.events import UserInputEvent from flocks.input.output import CliOutputSink + from flocks.session.message import Message event = UserInputEvent( source_type="cli", @@ -349,6 +350,12 @@ async def _process_message( model={"providerID": provider_id, "modelID": model_id}, display_text=stripped, ) + + async def _clear_history() -> None: + await Message.clear(self._session.id) + await self._clear_screen() + self.console.print("[dim]Conversation history cleared.[/dim]") + handled = await dispatch_user_input( event, CliOutputSink( @@ -365,6 +372,7 @@ async def _process_message( dispatch_commands=False, ), clear_screen=self._clear_screen, + clear_history=_clear_history, ), ) if handled.handled: @@ -774,7 +782,7 @@ def _print_help(self) -> None: /skills List skills (same as /skills list) /skills list List skills /skills refresh Refresh skills - /clear Clear screen + /clear Clear session history /exit Exit session /quit Exit session /q Exit session diff --git a/flocks/command/command.py b/flocks/command/command.py index f027fe408..000e7db0f 100644 --- a/flocks/command/command.py +++ b/flocks/command/command.py @@ -245,10 +245,12 @@ def _ensure_defaults(cls) -> None: ), CommandDef( name="clear", - description="Clear screen output", - template="Clear the current UI output only.", + description="Clear the current session history", + template="Clear all messages in the current session.", execution_kind="direct", allow_attachments=False, + visible_surfaces=ALL_SURFACES, + channel_safe=True, ), CommandDef( name="bug", diff --git a/flocks/command/direct.py b/flocks/command/direct.py index eac7666ac..47f9851a1 100644 --- a/flocks/command/direct.py +++ b/flocks/command/direct.py @@ -27,6 +27,7 @@ class DirectCommandResult: text: Optional[str] = None prompt: Optional[str] = None clear_screen: bool = False + clear_history: bool = False def is_agent_safe_direct_command(command: CommandInfo) -> bool: @@ -145,7 +146,7 @@ async def run_direct_command( return DirectCommandResult(handled=True, text=format_help(surface=surface)) if name == "clear": - return DirectCommandResult(handled=True, clear_screen=True, text="Screen cleared.") + return DirectCommandResult(handled=True, clear_history=True) if name == "tools": if not args or args == "list": diff --git a/flocks/command/handler.py b/flocks/command/handler.py index 4ba86bda4..808693bc2 100644 --- a/flocks/command/handler.py +++ b/flocks/command/handler.py @@ -11,6 +11,7 @@ SendText = Callable[[str], Awaitable[None]] SendPrompt = Callable[[str], Awaitable[None]] ClearScreen = Callable[[], Awaitable[None]] +ClearHistory = Callable[[], Awaitable[None]] async def handle_slash_command( @@ -19,6 +20,7 @@ async def handle_slash_command( send_text: SendText, send_prompt: SendPrompt, clear_screen: Optional[ClearScreen] = None, + clear_history: Optional[ClearHistory] = None, surface: Optional[CommandSurface] = None, ) -> bool: """ @@ -55,5 +57,12 @@ async def handle_slash_command( await send_text(result.text or "Screen cleared.") return True + if result.clear_history: + if clear_history: + await clear_history() + else: + await send_text(result.text or "Conversation history could not be cleared on this surface.") + return True + await send_text(result.text or "") return True diff --git a/flocks/ingest/kafka/__init__.py b/flocks/ingest/kafka/__init__.py new file mode 100644 index 000000000..4accc0436 --- /dev/null +++ b/flocks/ingest/kafka/__init__.py @@ -0,0 +1 @@ +"""Kafka ingest: consume messages from a topic and trigger workflow runs.""" diff --git a/flocks/ingest/kafka/constants.py b/flocks/ingest/kafka/constants.py new file mode 100644 index 000000000..3caab12d6 --- /dev/null +++ b/flocks/ingest/kafka/constants.py @@ -0,0 +1,3 @@ +"""Storage key prefix for per-workflow Kafka config (must match server routes).""" + +WORKFLOW_KAFKA_CONFIG_PREFIX = "workflow_kafka_config/" diff --git a/flocks/ingest/kafka/manager.py b/flocks/ingest/kafka/manager.py new file mode 100644 index 000000000..4cec529ad --- /dev/null +++ b/flocks/ingest/kafka/manager.py @@ -0,0 +1,633 @@ +"""Lifecycle manager for Kafka consumers → workflow runs. + +This mirrors :mod:`flocks.ingest.syslog.manager`: one async consumer task per +workflow id (when enabled), draining a bounded queue with a fixed worker pool so +an inbound burst cannot translate into unbounded ``asyncio.Task`` growth. + +Differences from the syslog manager: + +* The transport is a Kafka *consumer* (``aiokafka.AIOKafkaConsumer``) instead of + a UDP/TCP socket bind. "binding/listening" is replaced by + "connecting/running"; a connection failure (broker unreachable, auth error) + is surfaced the same way a bind failure is. +* Backpressure uses a *blocking* ``queue.put`` instead of ``put_nowait``+drop: + this avoids local drops while the worker pool falls behind and lets the + consumer pause naturally. Because ``aiokafka`` auto-commits fetched offsets, + the current crash semantics are still best-effort / at-most-once rather than + fully durable at-least-once delivery. +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import time +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from flocks.storage.storage import Storage +from flocks.utils.log import Log +from flocks.workflow.execution_store import ( + DEFAULT_LARGE_LIST_KEYS, + compact_history_for_storage, + compact_outputs_for_storage, + create_execution_record, + record_execution_result, + resolve_execution_outcome, +) +from flocks.workflow.fs_store import read_workflow_from_fs +from flocks.workflow.runner import run_workflow + +from flocks.ingest.kafka.constants import WORKFLOW_KAFKA_CONFIG_PREFIX + +log = Log.create(service="kafka.manager") + + +# Maximum concurrent workflow executions per workflow to avoid FD exhaustion and +# SQLite write contention. Kafka messages can carry large JSON payloads, so keep +# this lower than syslog to avoid several full workflow histories being resident +# at the same time. +_MAX_CONCURRENT_EXECUTIONS = 2 +# Maximum number of buffered Kafka messages per workflow. Unlike syslog we do +# not drop on overflow; a full queue applies backpressure to the consumer loop. +_MAX_QUEUE_SIZE = 100 +# Maximum time we wait for the consumer to either connect successfully or fail +# during ``restart_workflow`` so the HTTP save endpoint can surface connection +# errors instead of pretending the consumer is running. +_CONNECT_WAIT_TIMEOUT_S = 8.0 +# Kafka client request timeout; kept short so an unreachable broker fails fast +# within the connect-wait window above. +_REQUEST_TIMEOUT_MS = 5000 +# Bound aiokafka's internal fetch buffers. The explicit queue below provides the +# main backpressure; these caps stop the client from prefetching a large burst +# before the workflow workers can drain it. +_FETCH_MAX_BYTES = 8 * 1024 * 1024 +_MAX_PARTITION_FETCH_BYTES = 4 * 1024 * 1024 +_MAX_POLL_RECORDS = 16 + +_KAFKA_STORAGE_LIST_KEYS = DEFAULT_LARGE_LIST_KEYS | frozenset( + { + "duplicate_alerts", + "triage_candidate_alerts", + "enriched_alerts_with_triage", + "kafka_messages", + } +) +_KAFKA_RAW_INPUT_KEYS = frozenset({"kafka_message", "kafka_value", "kafka_record"}) +_STORAGE_PREVIEW_CHARS = 512 + + +@dataclass(frozen=True) +class _QueuedKafkaMessage: + """Raw Kafka value kept in the queue until a worker is ready to process it.""" + + raw_value: Optional[bytes] + size_bytes: int + + +def _strip_execution_only_comments(value: Any) -> Any: + if isinstance(value, list): + return [_strip_execution_only_comments(item) for item in value] + if not isinstance(value, dict): + return value + return { + key: _strip_execution_only_comments(nested) + for key, nested in value.items() + if not str(key).startswith("_comment") + } + + +def _decode_message(raw: Optional[bytes]) -> Any: + """Decode a Kafka message value to a Python object. + + Tries UTF-8 + JSON first (the common case for structured events); falls back + to the raw decoded string, then to a base64-free repr for binary payloads. + """ + if raw is None: + return None + try: + text = raw.decode("utf-8") + except Exception: + return raw.hex() + try: + return json.loads(text) + except Exception: + return text + + +def _summarize_large_value(value: Any) -> Any: + """Return a bounded representation suitable for execution history storage.""" + if isinstance(value, bytes): + return { + "_type": "bytes", + "sizeBytes": len(value), + "sha256": hashlib.sha256(value).hexdigest(), + } + if isinstance(value, str): + if len(value) <= _STORAGE_PREVIEW_CHARS: + return value + return { + "_type": "string", + "chars": len(value), + "sha256": hashlib.sha256(value.encode("utf-8", errors="ignore")).hexdigest(), + "preview": value[:_STORAGE_PREVIEW_CHARS], + } + if isinstance(value, (list, tuple)): + return { + "_type": "list", + "count": len(value), + "preview": [_summarize_large_value(item) for item in list(value)[:3]], + } + if isinstance(value, dict): + compacted: Dict[str, Any] = { + "_type": "dict", + "keys": list(value.keys())[:50], + } + for key in ( + "id", + "_id", + "log_id", + "raw_log_id", + "event_id", + "message_id", + "source", + "product_type", + "hostname", + ): + if key in value: + compacted[key] = _summarize_large_value(value[key]) + if "alarmData" in value: + compacted["alarmData"] = _summarize_large_value(value["alarmData"]) + return compacted + return value + + +def _compact_for_kafka_storage(outputs: Any) -> Dict[str, Any]: + """Compact all known large workflow lists for high-frequency Kafka runs.""" + compacted = compact_outputs_for_storage( + outputs, + keys=_KAFKA_STORAGE_LIST_KEYS, + size_threshold=0, + ) + for key, value in list(compacted.items()): + if key == "kafka_output" or ( + isinstance(value, str) and len(value) > _STORAGE_PREVIEW_CHARS + ): + compacted[key] = _summarize_large_value(value) + return compacted + + +def _compact_history_for_kafka_storage(history: Any, *, input_key: str) -> List[Any]: + compacted = compact_history_for_storage( + history, + keys=_KAFKA_STORAGE_LIST_KEYS, + size_threshold=0, + ) + raw_input_keys = _KAFKA_RAW_INPUT_KEYS | frozenset({input_key}) + for step in compacted: + if not isinstance(step, dict): + continue + for field in ("inputs", "outputs"): + payload = step.get(field) + if not isinstance(payload, dict): + continue + for key, value in list(payload.items()): + if key in raw_input_keys or key == "kafka_output" or ( + isinstance(value, str) and len(value) > _STORAGE_PREVIEW_CHARS + ): + payload[key] = _summarize_large_value(value) + return compacted + + +class KafkaManager: + """One async consumer task per workflow id (when enabled).""" + + def __init__(self) -> None: + self._tasks: dict[str, asyncio.Task] = {} + self._abort_events: dict[str, asyncio.Event] = {} + # Per-workflow bounded message queue for backpressure + self._queues: dict[str, asyncio.Queue] = {} + # Per-workflow fixed worker pool draining the queue + self._worker_pools: dict[str, List[asyncio.Task]] = {} + # Per-workflow consumer runtime status for the kafka-status API. + # State values: "connecting" | "running" | "failed" | "stopped". + self._status: dict[str, Dict[str, Any]] = {} + # Per-workflow event signalled once the consumer has either connected + # successfully or failed; used by ``restart_workflow``. + self._ready: dict[str, asyncio.Event] = {} + + @staticmethod + def _config_key(workflow_id: str) -> str: + return f"{WORKFLOW_KAFKA_CONFIG_PREFIX}{workflow_id}" + + async def start_all(self) -> None: + try: + keys = await Storage.list_keys(WORKFLOW_KAFKA_CONFIG_PREFIX) + except Exception as exc: + log.warning("kafka.list_keys_failed", {"error": str(exc)}) + return + + for key in keys: + if not key.startswith(WORKFLOW_KAFKA_CONFIG_PREFIX): + continue + workflow_id = key[len(WORKFLOW_KAFKA_CONFIG_PREFIX):] + if not workflow_id: + continue + try: + data = await Storage.read(key) + except Exception as exc: + log.warning("kafka.config_read_failed", {"key": key, "error": str(exc)}) + continue + if isinstance(data, dict) and data.get("enabled"): + await self.restart_workflow(workflow_id) + + async def stop_all(self) -> None: + for workflow_id in list(self._tasks.keys()): + await self.stop_workflow(workflow_id) + + async def _cleanup_runtime_resources(self, workflow_id: str) -> None: + # Cancel all worker pool tasks; pop first so callers observing a stopped + # consumer see an empty pool immediately. + pool = self._worker_pools.pop(workflow_id, None) + if pool: + for worker in pool: + if not worker.done(): + worker.cancel() + try: + await asyncio.wait_for( + asyncio.gather(*pool, return_exceptions=True), + timeout=5.0, + ) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + self._queues.pop(workflow_id, None) + self._abort_events.pop(workflow_id, None) + self._ready.pop(workflow_id, None) + + def get_consumer_status(self, workflow_id: str) -> Dict[str, Any]: + """Return a snapshot of the consumer runtime state for ``workflow_id``. + + Result shape:: + + {"state": "connecting|running|failed|stopped", "error": "..." | None, + "broker": "...", "topic": "...", "groupId": "...", + "queueSize": 12, "queueCapacity": , + "workerCount": <_MAX_CONCURRENT_EXECUTIONS>} + """ + status = dict(self._status.get(workflow_id) or {"state": "stopped"}) + q = self._queues.get(workflow_id) + if q is not None: + status["queueSize"] = q.qsize() + status["queueCapacity"] = q.maxsize + pool = self._worker_pools.get(workflow_id) + if pool is not None: + status["workerCount"] = sum(1 for t in pool if not t.done()) + return status + + async def stop_workflow(self, workflow_id: str) -> None: + ev = self._abort_events.get(workflow_id) + if ev is not None: + ev.set() + task = self._tasks.pop(workflow_id, None) + if task is not None and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + pass + await self._cleanup_runtime_resources(workflow_id) + if workflow_id in self._status: + self._status[workflow_id] = {"state": "stopped", "error": None} + + async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: + """Restart the consumer and return its post-connect runtime status. + + Blocks until the consumer connects, the connection fails, or + ``_CONNECT_WAIT_TIMEOUT_S`` elapses, so the HTTP save endpoint can + surface connection errors to the user. + """ + await self.stop_workflow(workflow_id) + key = self._config_key(workflow_id) + try: + data = await Storage.read(key) + except Exception as exc: + log.warning("kafka.restart_read_failed", {"workflow_id": workflow_id, "error": str(exc)}) + return {"state": "failed", "error": str(exc)} + if not isinstance(data, dict) or not data.get("enabled"): + self._status[workflow_id] = {"state": "stopped", "error": None} + return {"state": "stopped", "error": None} + + input_broker = str(data.get("inputBroker") or "").strip() + input_topic = str(data.get("inputTopic") or "").strip() + if not input_broker or not input_topic: + err = "missing_input_broker_or_topic" + self._status[workflow_id] = {"state": "failed", "error": err} + log.warning("kafka.config_incomplete", {"workflow_id": workflow_id}) + return {"state": "failed", "error": err} + + # Load and cache the workflow JSON once; avoids a disk read per message. + wf_data = read_workflow_from_fs(workflow_id) + if not wf_data: + err = "workflow_not_found" + self._status[workflow_id] = {"state": "failed", "error": err} + log.warning("kafka.workflow_not_found_on_start", {"workflow_id": workflow_id}) + return {"state": "failed", "error": err} + workflow_json = wf_data.get("workflowJson") + if not workflow_json: + err = "workflow_json_missing" + self._status[workflow_id] = {"state": "failed", "error": err} + log.warning("kafka.workflow_json_missing_on_start", {"workflow_id": workflow_id}) + return {"state": "failed", "error": err} + + group_id = str(data.get("inputGroupId") or "").strip() or f"flocks-consumer-{workflow_id}" + input_key = str(data.get("inputKey") or "kafka_message") + configured_inputs = _strip_execution_only_comments( + data.get("inputs") if isinstance(data.get("inputs"), dict) else {} + ) + + queue: asyncio.Queue = asyncio.Queue(maxsize=_MAX_QUEUE_SIZE) + self._queues[workflow_id] = queue + + abort = asyncio.Event() + self._abort_events[workflow_id] = abort + + ready = asyncio.Event() + self._ready[workflow_id] = ready + + self._status[workflow_id] = { + "state": "connecting", + "error": None, + "broker": input_broker, + "topic": input_topic, + "groupId": group_id, + } + + # Fixed worker pool drains the queue (at most _MAX_CONCURRENT_EXECUTIONS + # concurrent runs). + workers: List[asyncio.Task] = [] + for i in range(_MAX_CONCURRENT_EXECUTIONS): + workers.append( + asyncio.create_task( + self._worker_loop( + workflow_id, workflow_json, input_key, configured_inputs, queue, abort, + ), + name=f"kafka-worker-{workflow_id}-{i}", + ) + ) + self._worker_pools[workflow_id] = workers + + task = asyncio.create_task( + self._consumer_loop( + workflow_id, input_broker, input_topic, group_id, + str(data.get("autoOffsetReset") or "latest"), + queue, abort, ready, + ), + name=f"kafka-{workflow_id}", + ) + self._tasks[workflow_id] = task + + try: + await asyncio.wait_for(ready.wait(), timeout=_CONNECT_WAIT_TIMEOUT_S) + except asyncio.TimeoutError: + current = self._status.get(workflow_id) or {} + if current.get("state") == "connecting": + self._status[workflow_id] = { + **current, + "state": "connecting", + "error": "connect_pending_timeout", + } + log.warning("kafka.connect_pending_timeout", {"workflow_id": workflow_id}) + + current = self._status.get(workflow_id) or {} + if current.get("state") == "failed": + task = self._tasks.get(workflow_id) + if task is not None and not task.done(): + try: + await task + except asyncio.CancelledError: + pass + except Exception: + pass + task = self._tasks.get(workflow_id) + if task is not None and task.done(): + self._tasks.pop(workflow_id, None) + + log.info("kafka.consumer_scheduled", {"workflow_id": workflow_id}) + return self.get_consumer_status(workflow_id) + + async def _consumer_loop( + self, + workflow_id: str, + broker: str, + topic: str, + group_id: str, + auto_offset_reset: str, + queue: asyncio.Queue, + abort: asyncio.Event, + ready: asyncio.Event, + ) -> None: + try: + from aiokafka import AIOKafkaConsumer + except Exception as exc: + self._status[workflow_id] = { + "state": "failed", + "error": f"aiokafka_import_failed: {exc}", + "broker": broker, + "topic": topic, + "groupId": group_id, + } + ready.set() + log.error("kafka.import_failed", {"workflow_id": workflow_id, "error": str(exc)}) + await self._cleanup_runtime_resources(workflow_id) + return + + consumer = AIOKafkaConsumer( + topic, + bootstrap_servers=broker, + group_id=group_id, + # Auto-commit advances based on fetched progress, not worker + # completion. Backpressure narrows the crash window but current + # semantics remain best-effort / at-most-once. + enable_auto_commit=True, + auto_offset_reset=auto_offset_reset if auto_offset_reset in ("latest", "earliest") else "latest", + request_timeout_ms=_REQUEST_TIMEOUT_MS, + fetch_max_bytes=_FETCH_MAX_BYTES, + max_partition_fetch_bytes=_MAX_PARTITION_FETCH_BYTES, + max_poll_records=_MAX_POLL_RECORDS, + ) + + try: + await consumer.start() + except asyncio.CancelledError: + try: + await consumer.stop() + except Exception: + pass + raise + except Exception as exc: + self._status[workflow_id] = { + "state": "failed", + "error": str(exc), + "broker": broker, + "topic": topic, + "groupId": group_id, + } + ready.set() + log.error( + "kafka.connect_failed", + {"workflow_id": workflow_id, "error": str(exc), "broker": broker, "topic": topic}, + ) + try: + await consumer.stop() + except Exception: + pass + await self._cleanup_runtime_resources(workflow_id) + return + + self._status[workflow_id] = { + "state": "running", + "error": None, + "broker": broker, + "topic": topic, + "groupId": group_id, + } + ready.set() + log.info("kafka.consumer_running", {"workflow_id": workflow_id, "topic": topic}) + + try: + async for msg in consumer: + if abort.is_set(): + break + raw_value = msg.value + queued = _QueuedKafkaMessage( + raw_value=raw_value, + size_bytes=len(raw_value) if raw_value is not None else 0, + ) + # Blocking put applies backpressure instead of dropping messages. + await queue.put(queued) + except asyncio.CancelledError: + raise + except Exception as exc: + self._status[workflow_id] = { + "state": "failed", + "error": str(exc), + "broker": broker, + "topic": topic, + "groupId": group_id, + } + log.error("kafka.consumer_error", {"workflow_id": workflow_id, "error": str(exc)}) + await self._cleanup_runtime_resources(workflow_id) + finally: + try: + await consumer.stop() + except Exception: + pass + current = asyncio.current_task() + if self._tasks.get(workflow_id) is current: + self._tasks.pop(workflow_id, None) + + async def _worker_loop( + self, + workflow_id: str, + workflow_json: Any, + input_key: str, + configured_inputs: Dict[str, Any], + queue: asyncio.Queue, + abort: asyncio.Event, + ) -> None: + while not abort.is_set(): + try: + msg = await asyncio.wait_for(queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + return + try: + if isinstance(msg, _QueuedKafkaMessage): + msg = _decode_message(msg.raw_value) + await self._trigger_workflow( + workflow_id, workflow_json, msg, input_key, configured_inputs, + ) + except asyncio.CancelledError: + return + except Exception as exc: + log.warning( + "kafka.worker_dispatch_failed", + {"workflow_id": workflow_id, "error": str(exc)}, + ) + + async def _trigger_workflow( + self, + workflow_id: str, + workflow_json: Any, + message: Any, + input_key: str, + configured_inputs: Optional[Dict[str, Any]] = None, + ) -> None: + configured_inputs = _strip_execution_only_comments( + configured_inputs if isinstance(configured_inputs, dict) else {} + ) + inputs = {**configured_inputs, input_key: message} + input_params = {"_trigger": "kafka", input_key: _summarize_large_value(message)} + for key, value in configured_inputs.items(): + if key == input_key: + continue + input_params[key] = _summarize_large_value(value) + + exec_data = await create_execution_record( + workflow_id, + input_params=input_params, + ) + exec_id = exec_data["id"] + start_time = time.time() + + result = None + try: + result = await asyncio.to_thread( + run_workflow, + workflow=workflow_json, + inputs=inputs, + trace=False, + history_mode="summary", + ) + status, error_msg = resolve_execution_outcome(result) + duration = time.time() - start_time + exec_data.update({ + "status": status, + "outputResults": _compact_for_kafka_storage(result.outputs), + "finishedAt": int(time.time() * 1000), + "duration": duration, + "errorMessage": error_msg, + "executionLog": _compact_history_for_kafka_storage( + result.history, + input_key=input_key, + ), + "currentNodeId": result.last_node_id, + "currentPhase": status, + "currentStepIndex": result.steps, + }) + except Exception as exc: + duration = time.time() - start_time + log.error( + "kafka.workflow_run_failed", + {"workflow_id": workflow_id, "exec_id": exec_id, "error": str(exc)}, + ) + exec_data.update({ + "status": "error", + "errorMessage": str(exc), + "finishedAt": int(time.time() * 1000), + "duration": duration, + "currentPhase": "error", + }) + finally: + try: + await record_execution_result(workflow_id, exec_id, exec_data) + except Exception as exc: + log.warning("kafka.exec_record_failed", {"exec_id": exec_id, "error": str(exc)}) + + +default_manager = KafkaManager() diff --git a/flocks/input/dispatcher.py b/flocks/input/dispatcher.py index 1c71bde68..4d64b6ebf 100644 --- a/flocks/input/dispatcher.py +++ b/flocks/input/dispatcher.py @@ -106,16 +106,16 @@ async def _collect_text(text: str) -> None: async def _collect_prompt(prompt: str) -> None: llm_prompts.append(prompt) - # Pass only the optional callback, not sink.clear_screen: the latter is - # always a bound method (truthy) even when no _clear_screen was - # registered, which made /clear skip the "Screen cleared." fallback and - # publish an empty assistant message on WebUI. + # Pass only optional callbacks, not the bound methods on the sink: those + # are always truthy even when no concrete callback was registered. clear_cb = getattr(sink, "_clear_screen", None) + clear_history_cb = getattr(sink, "_clear_history", None) handled = await handle_slash_command( parsed.raw_text, send_text=_collect_text, send_prompt=_collect_prompt, clear_screen=clear_cb, + clear_history=clear_history_cb, surface=sink.surface, ) if handled: @@ -126,7 +126,8 @@ async def _collect_prompt(prompt: str) -> None: event.display_text or parsed.raw_text, ) return DispatchResult(action="llm", command_name=command_def.name) - await sink.publish_direct_response(event, "\n".join(direct_texts)) + if direct_texts: + await sink.publish_direct_response(event, "\n".join(direct_texts)) return DispatchResult(action="direct", command_name=command_def.name) await sink.run_llm(event, parsed.raw_text, event.display_text or parsed.raw_text) diff --git a/flocks/input/output.py b/flocks/input/output.py index fb40079ea..150a9e7a1 100644 --- a/flocks/input/output.py +++ b/flocks/input/output.py @@ -42,6 +42,9 @@ async def execute_session_control( async def clear_screen(self) -> None: return None + async def clear_history(self) -> None: + return None + class CallbackOutputSink(OutputSink): """Simple sink backed by async callbacks from each surface adapter.""" @@ -54,12 +57,14 @@ def __init__( run_llm: RunLlmCallback, session_control: Optional[SessionControlCallback] = None, clear_screen: Optional[SideEffectCallback] = None, + clear_history: Optional[SideEffectCallback] = None, ) -> None: super().__init__(surface) self._direct_response = direct_response self._run_llm = run_llm self._session_control = session_control self._clear_screen = clear_screen + self._clear_history = clear_history async def publish_direct_response(self, event: UserInputEvent, text: str) -> None: await self._direct_response(event, text) @@ -85,6 +90,10 @@ async def clear_screen(self) -> None: if self._clear_screen is not None: await self._clear_screen() + async def clear_history(self) -> None: + if self._clear_history is not None: + await self._clear_history() + class SSEOutputSink(CallbackOutputSink): """Output sink for SSE-backed surfaces such as WebUI/TUI/ACP.""" diff --git a/flocks/provider/catalog.json b/flocks/provider/catalog.json index 1000780d7..182654157 100644 --- a/flocks/provider/catalog.json +++ b/flocks/provider/catalog.json @@ -52,6 +52,30 @@ "THREATBOOK_CN_LLM_API_KEY" ], "models": { + "minimax-m3": { + "name": "minimax-m3", + "family": "minimax", + "capabilities": { + "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, + "supports_streaming": true + }, + "limits": { + "context_window": 1000000, + "max_output_tokens": 128000 + }, + "pricing": { + "input": 4.2, + "output": 16.8, + "currency": "CNY", + "note": "≤512k input: ¥4.20/M tokens output: ¥16.80/M tokens; >512k input: ¥8.40/M tokens output: ¥33.60/M tokens" + } + }, "minimax-m2.7": { "name": "minimax-m2.7", "family": "minimax", @@ -203,8 +227,8 @@ "supports_streaming": true }, "limits": { - "context_window": 200000, - "max_output_tokens": 128000 + "context_window": 1000000, + "max_output_tokens": 384000 }, "pricing": { "input": 1.0, @@ -237,6 +261,30 @@ "THREATBOOK_IO_LLM_API_KEY" ], "models": { + "minimax-m3": { + "name": "minimax-m3", + "family": "minimax", + "capabilities": { + "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, + "supports_streaming": true + }, + "limits": { + "context_window": 1000000, + "max_output_tokens": 128000 + }, + "pricing": { + "input": 4.2, + "output": 16.8, + "currency": "CNY", + "note": "≤512k input: ¥4.20/M tokens output: ¥16.80/M tokens; >512k input: ¥8.40/M tokens output: ¥33.60/M tokens" + } + }, "minimax-m2.7": { "name": "minimax-m2.7", "family": "minimax", @@ -361,8 +409,8 @@ "supports_streaming": true }, "limits": { - "context_window": 200000, - "max_output_tokens": 128000 + "context_window": 1000000, + "max_output_tokens": 384000 }, "pricing": { "input": 1.0, @@ -1344,6 +1392,30 @@ "MINIMAX_API_KEY" ], "models": { + "minimax-m3": { + "name": "MiniMax M3", + "family": "minimax", + "capabilities": { + "supports_tools": true, + "supports_reasoning": true, + "interleaved": { + "field": "reasoning_details", + "echo": "tool_calls", + "cross_provider_policy": "promote" + }, + "supports_streaming": true + }, + "limits": { + "context_window": 1000000, + "max_output_tokens": 128000 + }, + "pricing": { + "input": 4.2, + "output": 16.8, + "currency": "CNY", + "note": "≤512k input: ¥4.20/M tokens, output: ¥16.80/M tokens; >512k input: ¥8.40/M tokens, output: ¥33.60/M tokens" + } + }, "minimax-m2.7": { "name": "MiniMax M2.7", "family": "minimax-m2", diff --git a/flocks/provider/interleaved.py b/flocks/provider/interleaved.py index f62a0f8d2..aebadeec2 100644 --- a/flocks/provider/interleaved.py +++ b/flocks/provider/interleaved.py @@ -79,13 +79,6 @@ "trinity-large-thinking", ) -_PROMOTE_REASONING_DETAILS_TOKENS = ( - "minimax", - "gemini-3", - "gemini-3.1", -) - - def _lower(value: Optional[str]) -> str: return value.lower() if isinstance(value, str) else "" @@ -110,7 +103,7 @@ def infer_interleaved_capability( mid = _lower(model_id) burl = _lower(base_url) - if _matches_any(mid, *_PROMOTE_REASONING_DETAILS_TOKENS) or pid == "minimax": + if "minimax" in mid or pid == "minimax": return dict(_PROMOTE_REASONING_DETAILS) if "claude" in mid or "anthropic" in pid or "anthropic.com" in burl: diff --git a/flocks/provider/sdk/openai_compatible.py b/flocks/provider/sdk/openai_compatible.py index 391582be8..5e86d9592 100644 --- a/flocks/provider/sdk/openai_compatible.py +++ b/flocks/provider/sdk/openai_compatible.py @@ -42,10 +42,6 @@ class OpenAICompatibleProvider(BaseProvider): """OpenAI Compatible API provider""" - _MINIMAX_EMPTY_RESPONSE_TARGETS = { - "minimax-m2.5", - "minimax-m2.7", - } _MINIMAX_EMPTY_RESPONSE_RETRY_DELAY_SECONDS = 3 def __init__(self): @@ -143,7 +139,7 @@ def _normalize_model_id(cls, model_id: str) -> str: @classmethod def _is_minimax_empty_response_target(cls, model_id: str) -> bool: normalized = cls._normalize_model_id(model_id) - return any(target in normalized for target in cls._MINIMAX_EMPTY_RESPONSE_TARGETS) + return "minimax" in normalized @staticmethod def _has_non_empty_text_content(content: Any) -> bool: diff --git a/flocks/server/app.py b/flocks/server/app.py index bf8dba1da..402a75a9a 100644 --- a/flocks/server/app.py +++ b/flocks/server/app.py @@ -414,6 +414,42 @@ async def _delayed_syslog_start() -> None: except Exception as e: log.warning("syslog.manager.start_failed", {"error": str(e)}) + # Start Kafka consumers for workflows with kafka input enabled. + # Mirrors the syslog startup: a short delayed background task keeps the main + # startup path unblocked and avoids a crash-restart loop if a broker is down. + try: + from flocks.ingest.kafka.manager import default_manager as default_kafka_manager + + async def _delayed_kafka_start() -> None: + await asyncio.sleep(3) + try: + await default_kafka_manager.start_all() + log.info("kafka.manager.started") + except Exception as exc: + log.warning("kafka.manager.start_failed", {"error": str(exc)}) + + _schedule_startup_phase(app, log, "kafka.manager.start", _delayed_kafka_start) + except Exception as e: + log.warning("kafka.manager.start_failed", {"error": str(e)}) + + # Start workflow pollers for workflows with poller enabled. + # Mirrors Kafka/syslog startup so persistent slow-path workflows resume + # automatically without delaying server readiness. + try: + from flocks.workflow.poller_manager import default_manager as default_poller_manager + + async def _delayed_poller_start() -> None: + await asyncio.sleep(3) + try: + await default_poller_manager.start_all() + log.info("workflow.poller.started") + except Exception as exc: + log.warning("workflow.poller.start_failed", {"error": str(exc)}) + + _schedule_startup_phase(app, log, "workflow.poller.start", _delayed_poller_start) + except Exception as e: + log.warning("workflow.poller.start_failed", {"error": str(e)}) + try: from flocks.updater.updater import recover_upgrade_state @@ -486,6 +522,15 @@ async def _delayed_syslog_start() -> None: except Exception as e: log.warning("syslog.manager.stop_failed", {"error": str(e)}) + # Stop Kafka consumers + try: + from flocks.ingest.kafka.manager import default_manager as default_kafka_manager + + await default_kafka_manager.stop_all() + log.info("kafka.manager.stopped") + except Exception as e: + log.warning("kafka.manager.stop_failed", {"error": str(e)}) + # Stop Task Center try: from flocks.task.manager import TaskManager diff --git a/flocks/server/routes/agent.py b/flocks/server/routes/agent.py index cc7cc3f47..e8b30da39 100644 --- a/flocks/server/routes/agent.py +++ b/flocks/server/routes/agent.py @@ -27,6 +27,7 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field +import flocks.agent.delegatable_settings as delegatable_settings from flocks.agent.registry import Agent from flocks.agent.agent import AgentInfo as AgentInfoModel, AgentModel as AgentModelConfig from flocks.agent.agent_factory import find_yaml_agent, read_yaml_agent, update_yaml_agent, delete_yaml_agent @@ -87,11 +88,16 @@ def agent_to_response( agent: AgentInfoModel, model_override: Optional[Dict[str, str]] = None, temperature_override: Optional[float] = None, + delegatable_override: Optional[bool] = None, skills: Optional[List[str]] = None, tools: Optional[List[str]] = None, ) -> AgentResponse: """Convert internal AgentInfo to API response format.""" - delegatable = agent.delegatable if agent.delegatable is not None else True + delegatable = ( + delegatable_override + if delegatable_override is not None + else agent.delegatable if agent.delegatable is not None else agent.mode != "primary" + ) if model_override: model_info = AgentModelInfo( @@ -134,6 +140,10 @@ def _agent_data_to_info(agent_data: Dict[str, Any]) -> AgentInfoModel: Used after create/update to keep ``_custom_agents`` in sync with Storage. """ model_data = agent_data.get("model") + mode = agent_data.get("mode", "primary") + delegatable = agent_data.get("delegatable") + if delegatable is None: + delegatable = mode != "primary" return AgentInfoModel( name=agent_data["name"], description=agent_data.get("description") or "", @@ -141,13 +151,14 @@ def _agent_data_to_info(agent_data: Dict[str, Any]) -> AgentInfoModel: prompt=agent_data.get("prompt") or "", temperature=agent_data.get("temperature"), color=agent_data.get("color"), - mode=agent_data.get("mode", "primary"), + mode=mode, model=AgentModelConfig( model_id=model_data["modelID"], provider_id=model_data["providerID"], ) if model_data else None, native=False, hidden=False, + delegatable=delegatable, ) @@ -155,6 +166,10 @@ def _custom_agent_data_to_response(agent_data: Dict[str, Any]) -> AgentResponse: """Build an AgentResponse from a custom agent's stored data dict.""" model_data = agent_data.get("model") model_info = AgentModelInfo(**model_data) if model_data else None + mode = agent_data.get("mode", "primary") + delegatable = agent_data.get("delegatable") + if delegatable is None: + delegatable = mode != "primary" return AgentResponse( name=agent_data["name"], description=agent_data.get("description"), @@ -162,12 +177,13 @@ def _custom_agent_data_to_response(agent_data: Dict[str, Any]) -> AgentResponse: prompt=agent_data.get("prompt"), temperature=agent_data.get("temperature"), color=agent_data.get("color"), - mode=agent_data.get("mode", "primary"), + mode=mode, model=model_info, native=agent_data.get("native", False), hidden=agent_data.get("hidden", False), permission=[], options={}, + delegatable=delegatable, skills=agent_data.get("skills", []), tools=agent_data.get("tools", []), tags=agent_data.get("tags", []), @@ -184,6 +200,10 @@ async def _load_model_overrides() -> Dict[str, Dict[str, Any]]: return {} +def _load_delegatable_overrides() -> Dict[str, bool]: + return delegatable_settings.load_overrides() + + async def _load_custom_agent_extras(name: str) -> tuple[List[str], List[str]]: """Load skills/tools list for an agent from storage. @@ -218,6 +238,7 @@ def _compute_native_agent_tools(agent: AgentInfoModel, all_tool_names: List[str] async def _build_single_agent_response( agent: AgentInfoModel, overrides: Dict[str, Dict[str, Any]], + delegatable_overrides: Dict[str, bool], all_tool_names: List[str], ) -> AgentResponse: """Build AgentResponse for one agent, resolving model overrides and tools/skills.""" @@ -233,6 +254,7 @@ async def _build_single_agent_response( agent, model_override=model_override, temperature_override=temperature_override, + delegatable_override=delegatable_overrides.get(agent.name), skills=skills, tools=tools, ) @@ -263,12 +285,13 @@ async def list_agents(): try: agents = await Agent.list() overrides = await _load_model_overrides() + delegatable_overrides = _load_delegatable_overrides() all_tool_names = _get_all_tool_names() result = [] for agent in agents: if agent.hidden: continue - result.append(await _build_single_agent_response(agent, overrides, all_tool_names)) + result.append(await _build_single_agent_response(agent, overrides, delegatable_overrides, all_tool_names)) return result except Exception as e: log.error("agent.list.error", {"error": str(e)}) @@ -283,8 +306,9 @@ async def get_agent(name: str): if not agent: raise HTTPException(status_code=404, detail=f"Agent {name} not found") overrides = await _load_model_overrides() + delegatable_overrides = _load_delegatable_overrides() all_tool_names = _get_all_tool_names() - return await _build_single_agent_response(agent, overrides, all_tool_names) + return await _build_single_agent_response(agent, overrides, delegatable_overrides, all_tool_names) except HTTPException: raise except Exception as e: @@ -321,6 +345,7 @@ class AgentCreateRequest(BaseModel): color: Optional[str] = Field(None, description="Color") mode: str = Field("primary", description="Agent mode") model: Optional[AgentModelInfo] = Field(None, description="Preferred model") + delegatable: Optional[bool] = Field(None, description="Whether this agent can be delegated to") skills: List[str] = Field(default_factory=list, description="Enabled skill names") tools: List[str] = Field(default_factory=list, description="Enabled tool names") @@ -333,6 +358,7 @@ class AgentUpdateRequest(BaseModel): temperature: Optional[float] = Field(None, description="Temperature") color: Optional[str] = Field(None, description="Color") model: Optional[AgentModelInfo] = Field(None, description="Preferred model") + delegatable: Optional[bool] = Field(None, description="Whether this agent can be delegated to") skills: Optional[List[str]] = Field(None, description="Enabled skill names") tools: Optional[List[str]] = Field(None, description="Enabled tool names") @@ -343,6 +369,11 @@ class AgentModelUpdateRequest(BaseModel): temperature: Optional[float] = Field(None, description="Temperature override for native agents") +class AgentDelegatableUpdateRequest(BaseModel): + """Request to update the delegatable toggle without rewriting YAML.""" + delegatable: bool = Field(..., description="Whether this agent can be delegated to") + + @router.post("", response_model=AgentResponse, summary="Create custom agent") async def create_agent(req: AgentCreateRequest): """ @@ -366,6 +397,7 @@ async def create_agent(req: AgentCreateRequest): "color": req.color, "mode": req.mode, "model": req.model.model_dump() if req.model else None, + "delegatable": req.delegatable if req.delegatable is not None else req.mode != "primary", "native": False, "hidden": False, "skills": req.skills, @@ -416,6 +448,9 @@ async def update_agent(name: str, req: AgentUpdateRequest): agent_data["color"] = req.color if req.model is not None: agent_data["model"] = req.model.model_dump() + if req.delegatable is not None: + agent_data["delegatable"] = req.delegatable + delegatable_settings.forget_override(name) if req.skills is not None: agent_data["skills"] = req.skills if req.tools is not None: @@ -445,6 +480,8 @@ async def update_agent(name: str, req: AgentUpdateRequest): updates["color"] = req.color if req.model is not None: updates["model"] = req.model.model_dump() + if req.delegatable is not None: + updates["delegatable"] = req.delegatable if not update_yaml_agent(name, updates): raise HTTPException(status_code=500, detail=f"Failed to write YAML for agent {name}") @@ -478,9 +515,12 @@ async def update_agent(name: str, req: AgentUpdateRequest): model_id=req.model.modelID, provider_id=req.model.providerID, ) + if req.delegatable is not None: + agent.delegatable = req.delegatable overrides = await _load_model_overrides() + delegatable_overrides = _load_delegatable_overrides() all_tool_names = _get_all_tool_names() - return await _build_single_agent_response(agent, overrides, all_tool_names) + return await _build_single_agent_response(agent, overrides, delegatable_overrides, all_tool_names) yaml_data = read_yaml_agent(name) or {} return _custom_agent_data_to_response(yaml_data) @@ -492,6 +532,55 @@ async def update_agent(name: str, req: AgentUpdateRequest): raise HTTPException(status_code=500, detail=str(e)) +@router.patch("/{name}/delegatable", response_model=AgentResponse, summary="Update agent delegatable toggle") +async def update_agent_delegatable(name: str, req: AgentDelegatableUpdateRequest): + """Update footer-toggle delegatable state without rewriting YAML sources.""" + from flocks.storage.storage import Storage + + try: + agent = await Agent.get(name) + if not agent: + raise HTTPException(status_code=404, detail=f"Agent {name} not found") + + agent_key = f"agent/custom/{name}" + agent_data: Optional[Dict[str, Any]] = None + try: + agent_data = await Storage.read(agent_key) + except Storage.NotFoundError: + pass + + if agent_data is not None and agent_data.get("name"): + agent_data["delegatable"] = req.delegatable + await Storage.write(agent_key, agent_data) + delegatable_settings.forget_override(name) + + from flocks.agent.registry import Agent as AgentRegistry + + AgentRegistry.register(name, _agent_data_to_info(agent_data)) + await AgentRegistry.refresh() + + log.info("agent.delegatable.updated", {"name": name, "source": "storage", "delegatable": req.delegatable}) + return _custom_agent_data_to_response(agent_data) + + delegatable_settings.set_override(name, req.delegatable) + await Agent.refresh() + agent = await Agent.get(name) + if not agent: + raise HTTPException(status_code=404, detail=f"Agent {name} not found") + + overrides = await _load_model_overrides() + delegatable_overrides = _load_delegatable_overrides() + all_tool_names = _get_all_tool_names() + + log.info("agent.delegatable.updated", {"name": name, "source": "override", "delegatable": req.delegatable}) + return await _build_single_agent_response(agent, overrides, delegatable_overrides, all_tool_names) + except HTTPException: + raise + except Exception as e: + log.error("agent.delegatable.update.error", {"error": str(e), "name": name}) + raise HTTPException(status_code=500, detail=str(e)) + + @router.delete("/{name}", summary="Delete custom agent") async def delete_agent(name: str): """ @@ -531,6 +620,7 @@ async def delete_agent(name: str): from flocks.hub import local as hub_local hub_local.remove_installed_record("agent", name) + delegatable_settings.forget_override(name) # Sync: remove from in-memory agent cache from flocks.agent.registry import Agent as AgentRegistry @@ -581,7 +671,12 @@ async def update_agent_model(name: str, req: AgentModelUpdateRequest): log.info("agent.model_override.saved", {"name": name, "model": req.model, "temperature": req.temperature}) override = overrides.get(name, {}) model_override = {k: override[k] for k in ("modelID", "providerID") if k in override} or None - return agent_to_response(agent, model_override=model_override, temperature_override=override.get("temperature")) + return agent_to_response( + agent, + model_override=model_override, + temperature_override=override.get("temperature"), + delegatable_override=_load_delegatable_overrides().get(name), + ) else: # --- Try Storage-based custom agent --- agent_key = f"agent/custom/{name}" @@ -619,8 +714,9 @@ async def update_agent_model(name: str, req: AgentModelUpdateRequest): log.info("agent.model.updated", {"name": name, "source": "yaml"}) overrides = await _load_model_overrides() + delegatable_overrides = _load_delegatable_overrides() all_tool_names = _get_all_tool_names() - return await _build_single_agent_response(agent, overrides, all_tool_names) + return await _build_single_agent_response(agent, overrides, delegatable_overrides, all_tool_names) raise HTTPException(status_code=404, detail=f"Custom agent {name} not found") except HTTPException: diff --git a/flocks/server/routes/console_upgrade.py b/flocks/server/routes/console_upgrade.py index f89d872d6..33324c722 100644 --- a/flocks/server/routes/console_upgrade.py +++ b/flocks/server/routes/console_upgrade.py @@ -478,7 +478,7 @@ async def _report_pro_bundle_installation( return marker = _read_pro_bundle_install_marker() payload = { - "license_id": record.get("activate_key"), + "license_id": _record_license_id(record), "fingerprint": console_session.get("fingerprint"), "install_id": console_session.get("install_id"), "installed_version": marker.get("installed_version") or details.get("auto_install_target") or details.get("auto_install_version") or "", @@ -896,13 +896,18 @@ async def _stream(): raw["updated_at"] = datetime.now(UTC).isoformat() await Storage.set(_request_key(request_id), raw, "json") await _report_pro_bundle_installation(raw, install_result="failed", error_message=progress.message) - elif progress.stage in {"done", "restarting"}: + elif progress.stage == "restarting": + details["auto_install_result"] = "restarting" + details["auto_install_message"] = progress.message + raw["updated_at"] = datetime.now(UTC).isoformat() + await Storage.set(_request_key(request_id), raw, "json") + elif progress.stage == "done": marker = _read_pro_bundle_install_marker() await _maybe_activate_pro_license(raw) await _maybe_refresh_pro_license(raw) capability = _record_pro_capability(details) if capability.get("pro_enabled"): - details["auto_install_result"] = "restarting" if progress.stage == "restarting" else "done" + details["auto_install_result"] = "done" else: details["auto_install_result"] = "license_inactive" details["auto_install_version"] = marker.get("installed_version") @@ -1008,4 +1013,3 @@ async def cancel_upgrade_request(request_id: str, request: Request) -> UpgradeRe raw["updated_at"] = datetime.now(UTC).isoformat() await Storage.set(_request_key(request_id), raw, "json") return UpgradeRequestStatus(**raw) - diff --git a/flocks/server/routes/device.py b/flocks/server/routes/device.py index f5785bffa..49ca185a9 100644 --- a/flocks/server/routes/device.py +++ b/flocks/server/routes/device.py @@ -30,15 +30,18 @@ from flocks.tool.device.store import ( create_group, delete_device_row, + delete_device_tool_setting, delete_group, fetch_device, get_group, group_exists, insert_device, + list_device_tool_settings, list_devices, list_groups, record_test_result, row_to_device, + set_device_tool_enabled, storage_key_to_service_id, update_device_row, update_group, @@ -220,9 +223,139 @@ async def route_delete_device(device_id: str): delete_secrets(device_id, db_fields) await delete_device_row(device_id) + # Per-device tool settings are cleaned up automatically via + # ON DELETE CASCADE on the device_tool_settings table. await sync_service_tool_state(service_id, deleted_storage_keys=[storage_key]) +# =========================================================================== +# Per-device tool settings routes +# =========================================================================== + +class DeviceToolInfo(BaseModel): + """Tool information with per-device enabled state.""" + name: str + description: str + description_cn: Optional[str] = None + enabled_global: bool = Field( + ..., + description="全局工具开关状态(影响所有同版本设备)", + ) + enabled_device: Optional[bool] = Field( + None, + description=( + "本设备的工具开关覆盖值。null 表示未设置覆盖,遵从全局状态;" + "true/false 表示该设备有独立的启用/禁用设置。" + ), + ) + enabled_effective: bool = Field( + ..., + description="最终生效状态(per-device 覆盖 > 全局 > 出厂默认)", + ) + + +class DeviceToolUpdateRequest(BaseModel): + enabled: bool = Field(..., description="启用或禁用此设备上的工具") + + +@router.get("/{device_id}/tools", response_model=List[DeviceToolInfo]) +async def route_list_device_tools(device_id: str): + """列出设备对应插件的所有工具,并附带该设备的独立开关状态。 + + 返回的 ``enabled_effective`` 字段反映实际执行时的生效状态: + - 若存在 per-device 覆盖(enabled_device 非 null),以它为准; + - 否则沿用全局 tool_settings(enabled_global)。 + """ + row = await fetch_device(device_id) + if row is None: + raise HTTPException( + status_code=http_status.HTTP_404_NOT_FOUND, detail="Device not found" + ) + + from flocks.tool.registry import ToolRegistry + + storage_key: str = row["storage_key"] + ToolRegistry.init() + + # Collect tools that belong to this device's plugin (matching provider). + device_tools = [ + t for t in ToolRegistry.list_tools() + if t.provider == storage_key and t.source == "device" + ] + + # Read per-device overrides once: {tool_name: enabled_bool}. + per_device = await list_device_tool_settings(device_id) + + result: List[DeviceToolInfo] = [] + for t in device_tools: + enabled_device: Optional[bool] = per_device.get(t.name) + + enabled_global = t.enabled + enabled_effective = ( + enabled_device if enabled_device is not None else enabled_global + ) + result.append( + DeviceToolInfo( + name=t.name, + description=t.description, + description_cn=t.description_cn, + enabled_global=enabled_global, + enabled_device=enabled_device, + enabled_effective=enabled_effective, + ) + ) + + return result + + +@router.patch("/{device_id}/tools/{tool_name}", response_model=DeviceToolInfo) +async def route_update_device_tool( + device_id: str, tool_name: str, body: DeviceToolUpdateRequest +): + """设置或清除某工具在指定设备上的独立开关。 + + - ``enabled=false`` → 仅在该设备上禁用工具,不影响同版本其他设备; + - ``enabled=true`` → 移除 per-device 覆盖,恢复遵从全局工具开关。 + """ + row = await fetch_device(device_id) + if row is None: + raise HTTPException( + status_code=http_status.HTTP_404_NOT_FOUND, detail="Device not found" + ) + + from flocks.tool.registry import ToolRegistry + + ToolRegistry.init() + storage_key: str = row["storage_key"] + tool = ToolRegistry.get(tool_name) + if tool is None or tool.info.provider != storage_key: + raise HTTPException( + status_code=http_status.HTTP_404_NOT_FOUND, + detail=f"Tool '{tool_name}' does not belong to this device", + ) + + if body.enabled: + # Removing the override restores global behaviour. + await delete_device_tool_setting(device_id, tool_name) + enabled_device = None + else: + await set_device_tool_enabled(device_id, tool_name, False) + enabled_device = False + + enabled_global = tool.info.enabled + enabled_effective = ( + enabled_device if enabled_device is not None else enabled_global + ) + return DeviceToolInfo( + name=tool_name, + description=tool.info.description, + description_cn=tool.info.description_cn, + enabled_global=enabled_global, + enabled_device=enabled_device, + enabled_effective=enabled_effective, + ) + + class DeviceTestRequest(BaseModel): """Optional body for ``POST /devices/{id}/test``. diff --git a/flocks/server/routes/session.py b/flocks/server/routes/session.py index 47038fbb4..23eac9bac 100644 --- a/flocks/server/routes/session.py +++ b/flocks/server/routes/session.py @@ -685,13 +685,8 @@ async def unshare_session_local(sessionID: str, http_request: Request) -> Sessio # Session Actions # ============================================================================= -@router.post( - "/{sessionID}/abort", - summary="Abort session", - description="Abort an active session and stop any ongoing processing", -) -async def abort_session(sessionID: str, http_request: Request) -> bool: - """Abort session processing. +async def _abort_session_processing(sessionID: str) -> bool: + """Abort active processing for a session and notify subscribers. Aborts both the SessionLoop (sets abort_event so the next step check stops the loop) and the SessionRunner (stops the current LLM stream). @@ -705,15 +700,6 @@ async def abort_session(sessionID: str, http_request: Request) -> bool: from flocks.session.session_loop import SessionLoop from flocks.server.routes.question import reject_session_questions - current_user = require_user(http_request) - session = await _get_session_by_id_unfiltered(sessionID) - if not session: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"Session {sessionID} not found", - ) - _require_session_write_access(session, current_user) - # Abort the loop-level context (propagates to runner via shared abort_event) loop_aborted = SessionLoop.abort(sessionID) @@ -757,6 +743,25 @@ async def abort_session(sessionID: str, http_request: Request) -> bool: return True +@router.post( + "/{sessionID}/abort", + summary="Abort session", + description="Abort an active session and stop any ongoing processing", +) +async def abort_session(sessionID: str, http_request: Request = None) -> bool: + """Abort session processing.""" + if http_request is not None: + current_user = require_user(http_request) + session = await _get_session_by_id_unfiltered(sessionID) + if not session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session {sessionID} not found", + ) + _require_session_write_access(session, current_user) + return await _abort_session_processing(sessionID) + + class ForkRequest(BaseModel): """Request to fork session""" messageID: Optional[str] = Field(None, description="Message ID to fork up to") @@ -2427,7 +2432,7 @@ def _materialize_data_url_to_disk( _, _, tail = filename_hint.rpartition(".") if tail.lower() in _UPLOAD_SAFE_EXTS: ext = "." + tail.lower() - unique_name = f"{Identifier.create('upload')}{ext}" + unique_name = f"{Identifier.create('part')}{ext}" target = uploads_root / unique_name target.write_bytes(raw_bytes) return f"file://{target.resolve()}" @@ -2828,6 +2833,220 @@ def _coerce_model_for_prompt_request(model: Any): return model +def _prompt_queue_lock(session_id: str) -> asyncio.Lock: + if not hasattr(router, "_prompt_queue_drain_locks"): + router._prompt_queue_drain_locks = {} + locks = router._prompt_queue_drain_locks + lock = locks.get(session_id) + if lock is None: + lock = asyncio.Lock() + locks[session_id] = lock + return lock + + +def _is_prompt_chain_active(session_id: str) -> bool: + return session_id in getattr(router, "_prompt_queue_active_sessions", set()) + + +def _set_prompt_chain_active(session_id: str, active: bool) -> None: + if not hasattr(router, "_prompt_queue_active_sessions"): + router._prompt_queue_active_sessions = set() + active_sessions = router._prompt_queue_active_sessions + if active: + active_sessions.add(session_id) + else: + active_sessions.discard(session_id) + + +async def _publish_prompt_queue(session_id: str) -> None: + from flocks.server.routes.event import publish_event + from flocks.session.interaction_queue import InteractionQueue + + items = await InteractionQueue.list(session_id) + await publish_event("session.prompt_queue.updated", { + "sessionID": session_id, + "items": [item.model_dump() for item in items], + }) + + +def _materialize_queued_parts(session_id: str, parts: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Persist queued data URLs so large base64 payloads do not sit in memory.""" + prepared: List[Dict[str, Any]] = [] + for part in parts: + next_part = dict(part) + url = next_part.get("url") + if next_part.get("type") == "file" and isinstance(url, str) and url.startswith("data:"): + mime = next_part.get("mime") or "" + filename = next_part.get("filename") + next_part["url"] = _materialize_data_url_part(session_id, url, mime, filename) + prepared.append(next_part) + return prepared + + +def _materialize_data_url_part( + session_id: str, + data_url: str, + mime_hint: str, + filename_hint: Optional[str], +) -> str: + try: + import base64 + from flocks.workspace.manager import WorkspaceManager + from flocks.utils.id import Identifier + + _header, _sep, encoded = data_url.partition(",") + if not encoded: + return data_url + raw_bytes = base64.b64decode(encoded) + ws = WorkspaceManager.get_instance() + uploads_root = ws.resolve_workspace_path(f"uploads/{session_id}") + uploads_root.mkdir(parents=True, exist_ok=True) + + ext_map = { + "image/png": ".png", + "image/jpeg": ".jpg", + "image/jpg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/bmp": ".bmp", + "application/pdf": ".pdf", + } + ext = ext_map.get(mime_hint, "") + if not ext and filename_hint: + _, _, tail = filename_hint.rpartition(".") + if tail.lower() in _UPLOAD_SAFE_EXTS: + ext = "." + tail.lower() + target = uploads_root / f"{Identifier.create('part')}{ext}" + target.write_bytes(raw_bytes) + return f"file://{target.resolve()}" + except Exception as exc: + log.warn("session.prompt_queue.materialize_failed", { + "sessionID": session_id, + "error": str(exc), + }) + return data_url + + +def _event_from_queued_prompt(item, working_directory: str): + from flocks.input.events import UserInputEvent + + return UserInputEvent( + source_type="webui", + sessionID=item.sessionID, + text=_extract_text_from_parts(item.parts), + parts=[dict(part) for part in item.parts], + agent=item.agent, + model=item.model, + variant=item.variant, + display_text=None, + messageID=item.messageID, + noReply=item.noReply, + mockReply=item.mockReply, + tools=item.tools, + system=item.system, + working_directory=working_directory, + ) + + +async def _drain_prompt_queue_locked(session_id: str, working_directory: str) -> bool: + from flocks.project.bootstrap import instance_bootstrap + from flocks.project.instance import Instance + from flocks.session.interaction_queue import InteractionQueue + from flocks.session.session_loop import SessionLoop + + while True: + if SessionLoop.is_running(session_id): + return False + + item = await InteractionQueue.pop_next(session_id) + if item is None: + await _publish_prompt_queue(session_id) + return True + + await _publish_prompt_queue(session_id) + session = await Session.get_by_id(session_id) + if not session: + log.warn("session.prompt_queue.session_missing", {"sessionID": session_id, "queueID": item.id}) + continue + + event = _event_from_queued_prompt(item, working_directory) + log.info("session.prompt_queue.dispatch", { + "sessionID": session_id, + "queueID": item.id, + }) + await Instance.provide( + directory=working_directory, + init=instance_bootstrap, + fn=lambda: _dispatch_sse_input(session_id, session, event, working_directory), + ) + + +async def _run_prompt_event_chain(session_id: str, session, event, working_directory: str) -> None: + from flocks.project.bootstrap import instance_bootstrap + from flocks.project.instance import Instance + + try: + async with _prompt_queue_lock(session_id): + dispatch_failed = False + try: + await Instance.provide( + directory=working_directory, + init=instance_bootstrap, + fn=lambda: _dispatch_sse_input(session_id, session, event, working_directory), + ) + except Exception: + dispatch_failed = True + raise + finally: + try: + await _drain_prompt_queue_locked(session_id, working_directory) + except Exception as drain_exc: + if dispatch_failed: + log.error("session.prompt_queue.drain_after_error_failed", { + "sessionID": session_id, + "error": str(drain_exc), + }) + else: + raise + finally: + _set_prompt_chain_active(session_id, False) + + +async def _schedule_prompt_queue_drain(session_id: str, working_directory: str) -> None: + max_attempts = 80 + retry_interval_s = 0.25 + + async def _run() -> None: + try: + async with _prompt_queue_lock(session_id): + for attempt in range(max_attempts): + completed = await _drain_prompt_queue_locked(session_id, working_directory) + if completed: + return + await asyncio.sleep(retry_interval_s) + log.warn("session.prompt_queue.drain_retry_exhausted", { + "sessionID": session_id, + "attempts": max_attempts, + }) + finally: + _set_prompt_chain_active(session_id, False) + + _set_prompt_chain_active(session_id, True) + _schedule_background_coro( + _run(), + session_id=session_id, + action="prompt_queue.drain", + ) + + +async def _wait_for_session_idle(session_id: str, timeout_s: float = 5.0) -> None: + from flocks.session.session_loop import SessionLoop + + deadline = time.time() + timeout_s + while SessionLoop.is_running(session_id) and time.time() < deadline: + await asyncio.sleep(0.05) + + def _build_prompt_request_from_event(event, prompt_text: str, display_text: Optional[str] = None): import types @@ -2952,6 +3171,9 @@ async def _run_llm(output_event, prompt_text: str, display_text: Optional[str] = request = _build_prompt_request_from_event(output_event, prompt_text, display_text) await _process_session_message(sessionID, session, request, working_directory) + async def _clear_history() -> None: + await _clear_session_history(sessionID) + async def _run_session_control(output_event, parsed) -> bool: if parsed.canonical_name != "compact": return False @@ -2987,10 +3209,153 @@ async def _run_session_control(output_event, parsed) -> bool: direct_response=_publish_direct_response, run_llm=_run_llm, session_control=_run_session_control, + clear_history=_clear_history, ) await dispatch_user_input(event, sink) +class PromptQueueUpdateRequest(BaseModel): + text: str = Field(..., description="Updated queued prompt text") + + +async def _enqueue_prompt_request( + session_id: str, + request: PromptRequest, +): + from flocks.session.interaction_queue import InteractionQueue + + model = request.model.model_dump(by_alias=True) if request.model else None + parts = _materialize_queued_parts(session_id, [dict(part) for part in request.parts]) + return await InteractionQueue.enqueue( + session_id, + parts=parts, + agent=request.agent, + model=model, + variant=request.variant, + message_id=request.messageID, + no_reply=request.noReply, + mock_reply=request.mockReply, + tools=request.tools, + system=request.system, + ) + + +@router.get( + "/{sessionID}/prompt_queue", + summary="List queued prompts", + description="List pending non-blocking prompts for a session", +) +async def list_prompt_queue(sessionID: str) -> Dict[str, Any]: + from flocks.session.interaction_queue import InteractionQueue + + session = await Session.get_by_id(sessionID) + if not session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session {sessionID} not found", + ) + items = await InteractionQueue.list(sessionID) + return {"sessionID": sessionID, "items": [item.model_dump() for item in items]} + + +@router.post( + "/{sessionID}/prompt_queue", + status_code=status.HTTP_202_ACCEPTED, + summary="Queue prompt", + description="Queue a prompt without writing it to the formal message history", +) +async def enqueue_prompt(sessionID: str, request: PromptRequest) -> Dict[str, Any]: + from flocks.session.interaction_queue import QueueFullError + + session = await Session.get_by_id(sessionID) + if not session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session {sessionID} not found", + ) + try: + item = await _enqueue_prompt_request(sessionID, request) + except QueueFullError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + await _publish_prompt_queue(sessionID) + return {"status": "queued", "sessionID": sessionID, "queueID": item.id} + + +@router.patch( + "/{sessionID}/prompt_queue/{queueID}", + summary="Update queued prompt", + description="Update the text part of a queued prompt", +) +async def update_prompt_queue_item( + sessionID: str, + queueID: str, + request: PromptQueueUpdateRequest, +) -> Dict[str, Any]: + from flocks.session.interaction_queue import InteractionQueue, QueueItemNotFoundError + + text = request.text.strip() + if not text: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Queued prompt text cannot be empty", + ) + try: + item = await InteractionQueue.update_text(sessionID, queueID, text) + except QueueItemNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + await _publish_prompt_queue(sessionID) + return {"status": "updated", "sessionID": sessionID, "item": item.model_dump()} + + +@router.delete( + "/{sessionID}/prompt_queue/{queueID}", + summary="Remove queued prompt", + description="Remove a queued prompt before it executes", +) +async def remove_prompt_queue_item(sessionID: str, queueID: str) -> Dict[str, Any]: + from flocks.session.interaction_queue import InteractionQueue, QueueItemNotFoundError + + try: + await InteractionQueue.remove(sessionID, queueID) + except QueueItemNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + await _publish_prompt_queue(sessionID) + return {"status": "removed", "sessionID": sessionID, "queueID": queueID} + + +@router.post( + "/{sessionID}/prompt_queue/{queueID}/run_now", + status_code=status.HTTP_202_ACCEPTED, + summary="Run queued prompt now", + description="Abort the current prompt and run the selected queued prompt next", +) +async def run_prompt_queue_item_now(sessionID: str, queueID: str) -> Dict[str, Any]: + import os + + from flocks.session.interaction_queue import InteractionQueue, QueueItemNotFoundError + from flocks.session.session_loop import SessionLoop + + session = await Session.get_by_id(sessionID) + if not session: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session {sessionID} not found", + ) + working_directory = session.directory or os.getcwd() + try: + await InteractionQueue.promote(sessionID, queueID) + except QueueItemNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + await _publish_prompt_queue(sessionID) + + if SessionLoop.is_running(sessionID): + await abort_session(sessionID) + await _wait_for_session_idle(sessionID) + + await _schedule_prompt_queue_drain(sessionID, working_directory) + return {"status": "accepted", "sessionID": sessionID, "queueID": queueID} + + @router.post( "/{sessionID}/prompt_async", status_code=status.HTTP_202_ACCEPTED, @@ -3005,6 +3370,8 @@ async def send_session_message_async( """Send message asynchronously - returns 202 immediately, response via SSE""" import os from flocks.input.events import UserInputEvent + from flocks.session.interaction_queue import InteractionQueue, QueueFullError + from flocks.session.session_loop import SessionLoop session = await _get_session_by_id_unfiltered(sessionID) if not session: @@ -3038,61 +3405,24 @@ async def send_session_message_async( system=request.system, working_directory=working_directory, ) - - # Use the same synchronous processing path as send_session_message - # but run it as a background task via asyncio.ensure_future - import asyncio - - async def _run_in_background(): - import traceback - import sys + + existing_queue = await InteractionQueue.list(sessionID) + if SessionLoop.is_running(sessionID) or existing_queue or _is_prompt_chain_active(sessionID): try: - log.info("session.prompt_async.processing_start", { - "sessionID": sessionID, - }) - - from flocks.project.instance import Instance - from flocks.project.bootstrap import instance_bootstrap - - await Instance.provide( - directory=working_directory, - init=instance_bootstrap, - fn=lambda: _dispatch_sse_input(sessionID, session, event, working_directory), - ) - log.info("session.prompt_async.processing_complete", { - "sessionID": sessionID, - }) - except Exception as e: - tb = traceback.format_exc() - log.error("session.prompt_async.error", { - "sessionID": sessionID, - "error": str(e), - "error_type": type(e).__name__, - }) - print(f"[prompt_async ERROR] {sessionID}: {e}\n{tb}", file=sys.stderr, flush=True) - # Clear session busy status on error - try: - from flocks.session.core.status import SessionStatus - SessionStatus.clear(sessionID) - except Exception: - pass - # Publish error event so frontend gets notified - from flocks.server.routes.event import publish_event - error_msg = str(e) - await publish_event("session.error", { - "sessionID": sessionID, - "error": {"name": type(e).__name__, "message": error_msg, "data": {"message": error_msg}}, - }) - - # Schedule as asyncio task with explicit reference tracking - loop = asyncio.get_running_loop() - task = loop.create_task(_run_in_background()) - # Store reference on the app state to prevent GC - if not hasattr(router, '_pending_tasks'): - router._pending_tasks = set() - router._pending_tasks.add(task) - task.add_done_callback(lambda t: router._pending_tasks.discard(t)) - + item = await _enqueue_prompt_request(sessionID, request) + except QueueFullError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + await _publish_prompt_queue(sessionID) + if not SessionLoop.is_running(sessionID): + await _schedule_prompt_queue_drain(sessionID, working_directory) + return {"status": "queued", "sessionID": sessionID, "queueID": item.id} + + _set_prompt_chain_active(sessionID, True) + _schedule_background_coro( + _run_prompt_event_chain(sessionID, session, event, working_directory), + session_id=sessionID, + action="session.prompt_async", + ) return {"status": "accepted", "sessionID": sessionID} @@ -3115,13 +3445,14 @@ class CommandRequest(BaseModel): summary="Send command", description="Execute a slash command in the session (returns 202, result via SSE)", ) -async def send_session_command(sessionID: str, request: CommandRequest, http_request: Request): +async def send_session_command(sessionID: str, request: CommandRequest, http_request: Request = None): """ Execute a slash command. - Direct commands (/tools, /skills, /help, /mcp, /clear) are handled - without calling the LLM. Their output is pushed as an assistant message - directly via SSE. + Direct commands (/tools, /skills, /help, /mcp) are handled without calling + the LLM. Their output is pushed as an assistant message directly via SSE. + Side-effecting direct commands like /clear run without creating a chat + message and instead update session state via callbacks. LLM-based commands (/plan, /ask, /init, /compact, ...) are routed through the normal session-loop pipeline. @@ -3142,8 +3473,9 @@ async def send_session_command(sessionID: str, request: CommandRequest, http_req status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {sessionID} not found", ) - current_user = require_user(http_request) - _require_session_write_access(session, current_user) + if http_request is not None: + current_user = require_user(http_request) + _require_session_write_access(session, current_user) working_directory = session.directory or os.getcwd() @@ -3452,6 +3784,41 @@ async def get_session_statistics(sessionID: str): raise HTTPException(status_code=500, detail=f"Failed to get session statistics: {str(e)}") +async def _clear_session_history(sessionID: str) -> int: + """Clear stored messages for a session and notify subscribed UIs.""" + session_info = await _get_session_by_id_unfiltered(sessionID) + if not session_info: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Session {sessionID} not found", + ) + + from flocks.server.routes.event import publish_event + from flocks.session.interaction_queue import InteractionQueue + from flocks.session.message import Message + + await abort_session(sessionID) + await InteractionQueue.clear(sessionID) + try: + await _publish_prompt_queue(sessionID) + except Exception as exc: + log.warn("session.clear.prompt_queue_event_error", {"sessionID": sessionID, "error": str(exc)}) + await _wait_for_session_idle(sessionID) + + deleted_count = await Message.clear(sessionID) + log.info("session.cleared", {"sessionID": sessionID, "deleted": deleted_count}) + + try: + await publish_event("session.cleared", { + "sessionID": sessionID, + "deletedMessages": deleted_count, + }) + except Exception as exc: + log.warn("session.clear.event_error", {"sessionID": sessionID, "error": str(exc)}) + + return deleted_count + + @router.post("/{sessionID}/clear") async def clear_session(sessionID: str, http_request: Request): """ @@ -3460,7 +3827,6 @@ async def clear_session(sessionID: str, http_request: Request): Removes all messages from the session while keeping the session itself. """ try: - # Verify session exists session_info = await _get_session_by_id_unfiltered(sessionID) if not session_info: raise HTTPException( @@ -3470,16 +3836,14 @@ async def clear_session(sessionID: str, http_request: Request): current_user = require_user(http_request) _require_session_write_access(session_info, current_user) - # Use Message.clear which handles bulk deletion atomically - from flocks.session.message import Message - deleted_count = await Message.clear(sessionID) - - log.info("session.cleared", {"sessionID": sessionID, "deleted": deleted_count}) + deleted_count = await _clear_session_history(sessionID) return { "status": "success", "sessionID": sessionID, "deletedMessages": deleted_count, } + except HTTPException: + raise except Exception as e: log.error("session.clear.error", {"sessionID": sessionID, "error": str(e)}) raise HTTPException(status_code=500, detail=f"Failed to clear session: {str(e)}") diff --git a/flocks/server/routes/tool.py b/flocks/server/routes/tool.py index 9e1bb65a3..6a4590add 100644 --- a/flocks/server/routes/tool.py +++ b/flocks/server/routes/tool.py @@ -5,7 +5,7 @@ import asyncio import time from typing import List, Optional, Dict, Any -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from flocks.server.auth import require_admin @@ -506,18 +506,34 @@ async def get_tool(tool_name: str): response_model=ToolInfoResponse, summary="Update tool settings", ) -async def update_tool(tool_name: str, request: ToolUpdateRequest, _admin: object = Depends(require_admin)): +async def update_tool( + tool_name: str, + request: ToolUpdateRequest, + device_id: Optional[str] = Query( + None, + description=( + "设备实例 UUID。提供时仅修改该设备的工具开关(per-device 覆盖)," + "不影响其他同版本设备;省略时修改全局 tool_settings(影响所有设备)。" + ), + ), + _admin: object = Depends(require_admin), +): """ Update tool settings (e.g., enable or disable). - The ``enabled`` flag is persisted to the user-level overlay in - ``flocks.json`` (``tool_settings..enabled``) instead of - mutating the YAML plugin file. This keeps project-level YAML files - (which may be tracked by git and overwritten on upgrade) clean and - treats the YAML's ``enabled:`` field as the factory default that the - overlay can selectively customise. + **Global mode** (``device_id`` omitted): + Persists to ``flocks.json`` → ``tool_settings..enabled``. + Affects all device instances that share this tool. + + **Per-device mode** (``device_id`` provided): + Persists to the SQLite ``device_tool_settings`` table (one row per + device_id × tool_name). Only affects tool execution when ``device_id`` + is explicitly targeted, allowing Device A and Device B (same plugin + version, different names) to carry independent tool enabled/disabled + states. Rows are removed automatically via ON DELETE CASCADE when the + parent device row is deleted. - Two behaviours of note: + Two behaviours of note (global mode only): * If ``request.enabled`` matches the registration-time default we *delete* the overlay entry instead of writing one — the tool is @@ -538,6 +554,34 @@ async def update_tool(tool_name: str, request: ToolUpdateRequest, _admin: object ) desired = bool(request.enabled) + + # --- Per-device mode --- + if device_id: + from flocks.tool.device.store import ( + delete_device_tool_setting, + set_device_tool_enabled, + ) + if desired: + # "Enable" in per-device mode means removing the per-device + # override so the global/factory default takes effect again. + removed = await delete_device_tool_setting(device_id, tool_name) + log.info("tool.device.updated.reset_to_global", { + "name": tool_name, + "device_id": device_id, + "removed_override": removed, + }) + else: + await set_device_tool_enabled(device_id, tool_name, False) + log.info("tool.device.updated", { + "name": tool_name, + "device_id": device_id, + "enabled": False, + }) + # The in-memory ToolInfo.enabled is NOT changed; it reflects global + # state. Per-device gating happens at ToolRegistry.execute time. + return _build_tool_response(tool.info) + + # --- Global mode (original behaviour) --- default = _get_default_enabled(tool.info) # Service gate: only matters when the user is trying to enable. # Disabling is always honoured. diff --git a/flocks/server/routes/update.py b/flocks/server/routes/update.py index 40ebb45f9..9836dfe0b 100644 --- a/flocks/server/routes/update.py +++ b/flocks/server/routes/update.py @@ -24,6 +24,7 @@ summary="Check for new version", ) async def check_version( + request: Request, locale: str | None = Query( default=None, description="Optional UI locale hint used to choose region-appropriate upgrade mirrors.", @@ -33,6 +34,8 @@ async def check_version( description="Version channel to check. flockspro checks the Console Pro bundle manifest.", ), ) -> VersionInfo: + if edition == "flockspro": + require_admin(request) return await check_update(locale=locale, force_console_manifest=(edition == "flockspro")) diff --git a/flocks/server/routes/workflow.py b/flocks/server/routes/workflow.py index e6ec3f048..d36bc67af 100644 --- a/flocks/server/routes/workflow.py +++ b/flocks/server/routes/workflow.py @@ -42,6 +42,7 @@ read_workflow_from_fs as shared_read_workflow_from_fs, workflow_scan_dirs as _all_scan_dirs, ) +from flocks.ingest.kafka.constants import WORKFLOW_KAFKA_CONFIG_PREFIX from flocks.ingest.syslog.constants import WORKFLOW_SYSLOG_CONFIG_PREFIX from flocks.workflow.execution_store import ( compact_history_for_storage, @@ -783,6 +784,17 @@ async def delete_workflow(workflow_id: str): except Storage.NotFoundError: pass + try: + from flocks.ingest.kafka.manager import default_manager as _kafka_default_manager + + await _kafka_default_manager.stop_workflow(workflow_id) + except Exception: + pass + try: + await Storage.remove(_kafka_config_key(workflow_id)) + except Storage.NotFoundError: + pass + log.info("workflow.deleted", {"id": workflow_id}) await publish_event("workflow.deleted", {"id": workflow_id}) return None @@ -1348,7 +1360,7 @@ async def export_workflow(workflow_id: str): # ============================================================================= _API_SERVICE_PREFIX = "workflow_api_service/" -_KAFKA_CONFIG_PREFIX = "workflow_kafka_config/" +_KAFKA_CONFIG_PREFIX = WORKFLOW_KAFKA_CONFIG_PREFIX _REGISTRY_PREFIX_MAIN = "workflow_registry/" @@ -1373,11 +1385,37 @@ class WorkflowServiceResponse(BaseModel): class KafkaConfigRequest(BaseModel): + """Per-workflow Kafka consumer configuration.""" + + enabled: bool = False inputBroker: Optional[str] = None inputTopic: Optional[str] = None inputGroupId: Optional[str] = None - outputBroker: Optional[str] = None - outputTopic: Optional[str] = None + inputKey: str = "kafka_message" + autoOffsetReset: str = "latest" + inputs: Dict[str, Any] = Field(default_factory=dict) + + +def _strip_execution_only_comments(value: Any) -> Any: + if isinstance(value, list): + return [_strip_execution_only_comments(item) for item in value] + if not isinstance(value, dict): + return value + return { + key: _strip_execution_only_comments(nested) + for key, nested in value.items() + if not str(key).startswith("_comment") + } + + +class WorkflowPollerConfigRequest(BaseModel): + """Per-workflow background poller configuration.""" + + enabled: bool = False + intervalSeconds: int = Field(30, ge=1) + timeoutSeconds: int = Field(7200, ge=1) + noOverlap: bool = True + inputs: Dict[str, Any] = Field(default_factory=dict) class SyslogConfigRequest(BaseModel): @@ -1547,8 +1585,12 @@ async def list_workflow_services(): @router.post("/workflow/{workflow_id}/kafka-config") async def save_kafka_config(workflow_id: str, req: KafkaConfigRequest): """ - Save Kafka input/output configuration for a workflow. - (Kafka integration is experimental; this stores config for future use.) + Save Kafka input configuration for a workflow. + + When ``enabled`` is true this also (re)starts the Kafka consumer and blocks + until it has either connected to the broker or failed. Connection failures + are surfaced as ``409 Conflict`` so the UI can show an actionable error + instead of falsely claiming the consumer is running. """ try: if not _read_workflow_from_fs(workflow_id): @@ -1556,15 +1598,28 @@ async def save_kafka_config(workflow_id: str, req: KafkaConfigRequest): config = { "workflowId": workflow_id, + "enabled": req.enabled, "inputBroker": req.inputBroker, "inputTopic": req.inputTopic, "inputGroupId": req.inputGroupId, - "outputBroker": req.outputBroker, - "outputTopic": req.outputTopic, + "inputKey": req.inputKey, + "autoOffsetReset": req.autoOffsetReset, + "inputs": _strip_execution_only_comments(req.inputs), "updatedAt": int(time.time() * 1000), } await Storage.write(_kafka_config_key(workflow_id), config) - return {"ok": True} + + from flocks.ingest.kafka.manager import default_manager as _kafka_default_manager + + status = await _kafka_default_manager.restart_workflow(workflow_id) + state = (status or {}).get("state") + if req.enabled and state == "failed": + err = (status or {}).get("error") or "consumer_connect_failed" + raise HTTPException( + status_code=409, + detail=f"Kafka consumer failed to start: {err}", + ) + return {"ok": True, "consumer": status} except HTTPException: raise except Exception as e: @@ -1585,6 +1640,98 @@ async def get_kafka_config(workflow_id: str): raise HTTPException(status_code=500, detail=f"Failed to get Kafka config: {str(e)}") +@router.get("/workflow/{workflow_id}/kafka-status") +async def get_kafka_status(workflow_id: str): + """Return the *runtime* status of the Kafka consumer for a workflow. + + Reflects the actual connection state (connecting/running/failed/stopped) and + queue depth so the UI can show whether a saved-but-not-yet-connected + consumer is actually running. The persisted config only captures *intent*. + """ + try: + from flocks.ingest.kafka.manager import default_manager as _kafka_default_manager + + return _kafka_default_manager.get_consumer_status(workflow_id) + except Exception as e: + log.error("workflow.kafka_status.get.error", {"id": workflow_id, "error": str(e)}) + raise HTTPException(status_code=500, detail=f"Failed to get Kafka status: {str(e)}") + + +@router.post("/workflow/{workflow_id}/poller-config") +async def save_workflow_poller_config(workflow_id: str, req: WorkflowPollerConfigRequest): + """Save background poller configuration for a workflow.""" + try: + if not _read_workflow_from_fs(workflow_id): + raise HTTPException(status_code=404, detail=f"Workflow not found: {workflow_id}") + + config = { + "workflowId": workflow_id, + "enabled": req.enabled, + "intervalSeconds": req.intervalSeconds, + "timeoutSeconds": req.timeoutSeconds, + "noOverlap": req.noOverlap, + "inputs": req.inputs, + "updatedAt": int(time.time() * 1000), + } + await Storage.write(f"workflow_poller_config/{workflow_id}", config) + + from flocks.workflow.poller_manager import default_manager as _poller_default_manager + + poller_status = await _poller_default_manager.restart_workflow(workflow_id) + if req.enabled and (poller_status or {}).get("state") == "failed": + err = (poller_status or {}).get("error") or "poller_start_failed" + raise HTTPException( + status_code=409, + detail=f"Workflow poller failed to start: {err}", + ) + return {"ok": True, "status": poller_status} + except HTTPException: + raise + except Exception as e: + log.error("workflow.poller_config.save.error", {"id": workflow_id, "error": str(e)}) + raise HTTPException(status_code=500, detail=f"Failed to save poller config: {str(e)}") + + +@router.get("/workflow/{workflow_id}/poller-config") +async def get_workflow_poller_config(workflow_id: str): + """Get saved poller configuration for a workflow.""" + try: + return await Storage.read(f"workflow_poller_config/{workflow_id}") + except Exception as e: + log.error("workflow.poller_config.get.error", {"id": workflow_id, "error": str(e)}) + raise HTTPException(status_code=500, detail=f"Failed to get poller config: {str(e)}") + + +@router.get("/workflow/{workflow_id}/poller-status") +async def get_workflow_poller_status(workflow_id: str): + """Return the runtime status of a workflow poller.""" + try: + from flocks.workflow.poller_manager import default_manager as _poller_default_manager + + return _poller_default_manager.get_status(workflow_id) + except Exception as e: + log.error("workflow.poller_status.get.error", {"id": workflow_id, "error": str(e)}) + raise HTTPException(status_code=500, detail=f"Failed to get poller status: {str(e)}") + + +@router.post("/workflow/{workflow_id}/poller-run-once") +async def run_workflow_poller_once(workflow_id: str): + """Trigger one immediate poller execution for a workflow.""" + try: + if not _read_workflow_from_fs(workflow_id): + raise HTTPException(status_code=404, detail=f"Workflow not found: {workflow_id}") + + from flocks.workflow.poller_manager import default_manager as _poller_default_manager + + poller_status = await _poller_default_manager.run_once(workflow_id) + return {"ok": True, "status": poller_status} + except HTTPException: + raise + except Exception as e: + log.error("workflow.poller_run_once.error", {"id": workflow_id, "error": str(e)}) + raise HTTPException(status_code=500, detail=f"Failed to run workflow poller once: {str(e)}") + + @router.post("/workflow/{workflow_id}/syslog-config") async def save_syslog_config(workflow_id: str, req: SyslogConfigRequest): """ diff --git a/flocks/session/interaction_queue.py b/flocks/session/interaction_queue.py new file mode 100644 index 000000000..92ab2242c --- /dev/null +++ b/flocks/session/interaction_queue.py @@ -0,0 +1,164 @@ +"""In-memory prompt queue for non-blocking session interaction.""" + +from __future__ import annotations + +import asyncio +import time +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from flocks.utils.id import Identifier + + +MAX_QUEUE_SIZE = 50 + + +class QueueFullError(Exception): + """Raised when a session prompt queue reaches its configured limit.""" + + +class QueueItemNotFoundError(Exception): + """Raised when a queued prompt cannot be found.""" + + +class QueuedPrompt(BaseModel): + id: str = Field(default_factory=lambda: Identifier.create("part")) + sessionID: str + parts: List[Dict[str, Any]] = Field(default_factory=list) + agent: Optional[str] = None + model: Optional[Dict[str, Any]] = None + variant: Optional[str] = None + messageID: Optional[str] = None + noReply: Optional[bool] = None + mockReply: Optional[str] = None + tools: Optional[Dict[str, bool]] = None + system: Optional[str] = None + status: str = "pending" + createdAt: int = Field(default_factory=lambda: int(time.time() * 1000)) + updatedAt: int = Field(default_factory=lambda: int(time.time() * 1000)) + + +class InteractionQueue: + """Process-local per-session FIFO prompt queues.""" + + _queues: Dict[str, List[QueuedPrompt]] = {} + _locks: Dict[str, asyncio.Lock] = {} + + @classmethod + def _lock_for(cls, session_id: str) -> asyncio.Lock: + lock = cls._locks.get(session_id) + if lock is None: + lock = asyncio.Lock() + cls._locks[session_id] = lock + return lock + + @classmethod + async def enqueue( + cls, + session_id: str, + *, + parts: List[Dict[str, Any]], + agent: Optional[str] = None, + model: Optional[Dict[str, Any]] = None, + variant: Optional[str] = None, + message_id: Optional[str] = None, + no_reply: Optional[bool] = None, + mock_reply: Optional[str] = None, + tools: Optional[Dict[str, bool]] = None, + system: Optional[str] = None, + ) -> QueuedPrompt: + async with cls._lock_for(session_id): + queue = cls._queues.setdefault(session_id, []) + if len(queue) >= MAX_QUEUE_SIZE: + raise QueueFullError(f"Session {session_id} prompt queue is full") + + item = QueuedPrompt( + sessionID=session_id, + parts=[dict(part) for part in parts], + agent=agent, + model=dict(model) if isinstance(model, dict) else model, + variant=variant, + messageID=message_id, + noReply=no_reply, + mockReply=mock_reply, + tools=dict(tools) if tools else None, + system=system, + ) + queue.append(item) + return item + + @classmethod + async def list(cls, session_id: str) -> List[QueuedPrompt]: + async with cls._lock_for(session_id): + return [item.model_copy(deep=True) for item in cls._queues.get(session_id, [])] + + @classmethod + async def update_text(cls, session_id: str, item_id: str, text: str) -> QueuedPrompt: + async with cls._lock_for(session_id): + item = cls._find_locked(session_id, item_id) + if item.status == "executing": + raise QueueItemNotFoundError(f"Queued prompt {item_id} is already executing") + + parts: List[Dict[str, Any]] = [] + replaced = False + for part in item.parts: + if part.get("type") == "text" and not replaced: + next_part = dict(part) + next_part["text"] = text + parts.append(next_part) + replaced = True + elif part.get("type") != "text": + parts.append(dict(part)) + if not replaced: + parts.insert(0, {"type": "text", "text": text}) + + item.parts = parts + item.updatedAt = int(time.time() * 1000) + return item.model_copy(deep=True) + + @classmethod + async def remove(cls, session_id: str, item_id: str) -> QueuedPrompt: + async with cls._lock_for(session_id): + queue = cls._queues.get(session_id, []) + for idx, item in enumerate(queue): + if item.id == item_id: + return queue.pop(idx) + raise QueueItemNotFoundError(f"Queued prompt {item_id} not found") + + @classmethod + async def pop_next(cls, session_id: str) -> Optional[QueuedPrompt]: + async with cls._lock_for(session_id): + queue = cls._queues.get(session_id, []) + if not queue: + return None + item = queue.pop(0) + item.status = "executing" + item.updatedAt = int(time.time() * 1000) + return item + + @classmethod + async def promote(cls, session_id: str, item_id: str) -> QueuedPrompt: + async with cls._lock_for(session_id): + queue = cls._queues.get(session_id, []) + for idx, item in enumerate(queue): + if item.id == item_id: + if item.status == "executing": + raise QueueItemNotFoundError(f"Queued prompt {item_id} is already executing") + promoted = queue.pop(idx) + promoted.updatedAt = int(time.time() * 1000) + queue.insert(0, promoted) + return promoted.model_copy(deep=True) + raise QueueItemNotFoundError(f"Queued prompt {item_id} not found") + + @classmethod + async def clear(cls, session_id: str) -> None: + async with cls._lock_for(session_id): + cls._queues.pop(session_id, None) + + @classmethod + def _find_locked(cls, session_id: str, item_id: str) -> QueuedPrompt: + for item in cls._queues.get(session_id, []): + if item.id == item_id: + return item + raise QueueItemNotFoundError(f"Queued prompt {item_id} not found") diff --git a/flocks/session/lifecycle/title.py b/flocks/session/lifecycle/title.py index 60ac540db..6378dd04f 100644 --- a/flocks/session/lifecycle/title.py +++ b/flocks/session/lifecycle/title.py @@ -7,6 +7,8 @@ from typing import Any, Awaitable, Callable, Dict, List, Optional import asyncio +import json +import re from flocks.utils.log import Log from flocks.provider.provider import ChatMessage @@ -22,6 +24,14 @@ class SessionTitle: """Session title generation""" + + _TOOL_CALL_TITLE_PATTERNS = ( + re.compile(r"^\s*\[TOOL_CALL\]", re.IGNORECASE), + re.compile(r"^\s*", re.IGNORECASE), + re.compile(r"\bargs\s*=>", re.IGNORECASE), + re.compile(r"^\s*\{?\s*(tool|name)\s*=>", re.IGNORECASE), + ) @classmethod async def generate_title_after_first_message( @@ -138,13 +148,8 @@ async def generate_title_after_first_message( "error": str(llm_err), }) - # Clean up the generated title - title = title.strip().strip('"').strip("'").strip() - - # Validate title length - if len(title) > 50: - title = title[:47] + "..." - + title = cls._sanitize_generated_title(title) + if not title: title = cls._generate_simple_title(question) @@ -233,3 +238,70 @@ def _generate_simple_title(text: str, max_length: int = 50) -> str: title = "New Chat" return title + + @classmethod + def _sanitize_generated_title(cls, title: str, max_length: int = 50) -> str: + """Clean and validate a model-generated title candidate.""" + title = title.strip().strip('"').strip("'").strip() + if not title: + return "" + + if cls._looks_like_tool_call_title(title): + log.warn("title.rejected_tool_call_candidate", { + "candidate": title[:120], + }) + return "" + + if len(title) > max_length: + title = title[:max_length - 3] + "..." + return title + + @classmethod + def _looks_like_tool_call_title(cls, title: str) -> bool: + """Return True when a title candidate is actually a tool-call payload.""" + candidate = title.strip() + if not candidate: + return False + + for pattern in cls._TOOL_CALL_TITLE_PATTERNS: + if pattern.search(candidate): + return True + + json_candidate = cls._strip_code_fence(candidate) + if not json_candidate.startswith(("{", "[")): + return False + + try: + parsed = json.loads(json_candidate) + except (TypeError, ValueError): + return False + + return cls._json_has_tool_call_shape(parsed) + + @staticmethod + def _strip_code_fence(text: str) -> str: + stripped = text.strip() + if not stripped.startswith("```"): + return stripped + + lines = stripped.splitlines() + if len(lines) >= 2 and lines[-1].strip() == "```": + return "\n".join(lines[1:-1]).strip() + return stripped + + @classmethod + def _json_has_tool_call_shape(cls, value: Any) -> bool: + if isinstance(value, dict): + keys = {str(key).lower() for key in value.keys()} + if {"tool", "args"} <= keys: + return True + if {"name", "arguments"} <= keys: + return True + if "function" in keys and ("arguments" in keys or "name" in keys): + return True + return any(cls._json_has_tool_call_shape(item) for item in value.values()) + + if isinstance(value, list): + return any(cls._json_has_tool_call_shape(item) for item in value) + + return False diff --git a/flocks/session/message.py b/flocks/session/message.py index 7eadd39be..9bf05da41 100644 --- a/flocks/session/message.py +++ b/flocks/session/message.py @@ -596,14 +596,35 @@ async def _ensure_cache(cls, session_id: str) -> None: storage_key = f"{cls._MESSAGE_PREFIX}:{session_id}" stored_data = await Storage.get(storage_key) - if stored_data: + message_times: Dict[str, Dict[str, int]] = {} + if isinstance(stored_data, list): messages = [] for msg_data in stored_data: - role = msg_data.get('role', 'assistant') - if role == 'user': - messages.append(UserMessageInfo.model_validate(msg_data)) - else: - messages.append(AssistantMessageInfo.model_validate(msg_data)) + if not isinstance(msg_data, dict): + log.warn("message.cache.skipped_non_dict", { + "session_id": session_id, + "raw_type": type(msg_data).__name__, + }) + continue + + normalized = cls._normalize_stored_message(msg_data, session_id) + role = normalized.get("role", "assistant") + try: + if role == "user": + message = UserMessageInfo.model_validate(normalized) + else: + message = AssistantMessageInfo.model_validate(normalized) + except Exception as exc: + log.warn("message.cache.skipped_invalid", { + "session_id": session_id, + "message_id": normalized.get("id"), + "role": role, + "error": str(exc), + }) + continue + + messages.append(message) + message_times[message.id] = message.time cls._messages_cache[session_id] = messages else: cls._messages_cache[session_id] = [] @@ -638,9 +659,12 @@ async def _ensure_cache(cls, session_id: str) -> None: for msg_id, parts_data in stored_parts.items(): if not isinstance(parts_data, list): continue - cls._parts_cache[session_id][msg_id] = [ - cls.deserialize_part(p) for p in parts_data - ] + cls._parts_cache[session_id][msg_id] = cls._deserialize_parts_list( + session_id, + msg_id, + parts_data, + message_time=message_times.get(msg_id), + ) cls._parts_revision_cache[session_id][msg_id] = 0 cls._parts_serialized_cache[session_id] = { msg_id: list(parts_data) @@ -663,8 +687,347 @@ def _rebuild_id_index(cls, session_id: str) -> None: cls._msg_id_index[session_id] = {m.id: i for i, m in enumerate(messages)} @classmethod - def deserialize_part(cls, part_data: Dict[str, Any]) -> PartType: + def _first_non_none(cls, *values: Any) -> Any: + """Return the first value that is not None, preserving 0-like sentinels.""" + for value in values: + if value is not None: + return value + return None + + @classmethod + def _default_message_time(cls, raw_time: Any) -> Dict[str, int]: + """Normalize stored message timestamps to the current schema.""" + now_ms = int(datetime.now().timestamp() * 1000) + if isinstance(raw_time, dict): + normalized = dict(raw_time) + created = cls._first_non_none( + normalized.get("created"), + normalized.get("start"), + ) + normalized["created"] = int(created) if created is not None else now_ms + if "completed" not in normalized and normalized.get("end") is not None: + normalized["completed"] = int(normalized["end"]) + return normalized + return {"created": now_ms} + + @classmethod + def _default_message_path(cls, raw_path: Any) -> Dict[str, str]: + """Return a safe message path payload for assistant messages.""" + if isinstance(raw_path, MessagePath): + return raw_path.model_dump() + if isinstance(raw_path, dict): + return { + "cwd": str(raw_path.get("cwd") or "./"), + "root": str(raw_path.get("root") or ""), + } + return MessagePath(cwd="./").model_dump() + + @classmethod + def _default_token_usage(cls, raw_tokens: Any = None) -> Dict[str, Any]: + """Return a safe token payload for assistant messages.""" + defaults = TokenUsage().model_dump() + if isinstance(raw_tokens, TokenUsage): + return raw_tokens.model_dump() + if isinstance(raw_tokens, dict): + normalized = dict(defaults) + for key in ("input", "output", "reasoning"): + value = raw_tokens.get(key) + if value is not None: + normalized[key] = value + cache_raw = raw_tokens.get("cache") + if isinstance(cache_raw, dict): + normalized["cache"] = { + "read": cache_raw.get("read", 0), + "write": cache_raw.get("write", 0), + } + return normalized + return defaults + + @classmethod + def _normalize_stored_message( + cls, + msg_data: Dict[str, Any], + session_id: str, + ) -> Dict[str, Any]: + """Backfill missing fields for legacy or partially-written messages.""" + normalized = dict(msg_data) + role = normalized.get("role", "assistant") + normalized["sessionID"] = normalized.get("sessionID") or session_id + normalized["time"] = cls._default_message_time(normalized.get("time")) + + if role == "user": + model_raw = normalized.get("model") + if not isinstance(model_raw, dict): + model_raw = {} + normalized["agent"] = normalized.get("agent") or "rex" + normalized["model"] = { + "providerID": model_raw.get("providerID") + or normalized.get("providerID") + or "", + "modelID": model_raw.get("modelID") + or normalized.get("modelID") + or "", + } + return normalized + + model_raw = normalized.get("model") + if not isinstance(model_raw, dict): + model_raw = {} + normalized["parentID"] = ( + normalized.get("parentID") + or normalized.get("parent_id") + or "" + ) + normalized["modelID"] = ( + normalized.get("modelID") + or normalized.get("model_id") + or model_raw.get("modelID") + or "" + ) + normalized["providerID"] = ( + normalized.get("providerID") + or normalized.get("provider_id") + or model_raw.get("providerID") + or "" + ) + normalized["agent"] = normalized.get("agent") or "rex" + normalized["mode"] = normalized.get("mode") or normalized["agent"] or "standard" + normalized["path"] = cls._default_message_path(normalized.get("path")) + normalized["tokens"] = cls._default_token_usage(normalized.get("tokens")) + return normalized + + @classmethod + def _default_part_time( + cls, + raw_time: Any = None, + *, + message_time: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Optional[int]]: + """Provide best-effort timestamps for legacy stored parts.""" + time_info = raw_time if isinstance(raw_time, dict) else {} + fallback_time = message_time if isinstance(message_time, dict) else {} + start = cls._first_non_none( + time_info.get("start"), + time_info.get("created"), + fallback_time.get("start"), + fallback_time.get("created"), + 0, + ) + end = cls._first_non_none( + time_info.get("end"), + time_info.get("updated"), + time_info.get("completed"), + fallback_time.get("end"), + fallback_time.get("updated"), + fallback_time.get("completed"), + ) + if end is None and start is not None: + end = start + return {"start": int(start), "end": int(end) if end is not None else None} + + @classmethod + def _normalize_tool_state( + cls, + raw_state: Any, + *, + message_time: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Convert legacy tool state payloads to the current shape.""" + fallback_time = cls._default_part_time(message_time=message_time) + if isinstance(raw_state, dict): + normalized = dict(raw_state) + else: + normalized = {} + if isinstance(raw_state, str): + normalized["status"] = raw_state + + status = normalized.get("status") + if status not in {"pending", "running", "completed", "error"}: + if "output" in normalized: + status = "completed" + elif "error" in normalized: + status = "error" + elif "time" in normalized: + status = "running" + else: + status = "pending" + normalized["status"] = status + + if status == "pending": + normalized.setdefault("input", {}) + normalized.setdefault("raw", "") + elif status == "running": + normalized.setdefault("input", {}) + raw_time = normalized.get("time") + normalized["time"] = ( + cls._default_part_time(raw_time, message_time=message_time) + if isinstance(raw_time, dict) + else fallback_time + ) + elif status == "completed": + normalized.setdefault("input", {}) + normalized.setdefault("output", "") + normalized.setdefault("title", "") + normalized.setdefault("metadata", {}) + raw_time = normalized.get("time") + normalized["time"] = ( + cls._default_part_time(raw_time, message_time=message_time) + if isinstance(raw_time, dict) + else fallback_time + ) + elif status == "error": + normalized.setdefault("input", {}) + normalized.setdefault("error", "") + raw_time = normalized.get("time") + normalized["time"] = ( + cls._default_part_time(raw_time, message_time=message_time) + if isinstance(raw_time, dict) + else fallback_time + ) + + return normalized + + @classmethod + def _normalize_part_data( + cls, + part_data: Dict[str, Any], + *, + session_id: Optional[str] = None, + message_id: Optional[str] = None, + message_time: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Normalize legacy/exported part payloads before validation.""" + normalized = dict(part_data) + if session_id and not normalized.get("sessionID"): + normalized["sessionID"] = session_id + if message_id and not normalized.get("messageID"): + normalized["messageID"] = message_id + + part_type = normalized.get("type", "text") + metadata = normalized.get("metadata") + metadata_dict = metadata if isinstance(metadata, dict) else {} + + if "content" in normalized and "text" not in normalized: + normalized["text"] = normalized.get("content", "") + + raw_time = normalized.get("time") + if part_type == "text": + normalized.setdefault("text", "") + if isinstance(raw_time, dict): + if "start" not in raw_time and any( + key in raw_time for key in ("created", "updated", "completed", "end") + ): + normalized["time"] = cls._default_part_time( + raw_time, + message_time=message_time, + ) + elif raw_time is not None: + normalized.pop("time", None) + if metadata is not None and not isinstance(metadata, dict): + normalized["metadata"] = None + elif part_type == "reasoning": + normalized.setdefault("text", normalized.get("content", "")) + normalized.setdefault("metadata", metadata_dict or None) + normalized["time"] = cls._default_part_time( + raw_time, + message_time=message_time, + ) + elif part_type == "tool": + normalized.setdefault("callID", metadata_dict.get("callID") or normalized.get("id", "")) + normalized.setdefault("tool", metadata_dict.get("tool") or normalized.get("tool", "unknown")) + raw_state = normalized.get("state") + if raw_state is None and metadata_dict: + raw_state = metadata_dict.get("state") + normalized["state"] = cls._normalize_tool_state( + raw_state, + message_time=message_time, + ) + normalized.setdefault("metadata", metadata_dict or None) + elif part_type == "file": + normalized.setdefault("mime", metadata_dict.get("mime") or "application/octet-stream") + normalized.setdefault("filename", metadata_dict.get("filename")) + normalized.setdefault("url", metadata_dict.get("url") or normalized.get("content", "")) + elif part_type == "snapshot": + normalized.setdefault("snapshot", metadata_dict.get("snapshot") or normalized.get("content", "")) + elif part_type == "patch": + normalized.setdefault("hash", metadata_dict.get("hash") or "") + normalized.setdefault("files", metadata_dict.get("files") or []) + elif part_type == "step-finish": + normalized.setdefault("reason", metadata_dict.get("reason") or "completed") + normalized.setdefault("snapshot", metadata_dict.get("snapshot")) + normalized.setdefault("cost", metadata_dict.get("cost") or 0.0) + normalized.setdefault("tokens", metadata_dict.get("tokens") or cls._default_token_usage()) + elif part_type == "agent": + normalized.setdefault("name", metadata_dict.get("name") or normalized.get("content") or "agent") + elif part_type == "subtask": + normalized.setdefault("prompt", metadata_dict.get("prompt") or normalized.get("content", "")) + normalized.setdefault("description", metadata_dict.get("description") or "") + normalized.setdefault("agent", metadata_dict.get("agent") or "agent") + elif part_type == "retry": + normalized.setdefault("attempt", metadata_dict.get("attempt") or 1) + normalized.setdefault("error", metadata_dict.get("error") or {}) + normalized["time"] = cls._default_part_time( + metadata_dict.get("time") or raw_time, + message_time=message_time, + ) + elif part_type == "compaction": + normalized.setdefault("auto", bool(metadata_dict.get("auto", False))) + + return normalized + + @classmethod + def _deserialize_parts_list( + cls, + session_id: str, + message_id: str, + parts_data: List[Dict[str, Any]], + *, + message_time: Optional[Dict[str, Any]] = None, + ) -> List[PartType]: + """Best-effort deserialize a stored parts list without dropping the session.""" + parts: List[PartType] = [] + for raw_part in parts_data: + if not isinstance(raw_part, dict): + log.warn("message.part.skipped_non_dict", { + "session_id": session_id, + "message_id": message_id, + "raw_type": type(raw_part).__name__, + }) + continue + try: + parts.append( + cls.deserialize_part( + raw_part, + session_id=session_id, + message_id=message_id, + message_time=message_time, + ) + ) + except Exception as exc: + log.warn("message.part.skipped_invalid", { + "session_id": session_id, + "message_id": message_id, + "part_id": raw_part.get("id"), + "error": str(exc), + }) + return parts + + @classmethod + def deserialize_part( + cls, + part_data: Dict[str, Any], + *, + session_id: Optional[str] = None, + message_id: Optional[str] = None, + message_time: Optional[Dict[str, Any]] = None, + ) -> PartType: """Deserialize a part from storage format""" + part_data = cls._normalize_part_data( + part_data, + session_id=session_id, + message_id=message_id, + message_time=message_time, + ) part_type = part_data.get('type', 'text') type_map = { @@ -692,10 +1055,14 @@ def _normalize_assistant_message(message: MessageInfo) -> MessageInfo: return message updates: Dict[str, Any] = {} - if isinstance(message.tokens, dict): - updates["tokens"] = TokenUsage.model_validate(message.tokens) - if isinstance(message.path, dict): - updates["path"] = MessagePath.model_validate(message.path) + if not isinstance(message.tokens, TokenUsage): + updates["tokens"] = TokenUsage.model_validate( + Message._default_token_usage(message.tokens) + ) + if not isinstance(message.path, MessagePath): + updates["path"] = MessagePath.model_validate( + Message._default_message_path(message.path) + ) if not updates: return message diff --git a/flocks/session/prompt.py b/flocks/session/prompt.py index 160077d63..3896400e6 100644 --- a/flocks/session/prompt.py +++ b/flocks/session/prompt.py @@ -271,6 +271,8 @@ def environment_stable( f" Workspace outputs directory: {outputs_dir}", f" Is directory a git repo: {'yes' if is_git else 'no'}", f" Platform: {platform.system().lower()}", + " Python executor: uv python", + " Python package manager: uv pip", "", ] return ["\n".join(env_info)] diff --git a/flocks/storage/storage.py b/flocks/storage/storage.py index 5e46de439..a7f042fdc 100644 --- a/flocks/storage/storage.py +++ b/flocks/storage/storage.py @@ -977,7 +977,35 @@ async def list_entries( value = json.loads(value_str) entries.append((key, value)) return entries - + + @classmethod + async def list_raw( + cls, + prefix: Optional[str] = None, + ) -> List[Tuple[str, str]]: + """List ``(key, raw_value_str)`` pairs without Python-side JSON parsing. + + Unlike :meth:`list_entries`, the value is returned as a plain string + so callers can apply lightweight extraction (e.g. regex) instead of + full ``json.loads``. This is critical for hot trim paths that only + need a couple of scalar fields from otherwise large JSON blobs. + Compatible with all SQLite versions (no ``json_extract`` required). + """ + await cls._ensure_init() + + if prefix: + query = "SELECT key, value FROM storage WHERE key LIKE ?" + params: tuple = (f"{prefix}%",) + else: + query = "SELECT key, value FROM storage" + params = () + + async with cls.connect(cls._db_path) as db: + async with db.execute(query, params) as cursor: + rows = await cursor.fetchall() + + return [(row[0], row[1]) for row in rows] + @classmethod async def exists(cls, key: str) -> bool: """ diff --git a/flocks/tool/agent/delegate_task.py b/flocks/tool/agent/delegate_task.py index a77df6c2d..27ccf30a4 100644 --- a/flocks/tool/agent/delegate_task.py +++ b/flocks/tool/agent/delegate_task.py @@ -237,14 +237,13 @@ def _derive_task_description( Usage notes: - Provide a clear description (3-5 words) - Provide detailed prompt with context -- run_in_background=true: returns task_id immediately, collect results later with background_output -- run_in_background=false: waits for completion and returns results inline - Pass session_id to continue a previous agent with full context +- run_in_background=false: (default) waits for completion and returns results inline REQUIRED: prompt. LOAD_SKILLS is optional and defaults to []. DESCRIPTION is optional and will be auto-derived when omitted. -RUN_IN_BACKGROUND defaults to false (sync). +RUN_IN_BACKGROUND defaults to false (sync). if true, need: returns task_id immediately, collect results later with background_output USE EITHER subagent_type OR category — NEVER both simultaneously. """ diff --git a/flocks/tool/device/models.py b/flocks/tool/device/models.py index 71def077b..5b945673c 100644 --- a/flocks/tool/device/models.py +++ b/flocks/tool/device/models.py @@ -64,6 +64,26 @@ "ALTER TABLE device_integrations ADD COLUMN group_id TEXT NOT NULL DEFAULT '';" ) +# Per-device tool enabled/disabled overrides. +# +# Each row disables (enabled=0) or re-enables (enabled=1) a specific tool +# for a specific device instance, independent of the shared global +# tool_settings overlay and other device instances that share the same +# storage_key (same plugin version, different names). +# +# ON DELETE CASCADE removes all per-device settings automatically when the +# parent device row is deleted, so no manual cleanup is needed. +Storage.register_ddl(""" +CREATE TABLE IF NOT EXISTS device_tool_settings ( + device_id TEXT NOT NULL REFERENCES device_integrations(id) ON DELETE CASCADE, + tool_name TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1, + updated_at INTEGER NOT NULL, + PRIMARY KEY (device_id, tool_name) +); +CREATE INDEX IF NOT EXISTS idx_dts_device ON device_tool_settings(device_id); +""") + # --------------------------------------------------------------------------- # Pydantic models diff --git a/flocks/tool/device/prompt.py b/flocks/tool/device/prompt.py index 8f5176994..4f652e8a8 100644 --- a/flocks/tool/device/prompt.py +++ b/flocks/tool/device/prompt.py @@ -13,7 +13,7 @@ from flocks.utils.log import Log from .models import DeviceGroup, DeviceIntegration -from .store import list_devices, list_groups +from .store import list_all_device_tool_settings, list_devices, list_groups log = Log.create(service="tool.device.prompt") @@ -27,6 +27,10 @@ async def build_device_context_section() -> Optional[str]: same type are connected, the per-tool description and action list appear only once in a shared "工具说明" section, keeping the prompt size O(tools) rather than O(tools × devices). + + Per-device tool overrides (stored in ``device_tool_settings`` DB table) are + loaded per device so the Agent knows which tools are individually disabled + on a given device and will not waste a round-trip trying to call them. """ try: groups = await list_groups() @@ -38,6 +42,14 @@ async def build_device_context_section() -> Optional[str]: if not devices: return None + # Load per-device tool overrides for all devices upfront in ONE query + # (avoids N+1 connections when many devices are registered). + try: + per_device_overrides: Dict[str, Dict[str, bool]] = await list_all_device_tool_settings() + except Exception as exc: + log.warn("tool.device.prompt.per_device_load_failed", {"error": str(exc)}) + per_device_overrides = {} + tool_map = _build_tool_map() group_map: Dict[str, DeviceGroup] = {g.id: g for g in groups} @@ -80,6 +92,17 @@ async def build_device_context_section() -> Optional[str]: if d.enabled and tools: lines.append(f" tool_set_id: `{d.storage_key}`") lines.append(f" 调用方式: 附带 `device_id=\"{d.id}\"` 参数") + + # Show per-device disabled tools so the Agent knows not to call them. + overrides = per_device_overrides.get(d.id, {}) + disabled_tools = sorted( + name for name, enabled in overrides.items() if not enabled + ) + if disabled_tools: + lines.append( + f" 以下工具在设备「{d.name}」(device_id=`{d.id}`) 上已单独禁用,禁止调用: " + + ", ".join(f"`{t}`" for t in disabled_tools) + ) elif not d.enabled: lines.append(f" tool_set_id: `{d.storage_key}`") lines.append(" 可用工具: (已禁用,不可调用)") diff --git a/flocks/tool/device/store.py b/flocks/tool/device/store.py index 98d8fe41d..6b8400be5 100644 --- a/flocks/tool/device/store.py +++ b/flocks/tool/device/store.py @@ -390,6 +390,109 @@ async def ensure_default_group() -> None: # Public helper for downstream callers (Agent tools, etc.) # --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Per-device tool settings (device_tool_settings table) +# --------------------------------------------------------------------------- + +async def get_device_tool_enabled(device_id: str, tool_name: str) -> Optional[bool]: + """Return the per-device enabled override for *tool_name*, or None if not set. + + None → no per-device override; fall back to global tool_settings / factory default. + True → explicitly enabled for this device. + False → explicitly disabled for this device. + """ + async with Storage.connect(Storage.get_db_path()) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT enabled FROM device_tool_settings WHERE device_id = ? AND tool_name = ?", + (device_id, tool_name), + ) as cur: + row = await cur.fetchone() + if row is None: + return None + return bool(row["enabled"]) + + +async def set_device_tool_enabled( + device_id: str, tool_name: str, enabled: bool +) -> None: + """Upsert the per-device tool override for *tool_name*. + + Bumps the device revision so the session runner's system-prompt cache + invalidates and rebuilds the DeviceAssetContext section — otherwise the + Agent would keep seeing the pre-toggle tool list until the cache TTL. + """ + now = _now_ms() + async with Storage.connect(Storage.get_db_path()) as db: + await db.execute( + """ + INSERT INTO device_tool_settings (device_id, tool_name, enabled, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(device_id, tool_name) DO UPDATE SET + enabled = excluded.enabled, + updated_at = excluded.updated_at + """, + (device_id, tool_name, int(enabled), now), + ) + await db.commit() + _bump_revision() + log.info("tool.device.tool_setting.set", { + "device_id": device_id, "tool": tool_name, "enabled": enabled, + }) + + +async def delete_device_tool_setting(device_id: str, tool_name: str) -> bool: + """Remove the per-device override for *tool_name*. + + Returns True if a row existed and was deleted. Bumps the device + revision on actual deletion so cached prompts get rebuilt. + """ + async with Storage.connect(Storage.get_db_path()) as db: + cur = await db.execute( + "DELETE FROM device_tool_settings WHERE device_id = ? AND tool_name = ?", + (device_id, tool_name), + ) + await db.commit() + removed = cur.rowcount > 0 + if removed: + _bump_revision() + log.info("tool.device.tool_setting.removed", { + "device_id": device_id, "tool": tool_name, + }) + return removed + + +async def list_device_tool_settings( + device_id: str, +) -> Dict[str, bool]: + """Return {tool_name: enabled} for all per-device overrides of *device_id*.""" + async with Storage.connect(Storage.get_db_path()) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT tool_name, enabled FROM device_tool_settings WHERE device_id = ?", + (device_id,), + ) as cur: + rows = await cur.fetchall() + return {row["tool_name"]: bool(row["enabled"]) for row in rows} + + +async def list_all_device_tool_settings() -> Dict[str, Dict[str, bool]]: + """Return {device_id: {tool_name: enabled}} for ALL devices in one query. + + Avoids the N+1 pattern when building artefacts (e.g. the DeviceAssetContext + prompt section) that need per-device overrides for many devices at once. + """ + result: Dict[str, Dict[str, bool]] = {} + async with Storage.connect(Storage.get_db_path()) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT device_id, tool_name, enabled FROM device_tool_settings" + ) as cur: + async for row in cur: + result.setdefault(row["device_id"], {})[row["tool_name"]] = bool(row["enabled"]) + return result + + async def get_device_credentials(device_id: str) -> Optional[Dict[str, Any]]: """Return plaintext credentials for *device_id*, or None if not found / disabled. diff --git a/flocks/tool/registry.py b/flocks/tool/registry.py index 1bb7fb46a..47ce998f1 100644 --- a/flocks/tool/registry.py +++ b/flocks/tool/registry.py @@ -816,6 +816,27 @@ async def execute( device_id = resolved_device_id if device_id: + # Per-device tool enable gate: an individual device instance may + # have its own enabled=False override independent of the shared + # global tool_settings. This prevents toggling a tool "for + # Device A" from affecting Device B when both share the same + # storage_key (same plugin version, different names). + try: + from flocks.tool.device.store import get_device_tool_enabled + per_device_enabled = await get_device_tool_enabled(device_id, tool_name) + if per_device_enabled is False: + return ToolResult( + success=False, + error=( + f"工具 {tool_name!r} 在设备 {device_id!r} 上已禁用。" + "如需启用,请在设备管理页面打开对应工具开关。" + ), + ) + except Exception as _gate_err: + log.debug("tool.device.per_device_gate_error", { + "tool": tool_name, "device_id": device_id, "error": str(_gate_err), + }) + from flocks.tool.credential_context import activate_device_credentials async with activate_device_credentials(device_id) as activated: if not activated: diff --git a/flocks/tool/system/slash_command.py b/flocks/tool/system/slash_command.py index 2f417fd5b..55afe4467 100644 --- a/flocks/tool/system/slash_command.py +++ b/flocks/tool/system/slash_command.py @@ -153,7 +153,7 @@ async def run_slash_command_tool( result = await run_direct_command(command, args=normalized_args) if not result.handled: return ToolResult(success=False, error=f"Unhandled slash command: {command}") - if result.prompt is not None or result.clear_screen: + if result.prompt is not None or result.clear_screen or result.clear_history: return ToolResult( success=False, error=f"Slash command /{command} is not agent-safe in this context.", diff --git a/flocks/tool/tool_loader.py b/flocks/tool/tool_loader.py index 64e4071fc..b6c6348a8 100644 --- a/flocks/tool/tool_loader.py +++ b/flocks/tool/tool_loader.py @@ -17,6 +17,7 @@ import importlib.util import inspect import re +import sys import urllib.parse from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -308,6 +309,33 @@ def _build_handler(raw_handler: dict, yaml_path: Path) -> ToolHandler: raise ValueError(f"Unknown handler type: {handler_type}") +def _build_tcp_connector(verify_ssl: bool) -> "Any": + """Create a per-request aiohttp TCPConnector that tears down sockets promptly. + + Declarative HTTP tools (skyeye / tdp / onesec / qingteng / threatbook 等) build a + fresh ``ClientSession`` + connector for every call, so connection pooling brings no + benefit. Leaving keep-alive enabled lets the remote (often a self-signed HTTPS + appliance) send FIN while our socket lingers in ``CLOSE_WAIT``; under rapid back-to-back + tool calls — e.g. ``run_workflow`` / ``run_workflow_node`` driving dozens of nodes — + these half-closed sockets accumulate (60+ observed) and eventually starve the event loop. + + Mitigations: + - ``force_close=True``: close the underlying connection right after each request + instead of returning it to the pool, so no idle keep-alive socket can rot into + ``CLOSE_WAIT``. + - ``enable_cleanup_closed=True``: on CPython < 3.12.7 the asyncio SSL transport leak + means TLS sockets are not fully torn down; this aborts them after a short grace + period. The flag was made a no-op (with a DeprecationWarning) once the CPython bug + was fixed, so we only pass it on the affected interpreters. + """ + import aiohttp + + kwargs: Dict[str, Any] = {"ssl": verify_ssl, "force_close": True} + if sys.version_info < (3, 12, 7): + kwargs["enable_cleanup_closed"] = True + return aiohttp.TCPConnector(**kwargs) + + def _build_http_handler(cfg: dict) -> ToolHandler: """Build an async HTTP request handler from declarative config.""" method = cfg.get("method", "GET").upper() @@ -363,7 +391,7 @@ async def handler(ctx: ToolContext, **kwargs: Any) -> ToolResult: try: client_timeout = aiohttp.ClientTimeout(total=timeout) - connector = aiohttp.TCPConnector(ssl=verify_ssl) + connector = _build_tcp_connector(verify_ssl) async with aiohttp.ClientSession(timeout=client_timeout, connector=connector) as session: req_kwargs: Dict[str, Any] = {"headers": headers} if query_params: diff --git a/flocks/utils/log.py b/flocks/utils/log.py index 7d3c46ec0..f22d7a4fb 100644 --- a/flocks/utils/log.py +++ b/flocks/utils/log.py @@ -10,6 +10,7 @@ import os import sys import time +import threading from pathlib import Path from typing import Any, Dict, Optional, TextIO from datetime import datetime @@ -17,6 +18,13 @@ import glob as file_glob +_DEFAULT_LOG_MAX_BYTES = 5 * 1024 * 1024 +_DEFAULT_LOG_BACKUP_COUNT = 3 +_DEFAULT_LOG_VALUE_MAX_CHARS = 8 * 1024 +_MAX_STRUCTURED_ITEMS = 50 +_MAX_STRUCTURED_DEPTH = 4 + + def _log_dir() -> Path: """Log directory: FLOCKS_LOG_DIR, or FLOCKS_ROOT/logs, or ~/.flocks/logs. Matches config.""" raw = os.getenv("FLOCKS_LOG_DIR") @@ -33,6 +41,177 @@ def get_log_dir() -> Path: return _log_dir() +def _env_int(name: str, default: int) -> int: + raw = os.getenv(name) + if raw is None: + return default + try: + return int(raw) + except ValueError: + return default + + +def get_log_max_bytes(default: int = _DEFAULT_LOG_MAX_BYTES) -> int: + """Return the per-file log size limit in bytes. + + ``FLOCKS_LOG_MAX_BYTES`` is exact; ``FLOCKS_LOG_MAX_MB`` is a convenient + human-facing override. When both are set, bytes wins. Values <= 0 disable + rotation. + """ + if os.getenv("FLOCKS_LOG_MAX_BYTES") is not None: + return _env_int("FLOCKS_LOG_MAX_BYTES", default) + max_mb = os.getenv("FLOCKS_LOG_MAX_MB") + if max_mb is not None: + try: + return int(float(max_mb) * 1024 * 1024) + except ValueError: + return default + return default + + +def get_log_backup_count(default: int = _DEFAULT_LOG_BACKUP_COUNT) -> int: + """Return how many rotated backups to keep for long-lived log files.""" + return max(0, _env_int("FLOCKS_LOG_BACKUP_COUNT", default)) + + +def rotate_log_file( + path: Path, + *, + max_bytes: Optional[int] = None, + backup_count: Optional[int] = None, + force: bool = False, +) -> None: + """Rotate ``path`` if it is already over the configured size limit.""" + limit = get_log_max_bytes() if max_bytes is None else max_bytes + backups = get_log_backup_count() if backup_count is None else backup_count + if limit <= 0 or not path.exists(): + return + try: + if not force and path.stat().st_size < limit: + return + if backups <= 0: + path.unlink(missing_ok=True) + return + for index in range(backups - 1, 0, -1): + src = path.with_name(f"{path.name}.{index}") + dst = path.with_name(f"{path.name}.{index + 1}") + if src.exists(): + src.replace(dst) + path.replace(path.with_name(f"{path.name}.1")) + except OSError: + return + + +class _RotatingTextWriter: + """Small line-buffered writer with size-based rotation for Flocks logs.""" + + def __init__(self, path: Path, *, max_bytes: int, backup_count: int): + self.path = path + self.max_bytes = max_bytes + self.backup_count = backup_count + self._handle: Optional[TextIO] = None + self._bytes_written = 0 + self._lock = threading.RLock() + self._open() + + def _open(self) -> None: + self.path.parent.mkdir(parents=True, exist_ok=True) + self._handle = open(self.path, "a", buffering=1, encoding="utf-8") + try: + self._bytes_written = self.path.stat().st_size + except OSError: + self._bytes_written = 0 + + def _should_rotate(self, message: str) -> bool: + if self.max_bytes <= 0: + return False + return self._bytes_written + len(message.encode("utf-8")) > self.max_bytes + + def write(self, message: str) -> int: + encoded_len = len(message.encode("utf-8")) + with self._lock: + if self._should_rotate(message): + self.close() + rotate_log_file( + self.path, + max_bytes=self.max_bytes, + backup_count=self.backup_count, + force=True, + ) + self._open() + if self._handle is None: + self._open() + written = self._handle.write(message) + self._bytes_written += encoded_len + return written + + def flush(self) -> None: + with self._lock: + if self._handle is not None: + self._handle.flush() + + def close(self) -> None: + with self._lock: + if self._handle is not None: + self._handle.close() + self._handle = None + + +def _truncate_for_log(value: str, max_chars: Optional[int] = None) -> str: + limit = _env_int("FLOCKS_LOG_VALUE_MAX_CHARS", _DEFAULT_LOG_VALUE_MAX_CHARS) if max_chars is None else max_chars + if limit <= 0 or len(value) <= limit: + return value + omitted = len(value) - limit + return f"{value[:limit]}..." + + +def _prepare_json_value(value: Any, *, depth: int = 0, seen: Optional[set[int]] = None) -> Any: + if isinstance(value, str): + return _truncate_for_log(value) + if value is None or isinstance(value, (bool, int, float)): + return value + if seen is None: + seen = set() + value_id = id(value) + if value_id in seen: + return "" + if depth >= _MAX_STRUCTURED_DEPTH: + return f"<{type(value).__name__}>" + if isinstance(value, dict): + seen.add(value_id) + prepared = {} + for index, (key, item) in enumerate(value.items()): + if index >= _MAX_STRUCTURED_ITEMS: + prepared["__truncated__"] = f"{len(value) - _MAX_STRUCTURED_ITEMS} more keys" + break + prepared_key = key if key is None or isinstance(key, (str, int, float, bool)) else _truncate_for_log(str(key)) + prepared[prepared_key] = _prepare_json_value(item, depth=depth + 1, seen=seen) + seen.remove(value_id) + return prepared + if isinstance(value, list): + seen.add(value_id) + prepared = [ + _prepare_json_value(item, depth=depth + 1, seen=seen) + for item in value[:_MAX_STRUCTURED_ITEMS] + ] + if len(value) > _MAX_STRUCTURED_ITEMS: + prepared.append(f"") + seen.remove(value_id) + return prepared + return _truncate_for_log(str(value)) + + +def _format_log_value(value: Any) -> str: + if isinstance(value, Exception): + return _truncate_for_log(Log._format_error(value)) + if isinstance(value, (dict, list)): + try: + return _truncate_for_log(json.dumps(_prepare_json_value(value))) + except (TypeError, ValueError): + return _truncate_for_log(str(value)) + return _truncate_for_log(str(value)) + + def append_upgrade_text_log(message: str) -> None: """Append timestamped lines to ``update.log`` under the configured log directory. @@ -101,14 +280,7 @@ def _build_message(self, message: Any, extra: Optional[Dict[str, Any]] = None) - # Build prefix (key=value pairs) prefix_parts = [] for key, value in all_tags.items(): - if isinstance(value, Exception): - # Format error with message and cause chain - prefix_parts.append(f"{key}={Log._format_error(value)}") - elif isinstance(value, dict): - # JSON stringify objects - prefix_parts.append(f"{key}={json.dumps(value)}") - else: - prefix_parts.append(f"{key}={value}") + prefix_parts.append(f"{key}={_format_log_value(value)}") prefix = " ".join(prefix_parts) @@ -122,7 +294,7 @@ def _build_message(self, message: Any, extra: Optional[Dict[str, Any]] = None) - Log._last_time = current_time_ms # Build full message - parts = [timestamp, f"+{diff_ms}ms", prefix, str(message) if message else ""] + parts = [timestamp, f"+{diff_ms}ms", prefix, _truncate_for_log(str(message)) if message else ""] return " ".join([p for p in parts if p]) + "\n" def debug(self, message: Any = None, extra: Optional[Dict[str, Any]] = None) -> None: @@ -315,8 +487,12 @@ async def init( if cls._log_file.exists(): cls._log_file.write_text("") - # Open for writing - cls._writer = open(cls._log_file, "a", buffering=1) # Line buffered + # Open for writing with size-based rotation for long-running sessions. + cls._writer = _RotatingTextWriter( + cls._log_file, + max_bytes=get_log_max_bytes(), + backup_count=get_log_backup_count(), + ) # Create default logger cls.Default = cls.create(service="default") @@ -330,18 +506,33 @@ async def _cleanup(cls, log_dir: Path) -> None: log_dir: Directory containing log files """ try: - # Find all log files matching pattern YYYY-MM-DDTHHMMSS.log + # Find base log files matching pattern YYYY-MM-DDTHHMMSS.log. + # Rotated siblings are deleted together with their base file so + # old ``.log.1``/``.log.2`` files do not leak forever. pattern = str(log_dir / "????-??-??T??????.log") - files = sorted(file_glob.glob(pattern)) + files = [Path(path) for path in sorted(file_glob.glob(pattern))] # Keep only the 10 most recent if len(files) > 10: files_to_delete = files[:-10] - for file_path in files_to_delete: + for path in files_to_delete: try: - Path(file_path).unlink() + path.unlink(missing_ok=True) + for rotated in path.parent.glob(f"{path.name}.*"): + rotated.unlink(missing_ok=True) except Exception: pass # Silently ignore deletion errors + + kept_files = set(files[-10:]) + rotated_pattern = str(log_dir / "????-??-??T??????.log.*") + for rotated_path in (Path(path) for path in file_glob.glob(rotated_pattern)): + base_name = rotated_path.name.split(".log.", 1)[0] + ".log" + base_path = rotated_path.with_name(base_name) + if base_path not in kept_files and not base_path.exists(): + try: + rotated_path.unlink(missing_ok=True) + except Exception: + pass except Exception: pass # Silently ignore cleanup errors diff --git a/flocks/workflow/engine.py b/flocks/workflow/engine.py index 6762fd672..e58281f3d 100644 --- a/flocks/workflow/engine.py +++ b/flocks/workflow/engine.py @@ -9,7 +9,7 @@ import traceback import uuid import json -from typing import Any, Callable, Deque, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar +from typing import Any, Callable, Deque, Dict, List, Literal, NamedTuple, Optional, Set, Tuple, TypeVar from pydantic import BaseModel, Field @@ -39,12 +39,45 @@ class _ExecOutcome(NamedTuple): is_cancelled: bool = False +def _summarize_for_observability(value: Any, *, depth: int = 0) -> Any: + """Return a bounded, non-retaining summary for logs and lightweight history.""" + if depth >= 2: + if isinstance(value, (list, tuple, set)): + return {"_type": type(value).__name__, "count": len(value)} + if isinstance(value, dict): + return {"_type": "dict", "keys": list(value.keys())[:20]} + if isinstance(value, str) and len(value) > 200: + return {"_type": "string", "chars": len(value), "preview": value[:200]} + return value + if isinstance(value, dict): + return { + key: _summarize_for_observability(item, depth=depth + 1) + for key, item in list(value.items())[:50] + } + if isinstance(value, (list, tuple, set)): + return { + "_type": type(value).__name__, + "count": len(value), + "preview": [ + _summarize_for_observability(item, depth=depth + 1) + for item in list(value)[:3] + ], + } + if isinstance(value, str) and len(value) > 200: + return {"_type": "string", "chars": len(value), "preview": value[:200]} + return value + + def _outputs_for_log(outputs: Dict[str, Any], *, max_chars: int = 4000) -> str: - """Serialize outputs for logs with bounded size.""" + """Serialize an output summary for logs with bounded size.""" try: - text = json.dumps(outputs, ensure_ascii=False, default=str) + text = json.dumps( + _summarize_for_observability(outputs), + ensure_ascii=False, + default=str, + ) except Exception: - text = repr(outputs) + text = repr(_summarize_for_observability(outputs)) if len(text) <= max_chars: return text return text[:max_chars] + f"...[truncated:{len(text) - max_chars}]" @@ -64,6 +97,7 @@ class ExecutionResult(BaseModel): steps: int history: list[StepResult] = Field(default_factory=list) last_node_id: Optional[str] = None + outputs: Dict[str, Any] = Field(default_factory=dict) run_id: str = Field(default_factory=lambda: uuid.uuid4().hex) @@ -110,6 +144,7 @@ class WorkflowEngine: mutate_workflow: bool = False workflow_path: Optional[str] = None node_timeout_s: Optional[float] = 300.0 + history_mode: Literal["full", "summary"] = "full" _depth: int = 0 max_parallel_workers: int = 4 workflow_loader: Optional[Callable[[str], "Workflow"]] = field(default=None, repr=False) @@ -119,6 +154,8 @@ def __post_init__(self) -> None: raise ValueError("max_steps must be > 0") if self.max_parallel_workers < 1: raise ValueError("max_parallel_workers must be >= 1") + if self.history_mode not in ("full", "summary"): + raise ValueError("history_mode must be 'full' or 'summary'") if self.runtime is None: self.runtime = PythonExecRuntime() if self.code_gen is None: @@ -169,6 +206,7 @@ def run( [(self.workflow.start, initial_inputs or {}, None)] ) history: list[StepResult] = [] + last_outputs: Dict[str, Any] = {} step_count = 0 last_node_id: Optional[str] = None rid = (run_id or uuid.uuid4().hex).strip() or uuid.uuid4().hex @@ -191,6 +229,7 @@ def _build_execution_context() -> Dict[str, Any]: "run_id": rid, "steps": step_count, "last_node_id": last_node_id, + "outputs": last_outputs, "history": history, } @@ -266,15 +305,22 @@ def merge_payload(src: str, payload: Dict[str, Any]) -> None: inputs = merged join_seen_sources.pop(node_id, None) - # Dedup: skip if same node already ran with identical inputs - try: - _hash_raw = json.dumps( - {"n": node_id, "i": inputs}, - sort_keys=True, ensure_ascii=False, default=str, - ) - _input_hash = hashlib.sha256(_hash_raw.encode()).hexdigest()[:16] - except Exception: + # Dedup: skip if same node already ran with identical inputs. + # Lightweight history mode is used by high-throughput ingest + # paths with large payloads; hashing full inputs there would + # serialize the same large alert lists we are trying not to + # retain. + if self.history_mode == "summary": _input_hash = "" + else: + try: + _hash_raw = json.dumps( + {"n": node_id, "i": inputs}, + sort_keys=True, ensure_ascii=False, default=str, + ) + _input_hash = hashlib.sha256(_hash_raw.encode()).hexdigest()[:16] + except Exception: + _input_hash = "" if _input_hash and node_id in _dedup_hashes and _dedup_hashes[node_id] == _input_hash: _logger.info( "wf.step.dedup_skip node=%s (identical input hash %s)", @@ -413,12 +459,45 @@ def _par_exec( _nid, _nd, _inp, _src = ready[_eo.idx] _sn = step_count + _eo.idx + 1 last_node_id = _nid + last_outputs = ( + _summarize_for_observability(_eo.outputs) + if self.history_mode == "summary" + else _eo.outputs + ) + + def _build_step_result( + *, + outputs: Dict[str, Any], + stdout: str, + error: Optional[str], + traceback_text: Optional[str] = None, + ) -> StepResult: + if self.history_mode == "summary": + return StepResult( + node_id=_nid, + inputs=_summarize_for_observability(_inp), + outputs=_summarize_for_observability(outputs), + stdout=stdout, + error=error, + traceback=traceback_text, + duration_ms=_eo.duration_ms, + ) + return StepResult( + node_id=_nid, + inputs=_inp, + outputs=outputs, + stdout=stdout, + error=error, + traceback=traceback_text, + duration_ms=_eo.duration_ms, + ) if _eo.is_cancelled: - step_res = StepResult( - node_id=_nid, inputs=_inp, outputs={}, - stdout=_eo.stdout, error=_eo.error or "Run cancelled", - traceback=_eo.traceback, duration_ms=_eo.duration_ms, + step_res = _build_step_result( + outputs={}, + stdout=_eo.stdout, + error=_eo.error or "Run cancelled", + traceback_text=_eo.traceback, ) history.append(step_res) if on_step_end is not None and _eo.idx in step_tokens: @@ -440,10 +519,11 @@ def _par_exec( if _eo.traceback: print("[WF] traceback:") print(_eo.traceback.rstrip()) - step_res = StepResult( - node_id=_nid, inputs=_inp, outputs=_eo.outputs, - stdout=_eo.stdout, error=_eo.error, traceback=_eo.traceback, - duration_ms=_eo.duration_ms, + step_res = _build_step_result( + outputs=_eo.outputs, + stdout=_eo.stdout, + error=_eo.error, + traceback_text=_eo.traceback, ) history.append(step_res) _status = "timeout" if _eo.is_timeout else "error" @@ -495,9 +575,10 @@ def _par_exec( print("[WF] outputs=" + json.dumps(_eo.outputs, ensure_ascii=False, default=str)) except Exception: print(f"[WF] outputs= {_eo.outputs!r}") - step_res = StepResult( - node_id=_nid, inputs=_inp, outputs=_eo.outputs, - stdout=_eo.stdout, error=None, duration_ms=_eo.duration_ms, + step_res = _build_step_result( + outputs=_eo.outputs, + stdout=_eo.stdout, + error=None, ) history.append(step_res) _logger.info( @@ -545,7 +626,13 @@ def _par_exec( f"{nid} expected={expected} seen={seen}" for nid, expected, seen in pending_joins ) raise NodeExecutionError(node_id=pending_joins[0][0], message=msg) - return ExecutionResult(steps=step_count, history=history, last_node_id=last_node_id, run_id=rid) + return ExecutionResult( + steps=step_count, + history=history, + last_node_id=last_node_id, + outputs=last_outputs, + run_id=rid, + ) finally: if isinstance(self.runtime, PythonExecRuntime): self.runtime.cancel_checker = previous_cancel_checker @@ -775,6 +862,7 @@ def _execute_subworkflow_node( use_llm=self.use_llm, trace=self.trace, node_timeout_s=self.node_timeout_s, + history_mode=self.history_mode, _depth=self._depth + 1, workflow_loader=self.workflow_loader, ) diff --git a/flocks/workflow/execution_store.py b/flocks/workflow/execution_store.py index d36d5e60d..8bc6b3309 100644 --- a/flocks/workflow/execution_store.py +++ b/flocks/workflow/execution_store.py @@ -3,9 +3,10 @@ from __future__ import annotations import asyncio +import re import time import uuid -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Set from flocks.session.recorder import Recorder from flocks.storage.storage import Storage @@ -120,6 +121,12 @@ def compact_history_for_storage( # call per workflow to amortise the cost under high syslog throughput. _TRIM_CHECK_INTERVAL = 5 _trim_counters: Dict[str, int] = {} +# Workflows with an in-flight trim task. Because trims run as fire-and-forget +# ``asyncio.create_task`` background jobs, a slow trim under high syslog load +# could otherwise spawn many overlapping scans that each materialise table +# state simultaneously — the exact pattern that drove RSS to 20 GB. This +# guard ensures at most one trim per workflow is ever running. +_trim_in_flight: Set[str] = set() # Per-workflow lock to serialize read-modify-write of stats. Concurrent # executions of the same workflow (e.g. syslog-triggered runs with @@ -341,24 +348,56 @@ async def _record_audit() -> None: pass +# Regex patterns to extract scalar fields from raw JSON strings without +# calling json.loads. workflowId/startedAt are always serialised near the +# start of the record (set in build_initial_execution_record), so we only +# scan a small prefix of each value string — O(prefix) instead of O(value). +_RE_WORKFLOW_ID = re.compile(r'"workflowId"\s*:\s*"([^"]+)"') +_RE_STARTED_AT = re.compile(r'"startedAt"\s*:\s*(\d+)') + + async def _trim_execution_history(workflow_id: str) -> None: - """Delete the oldest execution records once the per-workflow cap is exceeded.""" - all_entries = await Storage.list_entries("workflow_execution/") - wf_entries = [ - (key, data) - for key, data in all_entries - if isinstance(data, dict) and data.get("workflowId") == workflow_id - ] - if len(wf_entries) <= _MAX_EXECUTION_HISTORY_PER_WORKFLOW: + """Delete the oldest execution records once the per-workflow cap is exceeded. + + Uses ``Storage.list_raw`` + regex instead of ``list_entries`` + ``json.loads`` + so that scanning the execution-history table never materialises large JSON + blobs as Python objects. The previous approach caused 100% single-core CPU + usage (``json.raw_decode``) and drove RSS to 20 GB under syslog load. + + Also guards against overlapping trim tasks via ``_trim_in_flight``: because + trims run as fire-and-forget background tasks, without the guard a slow trim + would spawn multiple concurrent scans that each load the full table + simultaneously, multiplying the memory spike. + """ + # Coalesce overlapping trims: only one scan per workflow at a time. + if workflow_id in _trim_in_flight: return - # Sort ascending by startedAt and remove the oldest excess records - wf_entries.sort(key=lambda kd: kd[1].get("startedAt", 0)) - excess = len(wf_entries) - _MAX_EXECUTION_HISTORY_PER_WORKFLOW - for key, _ in wf_entries[:excess]: - try: - exec_id = key.rsplit("/", 1)[-1] - await Storage.remove(key) - record_path = Recorder.paths().workflow_dir / f"{exec_id}.jsonl" - await asyncio.to_thread(record_path.unlink, missing_ok=True) - except Exception: - pass + _trim_in_flight.add(workflow_id) + try: + raw_rows = await Storage.list_raw("workflow_execution/") + wf_entries: List[tuple] = [] + for key, value_str in raw_rows: + # Scan only the first 400 chars — enough for workflowId + startedAt. + head = value_str[:400] + m_wf = _RE_WORKFLOW_ID.search(head) + if not m_wf or m_wf.group(1) != workflow_id: + continue + m_ts = _RE_STARTED_AT.search(head) + started_at = int(m_ts.group(1)) if m_ts else 0 + wf_entries.append((key, started_at)) + + if len(wf_entries) <= _MAX_EXECUTION_HISTORY_PER_WORKFLOW: + return + # Sort ascending by startedAt and remove the oldest excess records. + wf_entries.sort(key=lambda kd: kd[1]) + excess = len(wf_entries) - _MAX_EXECUTION_HISTORY_PER_WORKFLOW + for key, _ in wf_entries[:excess]: + try: + exec_id = key.rsplit("/", 1)[-1] + await Storage.remove(key) + record_path = Recorder.paths().workflow_dir / f"{exec_id}.jsonl" + await asyncio.to_thread(record_path.unlink, missing_ok=True) + except Exception: + pass + finally: + _trim_in_flight.discard(workflow_id) diff --git a/flocks/workflow/logging_config.py b/flocks/workflow/logging_config.py index 135e1fce4..cb0a6c83a 100644 --- a/flocks/workflow/logging_config.py +++ b/flocks/workflow/logging_config.py @@ -6,10 +6,10 @@ import logging import sys -from pathlib import Path +from logging.handlers import RotatingFileHandler from typing import Optional -from flocks.utils.log import get_log_dir +from flocks.utils.log import get_log_backup_count, get_log_dir, get_log_max_bytes def setup_workflow_logging( @@ -54,8 +54,11 @@ def setup_workflow_logging( try: log_dir = get_log_dir() log_dir.mkdir(parents=True, exist_ok=True) - file_handler = logging.FileHandler( - log_dir / "workflow.log", mode="a", encoding="utf-8" + file_handler = RotatingFileHandler( + log_dir / "workflow.log", + maxBytes=get_log_max_bytes(), + backupCount=get_log_backup_count(), + encoding="utf-8", ) file_handler.setLevel(level) file_handler.setFormatter(formatter) diff --git a/flocks/workflow/poller_manager.py b/flocks/workflow/poller_manager.py new file mode 100644 index 000000000..c25d4cd03 --- /dev/null +++ b/flocks/workflow/poller_manager.py @@ -0,0 +1,468 @@ +"""Lifecycle manager for workflow pollers. + +This mirrors the Kafka/syslog managers: one background poller task per workflow +id that periodically triggers ``run_workflow`` with configured inputs. +""" + +from __future__ import annotations + +import asyncio +import threading +import time +import uuid +from datetime import datetime +from typing import Any, Dict + +from flocks.storage.storage import Storage +from flocks.utils.log import Log +from flocks.workflow.execution_store import ( + compact_history_for_storage, + compact_outputs_for_storage, + create_execution_record, + record_execution_result, + resolve_execution_outcome, +) +from flocks.workflow.fs_store import read_workflow_from_fs +from flocks.workflow.runner import RunWorkflowResult, run_workflow + +WORKFLOW_POLLER_CONFIG_PREFIX = "workflow_poller_config/" +DEFAULT_INTERVAL_SECONDS = 30 +DEFAULT_TIMEOUT_SECONDS = 7200 +RUN_SHUTDOWN_GRACE_SECONDS = 1.0 + +log = Log.create(service="workflow.poller") + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def _today_string() -> str: + return datetime.now().strftime("%Y-%m-%d") + + +class WorkflowPollerManager: + """Manage one background poller loop per workflow id.""" + + def __init__(self) -> None: + self._tasks: dict[str, asyncio.Task[Any]] = {} + self._abort_events: dict[str, asyncio.Event] = {} + self._run_tasks: dict[str, set[asyncio.Task[Any]]] = {} + self._run_cancel_events: dict[str, set[threading.Event]] = {} + self._status: dict[str, Dict[str, Any]] = {} + + @staticmethod + def _config_key(workflow_id: str) -> str: + return f"{WORKFLOW_POLLER_CONFIG_PREFIX}{workflow_id}" + + def _normalize_config(self, workflow_id: str, data: Any) -> Dict[str, Any]: + raw = data if isinstance(data, dict) else {} + interval_seconds = int(raw.get("intervalSeconds") or DEFAULT_INTERVAL_SECONDS) + timeout_seconds = int(raw.get("timeoutSeconds") or DEFAULT_TIMEOUT_SECONDS) + inputs = raw.get("inputs") if isinstance(raw.get("inputs"), dict) else {} + return { + "workflowId": workflow_id, + "enabled": bool(raw.get("enabled")), + "intervalSeconds": max(1, interval_seconds), + "timeoutSeconds": max(1, timeout_seconds), + "noOverlap": bool(raw.get("noOverlap", True)), + "inputs": dict(inputs), + "updatedAt": raw.get("updatedAt"), + } + + def _cleanup_done_runs(self, workflow_id: str) -> int: + tasks = self._run_tasks.get(workflow_id) + if not tasks: + return 0 + active_tasks = {task for task in tasks if not task.done()} + if active_tasks: + self._run_tasks[workflow_id] = active_tasks + return len(active_tasks) + self._run_tasks.pop(workflow_id, None) + return 0 + + def _register_run_task(self, workflow_id: str, task: asyncio.Task[Any]) -> None: + task_set = self._run_tasks.setdefault(workflow_id, set()) + task_set.add(task) + + def _discard(done_task: asyncio.Task[Any]) -> None: + tasks = self._run_tasks.get(workflow_id) + if tasks is not None: + tasks.discard(done_task) + if not tasks: + self._run_tasks.pop(workflow_id, None) + + task.add_done_callback(_discard) + + def _build_inputs(self, config: Dict[str, Any]) -> Dict[str, Any]: + inputs = dict(config.get("inputs") or {}) + if not str(inputs.get("input_date") or "").strip(): + inputs["input_date"] = _today_string() + inputs["_trigger"] = "poller" + inputs["_poller_run_id"] = f"poller-{_now_ms()}-{uuid.uuid4().hex[:8]}" + return inputs + + def _summarize_outputs(self, outputs: Any) -> Dict[str, Any]: + if not isinstance(outputs, dict): + return {} + + summary: Dict[str, Any] = {} + load_stats = outputs.get("load_stats") + if isinstance(load_stats, dict) and isinstance(load_stats.get("record_count"), int): + summary["selectedCount"] = load_stats["record_count"] + + if isinstance(outputs.get("processed_cache_size_after"), int): + summary["processedMarkCount"] = outputs["processed_cache_size_after"] + elif isinstance(outputs.get("processed_mark_count"), int): + summary["processedMarkCount"] = outputs["processed_mark_count"] + + if isinstance(outputs.get("kafka_message_count"), int): + summary["kafkaMessageCount"] = outputs["kafka_message_count"] + + channel_status = outputs.get("channel_notify_status") + if channel_status is not None: + summary["channelNotifyStatus"] = channel_status + + return summary + + def _base_status(self, workflow_id: str) -> Dict[str, Any]: + return { + "workflowId": workflow_id, + "state": "stopped", + "error": None, + "activeRuns": 0, + "lastRunAt": None, + "lastStatus": None, + "lastError": None, + "lastDurationMs": None, + "selectedCount": None, + "processedMarkCount": None, + "channelNotifyStatus": None, + "kafkaMessageCount": None, + "nextRunAt": None, + "lastRunId": None, + } + + def get_status(self, workflow_id: str) -> Dict[str, Any]: + status = dict(self._base_status(workflow_id)) + status.update(self._status.get(workflow_id) or {}) + status["activeRuns"] = self._cleanup_done_runs(workflow_id) + if workflow_id not in self._tasks and status.get("state") == "running": + status["state"] = "stopped" + status["nextRunAt"] = None + return status + + async def start_all(self) -> None: + try: + keys = await Storage.list_keys(WORKFLOW_POLLER_CONFIG_PREFIX) + except Exception as exc: + log.warning("poller.list_keys_failed", {"error": str(exc)}) + return + + for key in keys: + if not key.startswith(WORKFLOW_POLLER_CONFIG_PREFIX): + continue + workflow_id = key[len(WORKFLOW_POLLER_CONFIG_PREFIX):] + if not workflow_id: + continue + try: + data = await Storage.read(key) + except Exception as exc: + log.warning("poller.config_read_failed", {"key": key, "error": str(exc)}) + continue + if isinstance(data, dict) and data.get("enabled"): + await self.restart_workflow(workflow_id) + + async def stop_all(self) -> None: + for workflow_id in list(self._tasks.keys()): + await self.stop_workflow(workflow_id) + + async def stop_workflow(self, workflow_id: str) -> None: + abort_event = self._abort_events.get(workflow_id) + if abort_event is not None: + abort_event.set() + + for cancel_event in self._run_cancel_events.get(workflow_id, set()): + cancel_event.set() + + task = self._tasks.pop(workflow_id, None) + if task is not None and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + pass + + run_tasks = list(self._run_tasks.get(workflow_id, set())) + if run_tasks: + await asyncio.wait(run_tasks, timeout=RUN_SHUTDOWN_GRACE_SECONDS) + + self._abort_events.pop(workflow_id, None) + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current["state"] = "stopped" + current["error"] = None + current["nextRunAt"] = None + current["activeRuns"] = self._cleanup_done_runs(workflow_id) + if current["activeRuns"] == 0: + self._run_cancel_events.pop(workflow_id, None) + self._status[workflow_id] = current + + async def restart_workflow(self, workflow_id: str) -> Dict[str, Any]: + await self.stop_workflow(workflow_id) + try: + stored = await Storage.read(self._config_key(workflow_id)) + except Exception as exc: + log.warning("poller.restart_read_failed", {"workflow_id": workflow_id, "error": str(exc)}) + return {"workflowId": workflow_id, "state": "failed", "error": str(exc)} + + config = self._normalize_config(workflow_id, stored) + if not config.get("enabled"): + self._status[workflow_id] = { + **self._base_status(workflow_id), + "workflowId": workflow_id, + "state": "stopped", + "error": None, + } + return self.get_status(workflow_id) + + wf_data = read_workflow_from_fs(workflow_id) + if not wf_data: + err = "workflow_not_found" + self._status[workflow_id] = { + **self.get_status(workflow_id), + "workflowId": workflow_id, + "state": "failed", + "error": err, + } + return self.get_status(workflow_id) + + workflow_json = wf_data.get("workflowJson") + if not workflow_json: + err = "workflow_json_missing" + self._status[workflow_id] = { + **self.get_status(workflow_id), + "workflowId": workflow_id, + "state": "failed", + "error": err, + } + return self.get_status(workflow_id) + + abort_event = asyncio.Event() + self._abort_events[workflow_id] = abort_event + self._status[workflow_id] = { + **self.get_status(workflow_id), + "workflowId": workflow_id, + "state": "running", + "error": None, + "enabled": True, + "intervalSeconds": config["intervalSeconds"], + "timeoutSeconds": config["timeoutSeconds"], + "noOverlap": config["noOverlap"], + "nextRunAt": _now_ms(), + } + task = asyncio.create_task( + self._poller_loop(workflow_id, workflow_json, config, abort_event), + name=f"workflow-poller-{workflow_id}", + ) + self._tasks[workflow_id] = task + return self.get_status(workflow_id) + + async def run_once(self, workflow_id: str) -> Dict[str, Any]: + try: + stored = await Storage.read(self._config_key(workflow_id)) + except Exception as exc: + log.warning("poller.run_once_read_failed", {"workflow_id": workflow_id, "error": str(exc)}) + current = self.get_status(workflow_id) + current["lastStatus"] = "failed" + current["lastError"] = str(exc) + return current + + config = self._normalize_config(workflow_id, stored) + wf_data = read_workflow_from_fs(workflow_id) + if not wf_data: + current = self.get_status(workflow_id) + current["state"] = "failed" if workflow_id in self._tasks else current.get("state", "stopped") + current["lastStatus"] = "failed" + current["lastError"] = "workflow_not_found" + self._status[workflow_id] = current + return self.get_status(workflow_id) + + workflow_json = wf_data.get("workflowJson") + if not workflow_json: + current = self.get_status(workflow_id) + current["state"] = "failed" if workflow_id in self._tasks else current.get("state", "stopped") + current["lastStatus"] = "failed" + current["lastError"] = "workflow_json_missing" + self._status[workflow_id] = current + return self.get_status(workflow_id) + + return await self._execute_run(workflow_id, workflow_json, config) + + async def _poller_loop( + self, + workflow_id: str, + workflow_json: Dict[str, Any], + config: Dict[str, Any], + abort_event: asyncio.Event, + ) -> None: + interval_seconds = config["intervalSeconds"] + try: + while not abort_event.is_set(): + await self._schedule_run(workflow_id, workflow_json, config) + next_run_at = _now_ms() + interval_seconds * 1000 + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current["nextRunAt"] = next_run_at + current["activeRuns"] = self._cleanup_done_runs(workflow_id) + self._status[workflow_id] = current + try: + await asyncio.wait_for(abort_event.wait(), timeout=interval_seconds) + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + raise + except Exception as exc: + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current["state"] = "failed" + current["error"] = str(exc) + current["nextRunAt"] = None + self._status[workflow_id] = current + log.warning("poller.loop_failed", {"workflow_id": workflow_id, "error": str(exc)}) + finally: + if workflow_id in self._tasks and self._tasks.get(workflow_id) is asyncio.current_task(): + current = self._status.get(workflow_id) or self._base_status(workflow_id) + if current.get("state") != "failed": + current["state"] = "stopped" + current["error"] = None + current["nextRunAt"] = None + current["activeRuns"] = self._cleanup_done_runs(workflow_id) + self._status[workflow_id] = current + + async def _schedule_run( + self, + workflow_id: str, + workflow_json: Dict[str, Any], + config: Dict[str, Any], + ) -> None: + active_runs = self._cleanup_done_runs(workflow_id) + if config.get("noOverlap", True) and active_runs > 0: + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current["lastStatus"] = "skipped" + current["lastError"] = "previous_run_still_active" + current["activeRuns"] = active_runs + self._status[workflow_id] = current + return + + run_task = asyncio.create_task( + self._execute_run(workflow_id, workflow_json, config), + name=f"workflow-poller-run-{workflow_id}", + ) + self._register_run_task(workflow_id, run_task) + + async def _execute_run( + self, + workflow_id: str, + workflow_json: Dict[str, Any], + config: Dict[str, Any], + ) -> Dict[str, Any]: + started_at_ms = _now_ms() + started_at_s = time.time() + cancel_event = threading.Event() + cancel_events = self._run_cancel_events.setdefault(workflow_id, set()) + cancel_events.add(cancel_event) + inputs = self._build_inputs(config) + exec_data = await create_execution_record(workflow_id, input_params=inputs) + exec_id = str(exec_data["id"]) + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current["lastRunAt"] = started_at_ms + current["activeRuns"] = self._cleanup_done_runs(workflow_id) + self._status[workflow_id] = current + + try: + result = await asyncio.to_thread( + run_workflow, + workflow=workflow_json, + inputs=inputs, + timeout_s=config["timeoutSeconds"], + trace=False, + cancel=cancel_event.is_set, + ) + if not isinstance(result, RunWorkflowResult): + result = RunWorkflowResult(status="failed", error="invalid_run_result") + status_value, error_message = resolve_execution_outcome(result) + if cancel_event.is_set() and status_value == "success": + status_value = "cancelled" + error_message = error_message or f"Run cancelled: run_id={result.run_id or exec_id}" + duration_ms = _now_ms() - started_at_ms + duration_s = max(0.0, time.time() - started_at_s) + summary = self._summarize_outputs(result.outputs) + exec_data.update({ + "outputResults": compact_outputs_for_storage(result.outputs), + "status": status_value, + "finishedAt": _now_ms(), + "duration": duration_s, + "executionLog": compact_history_for_storage(result.history), + "errorMessage": error_message, + "currentNodeId": result.last_node_id, + "currentPhase": status_value, + "currentStepIndex": result.steps, + }) + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current.update(summary) + current["lastRunAt"] = started_at_ms + current["lastDurationMs"] = duration_ms + current["lastRunId"] = result.run_id or exec_id + current["lastStatus"] = status_value + current["lastError"] = error_message + current["activeRuns"] = self._cleanup_done_runs(workflow_id) + if workflow_id in self._tasks and current.get("state") != "failed": + current["state"] = "running" + current["error"] = None + self._status[workflow_id] = current + except Exception as exc: + duration_ms = _now_ms() - started_at_ms + duration_s = max(0.0, time.time() - started_at_s) + status_value = "cancelled" if cancel_event.is_set() else "error" + finished_at_ms = _now_ms() + exec_data.update({ + "status": status_value, + "finishedAt": finished_at_ms, + "duration": duration_s, + "errorMessage": str(exc), + "executionLog": compact_history_for_storage(exec_data.get("executionLog")), + "currentPhase": status_value, + }) + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current["lastRunAt"] = started_at_ms + current["lastDurationMs"] = duration_ms + current["lastStatus"] = status_value + current["lastError"] = str(exc) + current["activeRuns"] = self._cleanup_done_runs(workflow_id) + if workflow_id in self._tasks and current.get("state") != "failed": + current["state"] = "running" + current["error"] = None + self._status[workflow_id] = current + log.warning("poller.run_failed", {"workflow_id": workflow_id, "error": str(exc)}) + finally: + try: + await record_execution_result(workflow_id, exec_id, exec_data) + except Exception as exc: + log.warning( + "poller.exec_record_failed", + {"workflow_id": workflow_id, "exec_id": exec_id, "error": str(exc)}, + ) + cancel_events.discard(cancel_event) + if not cancel_events: + self._run_cancel_events.pop(workflow_id, None) + current = self._status.get(workflow_id) or self._base_status(workflow_id) + current["activeRuns"] = self._cleanup_done_runs(workflow_id) + if workflow_id not in self._tasks and current.get("state") == "running": + current["state"] = "stopped" + current["nextRunAt"] = None + self._status[workflow_id] = current + + return self.get_status(workflow_id) + + +default_manager = WorkflowPollerManager() diff --git a/flocks/workflow/repl_runtime.py b/flocks/workflow/repl_runtime.py index dca38a334..5e43c8c2f 100644 --- a/flocks/workflow/repl_runtime.py +++ b/flocks/workflow/repl_runtime.py @@ -14,7 +14,7 @@ import uuid from concurrent.futures import TimeoutError as _FuturesTimeoutError from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Optional, TextIO, Tuple +from typing import Any, Callable, ClassVar, Dict, Optional, TextIO, Tuple from .errors import NodeExecutionError, RunCancelledError from .llm import get_lazy_llm @@ -49,6 +49,20 @@ class PythonExecRuntime(Runtime): globals: Dict[str, Any] = field(default_factory=dict) tool_registry: Optional[Any] = None # FlocksToolAdapter or compatible cancel_checker: Optional[Callable[[], bool]] = None + cleanup_globals_after_execute: bool = False + + _RUNTIME_GLOBAL_KEYS: ClassVar[frozenset[str]] = frozenset( + { + "__builtins__", + "inputs", + "outputs", + "cancelled", + "is_cancelled", + "llm", + "tool", + "get_path", + } + ) def execute(self, code: str, inputs: Dict[str, Any]) -> Tuple[Dict[str, Any], str]: if not isinstance(code, str): @@ -189,6 +203,10 @@ def _cancel_trace(_frame: Any, event: str, _arg: Any) -> Any: out_obj = {} if not isinstance(out_obj, dict): raise NodeExecutionError(node_id="", message="`outputs` must be a dict") + if self.cleanup_globals_after_execute: + for key in list(g.keys()): + if key not in self._RUNTIME_GLOBAL_KEYS: + g.pop(key, None) return out_obj, buf.getvalue() def reset(self) -> None: diff --git a/flocks/workflow/runner.py b/flocks/workflow/runner.py index 6588e4301..abed73d2f 100644 --- a/flocks/workflow/runner.py +++ b/flocks/workflow/runner.py @@ -296,6 +296,7 @@ def run_workflow( on_step_start: Optional[Any] = None, on_step_complete: Optional[Any] = None, max_parallel_workers: int = 4, + history_mode: Literal["full", "summary"] = "full", cancel: Optional[Callable[[], bool]] = None, ) -> RunWorkflowResult: # 确保日志已配置 @@ -414,14 +415,18 @@ def run_workflow( _logger.info("workflow runtime: host forced by sandbox.mode=off or runtime override") if ensure_requirements and reqs: (requirements_installer or RequirementsInstaller(installer="auto")).ensure_installed(reqs) - rt = PythonExecRuntime(tool_registry=registry) + rt = PythonExecRuntime( + tool_registry=registry, + cleanup_globals_after_execute=(history_mode == "summary"), + ) _logger.info( - "创建执行引擎 (use_llm=%s, trace=%s, node_timeout=%ss, parallel_workers=%s)", + "创建执行引擎 (use_llm=%s, trace=%s, node_timeout=%ss, parallel_workers=%s, history_mode=%s)", effective_use_llm, trace, effective_node_timeout_s, max_parallel_workers, + history_mode, ) engine = WorkflowEngine( wf, @@ -431,6 +436,7 @@ def run_workflow( workflow_path=workflow_path_for_engine, node_timeout_s=effective_node_timeout_s, max_parallel_workers=max_parallel_workers, + history_mode=history_mode, ) initial_inputs = _build_initial_inputs(inputs, workflow_path_for_engine) @@ -469,7 +475,9 @@ def _on_step_end(_token, step_result): if history_from_error and hasattr(history_from_error[0], 'model_dump'): history_from_error = [s.model_dump(mode="json") for s in history_from_error] - last_outputs = history_from_error[-1].get('outputs', {}) if history_from_error else {} + last_outputs = exec_ctx.get('outputs') or ( + history_from_error[-1].get('outputs', {}) if history_from_error else {} + ) status = "FAILED" if isinstance(e, RunCancelledError): @@ -488,7 +496,7 @@ def _on_step_end(_token, step_result): ) history = [s.model_dump(mode="json") for s in result.history] - last_outputs = result.history[-1].outputs if result.history else {} + last_outputs = result.outputs if result.outputs else (result.history[-1].outputs if result.history else {}) if cancel is not None and cancel(): return RunWorkflowResult( diff --git a/pyproject.toml b/pyproject.toml index 2ea779162..86b96896c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "flocks" -version = "v2026.5.27" +version = "v2026.6.3" description = "AI-Native SecOps platform with multi-agent collaboration" authors = [ {name = "Flocks Team", email = "team@example.com"} @@ -84,6 +84,8 @@ dependencies = [ "cdp-use>=1.4.5", "pillow>=12.2.0", "datasketch>=1.10.0", + # Kafka ingest (workflow input/output) + "aiokafka>=0.14.0", ] [dependency-groups] diff --git a/tests/agent/test_agent.py b/tests/agent/test_agent.py index ce38de5b4..54ebe3378 100644 --- a/tests/agent/test_agent.py +++ b/tests/agent/test_agent.py @@ -186,6 +186,23 @@ async def delegatable(name: str) -> bool: assert await delegatable("hephaestus") is True assert await delegatable("oracle") is True + @pytest.mark.asyncio + async def test_is_delegatable_respects_sidecar_override(self, tmp_path, monkeypatch): + settings_file = tmp_path / "agent_delegatable_settings.json" + monkeypatch.setattr("flocks.agent.delegatable_settings.settings_path", lambda: settings_file) + + import flocks.agent.delegatable_settings as delegatable_settings + from flocks.agent.registry import Agent as AgentRegistry, is_delegatable + + AgentRegistry._delegatable_settings_mtime = 0.0 + delegatable_settings.set_override("explore", False) + AgentRegistry.invalidate_cache() + + agent = await AgentRegistry.get("explore") + assert agent is not None + assert agent.delegatable is False + assert is_delegatable("explore") is False + @pytest.mark.asyncio async def test_list_names(self): names = await Agent.list_names() diff --git a/tests/agent/test_agent_factory.py b/tests/agent/test_agent_factory.py index 893135910..30ce129e6 100644 --- a/tests/agent/test_agent_factory.py +++ b/tests/agent/test_agent_factory.py @@ -24,6 +24,9 @@ scan_and_load, inject_dynamic_prompts, yaml_to_agent_info, + find_yaml_agent, + read_yaml_agent, + update_yaml_agent, delete_yaml_agent, _parse_prompt_metadata, ) @@ -778,6 +781,54 @@ def test_user_plugin_agent_is_not_native( assert result["custom-agent"].native is False +# =========================================================================== +# Project-level plugin agent CRUD +# =========================================================================== + +class TestProjectLevelYamlCrud: + """Project-level plugin agents should be readable and updatable.""" + + def _write_project_agent(self, root: Path, name: str) -> Path: + agent_dir = root / ".flocks" / "plugins" / "agents" / name + agent_dir.mkdir(parents=True) + yaml_path = agent_dir / "agent.yaml" + yaml_path.write_text( + textwrap.dedent(f"""\ + name: {name} + description: Project agent {name} + mode: subagent + delegatable: true + """), + encoding="utf-8", + ) + return yaml_path + + def test_find_and_read_project_level_agent( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + yaml_path = self._write_project_agent(tmp_path, "proj-editable") + monkeypatch.chdir(tmp_path) + + assert find_yaml_agent("proj-editable") == yaml_path + raw = read_yaml_agent("proj-editable") + assert raw is not None + assert raw["name"] == "proj-editable" + assert raw["delegatable"] is True + + def test_update_project_level_agent_yaml( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + yaml_path = self._write_project_agent(tmp_path, "proj-toggle") + monkeypatch.chdir(tmp_path) + + result = update_yaml_agent("proj-toggle", {"delegatable": False, "temperature": 0.2}) + + assert result is True + updated = yaml_path.read_text(encoding="utf-8") + assert "delegatable: false" in updated + assert "temperature: 0.2" in updated + + # =========================================================================== # delete_yaml_agent — subdirectory layout # =========================================================================== diff --git a/tests/channel/test_channel.py b/tests/channel/test_channel.py index 94be5144c..90d191607 100644 --- a/tests/channel/test_channel.py +++ b/tests/channel/test_channel.py @@ -570,6 +570,8 @@ async def fake_deliver(ctx, session_id=None): ) ) monkeypatch.setattr("flocks.session.session.Session.create", create_mock) + update_mock = AsyncMock(return_value=None) + monkeypatch.setattr("flocks.session.session.Session.update", update_mock) handled = await dispatcher._handle_feishu_native_command( binding=binding, @@ -587,6 +589,11 @@ async def fake_deliver(ctx, session_id=None): assert create_kwargs["title"] == "[Feishu] oc_group" assert "session_new" in delivered[0] assert "已开始全新对话。" in delivered[0] + # The previous session must be archived so it no longer appears in the + # active IM session list used for scheduled-task target resolution. + update_mock.assert_awaited_once() + assert update_mock.await_args.args == ("channel", "session_old") + assert update_mock.await_args.kwargs["status"] == "archived" @pytest.mark.asyncio async def test_reset_alias_matches_new_semantics(self, monkeypatch): @@ -727,6 +734,61 @@ async def fake_deliver(ctx, session_id=None): assert delivered assert "Available / commands:" in delivered[0] + @pytest.mark.asyncio + async def test_clear_command_clears_channel_session_history(self, monkeypatch): + from flocks.channel.inbound.dispatcher import InboundDispatcher + from flocks.channel.inbound.session_binding import SessionBinding + + dispatcher = InboundDispatcher() + binding = SessionBinding( + channel_id="wecom", + account_id="default", + chat_id="room_1", + chat_type=ChatType.DIRECT, + thread_id=None, + session_id="session_1", + agent_id="rex", + created_at=0, + last_message_at=0, + ) + msg = InboundMessage( + channel_id="wecom", + account_id="default", + message_id="msg_1", + sender_id="user_1", + chat_id="room_1", + chat_type=ChatType.DIRECT, + text="/clear", + mention_text="/clear", + ) + + delivered: list[str] = [] + clear_history = AsyncMock(return_value=3) + + async def fake_deliver(ctx, session_id=None): + delivered.append(ctx.text) + + monkeypatch.setattr( + "flocks.channel.outbound.deliver.OutboundDelivery.deliver", + fake_deliver, + ) + monkeypatch.setattr( + "flocks.server.routes.session._clear_session_history", + clear_history, + ) + + handled = await dispatcher._handle_feishu_native_command( + binding=binding, + msg=msg, + channel_config=ChannelConfig(enabled=True), + user_text="/clear", + scope_override=None, + ) + + assert handled is True + clear_history.assert_awaited_once_with("session_1") + assert delivered == ["已清空当前会话历史,共删除 3 条消息。"] + @pytest.mark.asyncio async def test_append_user_message_stores_feishu_media_part(self, monkeypatch): from flocks.channel.inbound.dispatcher import InboundDispatcher diff --git a/tests/channel/test_feishu.py b/tests/channel/test_feishu.py index 5b2e24b63..7681b7778 100644 --- a/tests/channel/test_feishu.py +++ b/tests/channel/test_feishu.py @@ -22,7 +22,11 @@ from flocks.channel.builtin.feishu.dedup import FeishuDedup from flocks.channel.builtin.feishu.inbound_media import download_inbound_media from flocks.channel.builtin.feishu.media import send_media_feishu -from flocks.channel.builtin.feishu.monitor import _build_ws_client, _start_single_websocket +from flocks.channel.builtin.feishu.monitor import ( + _build_ws_client, + _start_single_websocket, + start_websocket, +) from flocks.channel.builtin.feishu.send import send_message_feishu from flocks.channel.inbound.dispatcher import _resolve_feishu_group_overrides from flocks.config.config import ChannelConfig, FeishuGroupConfig @@ -390,20 +394,24 @@ async def test_handle_webhook_skips_replayed_requests(monkeypatch) -> None: @pytest.mark.asyncio async def test_websocket_card_action_events_are_deduplicated(monkeypatch, tmp_path) -> None: abort_event = asyncio.Event() + loop = asyncio.get_running_loop() on_message = AsyncMock() dedup = FeishuDedup(account_id="main", data_dir=tmp_path) class FakeWSClient: def __init__(self, app_id, app_secret, event_handler, log_level): self._event_handler = event_handler + self._stop_event = threading.Event() def start(self): payload = {"header": {"event_type": "card.action.trigger"}, "event": {}} self._event_handler(payload) self._event_handler(payload) - asyncio.get_running_loop().call_later(0.05, abort_event.set) + loop.call_soon_threadsafe(abort_event.set) + self._stop_event.wait(timeout=1) def stop(self): + self._stop_event.set() return None fake_lark = types.ModuleType("lark_oapi") @@ -452,6 +460,45 @@ def stop(self): assert on_message.await_count == 1 +@pytest.mark.asyncio +async def test_build_ws_client_adapter_branch_surfaces_disconnect_error(monkeypatch) -> None: + class FakeWSClient: + def __init__(self, app_id, app_secret, event_handler, log_level): + self._stop_event = threading.Event() + + def start(self): + self._stop_event.wait(timeout=0.05) + raise RuntimeError("socket closed") + + def stop(self): + self._stop_event.set() + + fake_lark = types.ModuleType("lark_oapi") + fake_lark.LogLevel = types.SimpleNamespace(WARNING="warning") + fake_adapter = types.ModuleType("lark_oapi.adapter") + fake_websocket = types.ModuleType("lark_oapi.adapter.websocket") + fake_websocket.WSClient = FakeWSClient + + monkeypatch.setitem(sys.modules, "lark_oapi", fake_lark) + monkeypatch.setitem(sys.modules, "lark_oapi.adapter", fake_adapter) + monkeypatch.setitem(sys.modules, "lark_oapi.adapter.websocket", fake_websocket) + + ws_client = _build_ws_client( + app_id="app-id", + app_secret="app-secret", + event_handler=lambda _data: None, + domain="https://open.feishu.cn", + ) + + ws_client.start() + disconnect_error = await asyncio.wait_for(ws_client.wait_disconnected(), timeout=1.0) + ws_client.stop() + + assert isinstance(disconnect_error, RuntimeError) + assert str(disconnect_error) == "socket closed" + assert ws_client.disconnected_error is disconnect_error + + def test_build_ws_client_falls_back_to_modern_sdk(monkeypatch) -> None: dispatched: list[dict] = [] captured: dict[str, object] = {} @@ -522,6 +569,75 @@ def fake_run_coroutine_threadsafe(coro, loop): assert dispatched == [{"header": {"event_type": "ping"}}] +@pytest.mark.asyncio +async def test_build_ws_client_surfaces_disconnect_error(monkeypatch) -> None: + captured: dict[str, object] = {} + + class _FakeConnection: + async def recv(self): + raise RuntimeError("ping timeout") + + class _FakeClient: + def __init__(self, **kwargs): + captured.update(kwargs) + captured["client"] = self + self._conn = None + self._auto_reconnect = kwargs["auto_reconnect"] + self.disconnect_calls = 0 + + def start(self): + loop = asyncio.get_event_loop() + self._conn = _FakeConnection() + loop.create_task(self._receive_message_loop()) + loop.run_forever() + + async def _handle_message(self, _msg): + return None + + async def _disconnect(self): + self.disconnect_calls += 1 + self._conn = None + + fake_lark = types.ModuleType("lark_oapi") + fake_lark.LogLevel = types.SimpleNamespace(WARNING="warning") + fake_ws_client = types.ModuleType("lark_oapi.ws.client") + fake_ws_client.Client = _FakeClient + fake_ws_client.loop = None + + real_import_module = __import__("importlib").import_module + + def fake_import_module(name, package=None): + if name == "lark_oapi": + return fake_lark + if name == "lark_oapi.ws.client": + return fake_ws_client + if name == "lark_oapi.adapter.websocket": + raise ImportError("legacy websocket adapter missing") + return real_import_module(name, package) + + monkeypatch.setattr( + "flocks.channel.builtin.feishu.monitor.importlib.import_module", + fake_import_module, + ) + + ws_client = _build_ws_client( + app_id="app-id", + app_secret="app-secret", + event_handler=lambda _data: None, + domain="https://open.feishu.cn", + ) + + ws_client.start() + disconnect_error = await asyncio.wait_for(ws_client.wait_disconnected(), timeout=1.0) + ws_client.stop() + + fake_client = captured["client"] + assert isinstance(disconnect_error, RuntimeError) + assert str(disconnect_error) == "ping timeout" + assert ws_client.disconnected_error is disconnect_error + assert fake_client.disconnect_calls >= 1 + + def test_build_ws_client_ignores_normal_close_during_stop(monkeypatch) -> None: captured: dict[str, object] = {} @@ -816,6 +932,201 @@ def fake_import_module(name, package=None): assert ws_client._client is None +@pytest.mark.asyncio +async def test_start_single_websocket_raises_when_ws_disconnects(monkeypatch) -> None: + on_message = AsyncMock() + stop_calls: list[str] = [] + + class _FakeWSClient: + start_error = None + + def start(self): + return None + + async def wait_disconnected(self): + return RuntimeError("keepalive ping timeout") + + def stop(self): + stop_calls.append("stop") + + monkeypatch.setattr( + "flocks.channel.builtin.feishu.identity.get_bot_identity", + AsyncMock(return_value=("ou_bot", "bot")), + ) + monkeypatch.setattr( + "flocks.channel.builtin.feishu.dedup.get_dedup", + AsyncMock(return_value=_FakeDedup()), + ) + monkeypatch.setattr( + "flocks.channel.builtin.feishu.monitor._build_ws_client", + lambda **kwargs: _FakeWSClient(), + ) + + with pytest.raises(RuntimeError, match="disconnected for account 'main'"): + await _start_single_websocket( + { + "appId": "main-id", + "appSecret": "main-secret", + "_account_id": "main", + }, + on_message, + asyncio.Event(), + ) + + assert stop_calls == ["stop"] + + +@pytest.mark.asyncio +async def test_start_single_websocket_abort_wins_over_disconnect_waiter(monkeypatch) -> None: + on_message = AsyncMock() + stop_calls: list[str] = [] + abort_event = asyncio.Event() + + class _FakeWSClient: + start_error = None + + def start(self): + asyncio.get_running_loop().call_soon(abort_event.set) + + async def wait_disconnected(self): + await asyncio.Event().wait() + + def stop(self): + stop_calls.append("stop") + + monkeypatch.setattr( + "flocks.channel.builtin.feishu.identity.get_bot_identity", + AsyncMock(return_value=("ou_bot", "bot")), + ) + monkeypatch.setattr( + "flocks.channel.builtin.feishu.dedup.get_dedup", + AsyncMock(return_value=_FakeDedup()), + ) + monkeypatch.setattr( + "flocks.channel.builtin.feishu.monitor._build_ws_client", + lambda **kwargs: _FakeWSClient(), + ) + + await _start_single_websocket( + { + "appId": "main-id", + "appSecret": "main-secret", + "_account_id": "main", + }, + on_message, + abort_event, + ) + + assert stop_calls == ["stop"] + + +@pytest.mark.asyncio +async def test_start_websocket_keeps_other_accounts_running_after_one_fails(monkeypatch) -> None: + events: list[str] = [] + abort_event = asyncio.Event() + + async def fake_start_single_websocket(config, on_message, abort_event=None): + account_id = config["_account_id"] + if account_id == "main": + await asyncio.sleep(0) + events.append("main_failed") + raise RuntimeError("socket closed") + events.append("backup_started") + await abort_event.wait() + events.append("backup_stopped") + + monkeypatch.setattr( + "flocks.channel.builtin.feishu.monitor._start_single_websocket", + fake_start_single_websocket, + ) + + task = asyncio.create_task(start_websocket( + { + "accounts": { + "main": { + "appId": "main-id", + "appSecret": "main-secret", + "connectionMode": "websocket", + }, + "backup": { + "appId": "backup-id", + "appSecret": "backup-secret", + "connectionMode": "websocket", + }, + } + }, + AsyncMock(), + abort_event, + )) + for _ in range(10): + await asyncio.sleep(0) + if "main_failed" in events: + break + + assert "main_failed" in events + assert "backup_started" in events + assert not task.done() + + abort_event.set() + await asyncio.wait_for(task, timeout=1.0) + + assert "backup_stopped" in events + + +@pytest.mark.asyncio +async def test_start_websocket_restarts_failed_account_while_sibling_runs(monkeypatch) -> None: + events: list[str] = [] + abort_event = asyncio.Event() + restart_seen = asyncio.Event() + attempts: dict[str, int] = {"main": 0, "backup": 0} + + async def fake_start_single_websocket(config, on_message, abort_event=None): + account_id = config["_account_id"] + attempts[account_id] += 1 + events.append(f"{account_id}_start_{attempts[account_id]}") + if account_id == "main" and attempts[account_id] == 1: + raise RuntimeError("socket closed") + if account_id == "main": + restart_seen.set() + await abort_event.wait() + events.append(f"{account_id}_stopped") + + monkeypatch.setattr( + "flocks.channel.builtin.feishu.monitor._start_single_websocket", + fake_start_single_websocket, + ) + + task = asyncio.create_task(start_websocket( + { + "websocketReconnectDelaySeconds": 0, + "accounts": { + "main": { + "appId": "main-id", + "appSecret": "main-secret", + "connectionMode": "websocket", + }, + "backup": { + "appId": "backup-id", + "appSecret": "backup-secret", + "connectionMode": "websocket", + }, + }, + }, + AsyncMock(), + abort_event, + )) + + await asyncio.wait_for(restart_seen.wait(), timeout=1.0) + + assert attempts["main"] == 2 + assert attempts["backup"] == 1 + assert "backup_start_1" in events + assert not task.done() + + abort_event.set() + await asyncio.wait_for(task, timeout=1.0) + + @pytest.mark.asyncio async def test_parse_reaction_event_falls_back_to_user_id(monkeypatch) -> None: from flocks.channel.builtin.feishu.monitor import _parse_reaction_event diff --git a/tests/cli/test_service_manager.py b/tests/cli/test_service_manager.py index cc74e9a10..86747f900 100644 --- a/tests/cli/test_service_manager.py +++ b/tests/cli/test_service_manager.py @@ -1357,6 +1357,27 @@ def fake_popen(*args, **kwargs): assert "startupinfo" not in captured["kwargs"] +def test_spawn_process_rotates_large_log_before_append(monkeypatch, tmp_path: Path) -> None: + log_path = tmp_path / "logs" / "backend.log" + log_path.parent.mkdir(parents=True) + log_path.write_text("x" * 20, encoding="utf-8") + + def fake_popen(*args, **kwargs): + kwargs["stdout"].write("new\n") + kwargs["stdout"].flush() + return SimpleNamespace(pid=9876) + + monkeypatch.setenv("FLOCKS_LOG_MAX_BYTES", "10") + monkeypatch.setenv("FLOCKS_LOG_BACKUP_COUNT", "1") + monkeypatch.setattr(service_manager.sys, "platform", "darwin") + monkeypatch.setattr(service_manager.subprocess, "Popen", fake_popen) + + service_manager._spawn_process(["python", "-m", "uvicorn"], cwd=tmp_path, log_path=log_path) + + assert log_path.read_text(encoding="utf-8") == "new\n" + assert (tmp_path / "logs" / "backend.log.1").read_text(encoding="utf-8") == "x" * 20 + + def test_spawn_process_passes_custom_environment(monkeypatch, tmp_path: Path) -> None: captured = {} log_path = tmp_path / "logs" / "backend.log" diff --git a/tests/ingest/test_kafka_manager.py b/tests/ingest/test_kafka_manager.py new file mode 100644 index 000000000..bdf21378f --- /dev/null +++ b/tests/ingest/test_kafka_manager.py @@ -0,0 +1,359 @@ +"""Unit tests for the Kafka → workflow ingest pipeline. + +These tests exercise :class:`KafkaManager` in isolation (no real broker) by +driving the bounded queue and worker pool directly, plus the connection-failure +path of ``restart_workflow``. They verify the same backpressure invariants as +the syslog manager: + +1. A fixed worker pool bounds the number of in-flight workflow dispatches. +2. ``stop_workflow`` cancels and drains the worker pool cleanly. +3. A consumer that cannot connect surfaces ``state == "failed"`` instead of + pretending to be running. +""" + +from __future__ import annotations + +import asyncio +import json +import sys +from types import SimpleNamespace + +import pytest + +from flocks.ingest.kafka import manager as kafka_manager + + +@pytest.mark.asyncio +async def test_worker_pool_bounds_in_flight_dispatches(monkeypatch: pytest.MonkeyPatch) -> None: + """The fixed worker pool must cap concurrent ``_trigger_workflow`` calls.""" + + manager = kafka_manager.KafkaManager() + pool_size = kafka_manager._MAX_CONCURRENT_EXECUTIONS + + in_flight = 0 + max_in_flight = 0 + completed = 0 + lock = asyncio.Lock() + + async def _fake_trigger(workflow_id, workflow_json, msg, input_key, producer=None, output_topic=""): # noqa: ANN001 + nonlocal in_flight, max_in_flight, completed + async with lock: + in_flight += 1 + if in_flight > max_in_flight: + max_in_flight = in_flight + await asyncio.sleep(0.01) + async with lock: + in_flight -= 1 + completed += 1 + + monkeypatch.setattr(manager, "_trigger_workflow", _fake_trigger) + + workflow_id = "test-wf" + queue: asyncio.Queue = asyncio.Queue(maxsize=kafka_manager._MAX_QUEUE_SIZE) + abort = asyncio.Event() + + manager._queues[workflow_id] = queue + manager._abort_events[workflow_id] = abort + workers = [ + asyncio.create_task( + manager._worker_loop(workflow_id, {}, "kafka_message", {}, queue, abort), + name=f"test-worker-{i}", + ) + for i in range(pool_size) + ] + manager._worker_pools[workflow_id] = workers + + burst_size = pool_size * 6 + for i in range(burst_size): + queue.put_nowait({"_seq": i}) + + deadline = asyncio.get_event_loop().time() + 5.0 + while completed < burst_size and asyncio.get_event_loop().time() < deadline: + await asyncio.sleep(0.02) + + abort.set() + for w in workers: + w.cancel() + await asyncio.gather(*workers, return_exceptions=True) + + assert completed == burst_size, f"expected {burst_size} dispatches, got {completed}" + assert max_in_flight <= pool_size, ( + f"in-flight dispatches exceeded worker pool size: " + f"max_in_flight={max_in_flight}, pool_size={pool_size}" + ) + + +@pytest.mark.asyncio +async def test_worker_decodes_queued_raw_message(monkeypatch: pytest.MonkeyPatch) -> None: + """Kafka workers should decode raw bytes only when a worker is ready.""" + + manager = kafka_manager.KafkaManager() + workflow_id = "test-wf-raw-queue" + queue: asyncio.Queue = asyncio.Queue(maxsize=8) + abort = asyncio.Event() + captured: list[dict] = [] + + async def _fake_trigger(workflow_id, workflow_json, msg, input_key, producer=None, output_topic=""): # noqa: ANN001 + captured.append(msg) + abort.set() + + monkeypatch.setattr(manager, "_trigger_workflow", _fake_trigger) + queue.put_nowait( + kafka_manager._QueuedKafkaMessage( # noqa: SLF001 + raw_value=b'{"ok": true}', + size_bytes=len(b'{"ok": true}'), + ) + ) + + worker = asyncio.create_task( + manager._worker_loop(workflow_id, {}, "kafka_message", {}, queue, abort), + name="test-worker-raw-queue", + ) + await asyncio.wait_for(worker, timeout=1.0) + + assert captured == [{"ok": True}] + + +@pytest.mark.asyncio +async def test_stop_workflow_cancels_worker_pool() -> None: + """``stop_workflow`` must cancel and drain the worker pool cleanly.""" + + manager = kafka_manager.KafkaManager() + workflow_id = "test-wf-stop" + queue: asyncio.Queue = asyncio.Queue(maxsize=8) + abort = asyncio.Event() + manager._queues[workflow_id] = queue + manager._abort_events[workflow_id] = abort + manager._status[workflow_id] = {"state": "running", "error": None} + + async def _noop_trigger(*args, **kwargs): # noqa: ANN001 + return None + + manager._trigger_workflow = _noop_trigger # type: ignore[assignment] + + workers = [ + asyncio.create_task( + manager._worker_loop(workflow_id, {}, "kafka_message", {}, queue, abort), + name=f"stop-worker-{i}", + ) + for i in range(3) + ] + manager._worker_pools[workflow_id] = workers + + await asyncio.sleep(0.05) + await manager.stop_workflow(workflow_id) + + for w in workers: + assert w.done(), "stop_workflow must terminate every worker in the pool" + assert workflow_id not in manager._worker_pools + assert workflow_id not in manager._queues + assert manager._status[workflow_id]["state"] == "stopped" + + +@pytest.mark.asyncio +async def test_restart_disabled_config_reports_stopped(monkeypatch: pytest.MonkeyPatch) -> None: + """A disabled (or missing) config must leave the consumer ``stopped``.""" + + manager = kafka_manager.KafkaManager() + + async def _fake_read(key): # noqa: ANN001 + return {"enabled": False} + + monkeypatch.setattr(kafka_manager.Storage, "read", _fake_read) + + status = await manager.restart_workflow("wf-disabled") + assert status == {"state": "stopped", "error": None} + + +@pytest.mark.asyncio +async def test_restart_missing_broker_reports_failed(monkeypatch: pytest.MonkeyPatch) -> None: + """Enabled config without broker/topic must fail fast (no real connect).""" + + manager = kafka_manager.KafkaManager() + + async def _fake_read(key): # noqa: ANN001 + return {"enabled": True, "inputBroker": "", "inputTopic": ""} + + monkeypatch.setattr(kafka_manager.Storage, "read", _fake_read) + + status = await manager.restart_workflow("wf-no-broker") + assert status["state"] == "failed" + assert status["error"] == "missing_input_broker_or_topic" + + +@pytest.mark.asyncio +async def test_restart_workflow_cleans_resources_after_connect_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A failed consumer start must not leave workers or producers behind.""" + + manager = kafka_manager.KafkaManager() + workflow_id = "wf-connect-failed" + + async def _fake_read(key): # noqa: ANN001 + return { + "enabled": True, + "inputBroker": "localhost:9092", + "inputTopic": "workflow-input", + "inputGroupId": "wf-group", + "inputKey": "kafka_message", + } + + class _Consumer: + def __init__(self, *args, **kwargs) -> None: # noqa: ANN002, ANN003 + self.stopped = False + + async def start(self) -> None: + raise RuntimeError("broker unreachable") + + async def stop(self) -> None: + self.stopped = True + + monkeypatch.setattr(kafka_manager.Storage, "read", _fake_read) + monkeypatch.setattr( + kafka_manager, + "read_workflow_from_fs", + lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + ) + monkeypatch.setitem(sys.modules, "aiokafka", SimpleNamespace(AIOKafkaConsumer=_Consumer)) + + status = await manager.restart_workflow(workflow_id) + + assert status["state"] == "failed" + assert status["error"] == "broker unreachable" + assert workflow_id not in manager._tasks + assert workflow_id not in manager._worker_pools + assert workflow_id not in manager._queues + assert workflow_id not in manager._abort_events + + +def test_decode_message_variants() -> None: + """``_decode_message`` decodes JSON, falls back to text, then hex.""" + + assert kafka_manager._decode_message(b'{"a": 1}') == {"a": 1} + assert kafka_manager._decode_message(b"plain text") == "plain text" + assert kafka_manager._decode_message(None) is None + # Invalid UTF-8 bytes fall back to a hex repr. + assert kafka_manager._decode_message(b"\xff\xfe") == "fffe" + + +@pytest.mark.asyncio +async def test_trigger_workflow_compacts_kafka_execution_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Kafka-triggered execution rows should not retain full raw alert bodies.""" + + manager = kafka_manager.KafkaManager() + captured_input_params: dict = {} + captured_exec_data: dict = {} + captured_run_kwargs: dict = {} + + async def _fake_create_execution_record(workflow_id, *, input_params=None, exec_id=None): # noqa: ANN001 + captured_input_params.update(input_params or {}) + return {"id": "exec-compact", "workflowId": workflow_id, "inputParams": input_params} + + async def _fake_record_execution_result(workflow_id, exec_id, exec_data): # noqa: ANN001 + captured_exec_data.update(exec_data) + + def _fake_run_workflow(**kwargs): # noqa: ANN003 + captured_run_kwargs.update(kwargs) + large_alert = {"raw_log_id": "alert-1", "req_body": "x" * 50_000} + return SimpleNamespace( + status="SUCCEEDED", + error=None, + outputs={ + "enriched_alerts": [large_alert], + "kafka_messages": [{"raw_log_id": "alert-1"}], + }, + history=[ + { + "node_id": "receive_alert", + "inputs": {"kafka_message": {"alarmData": "x" * 50_000}}, + "outputs": {"raw_alerts": [large_alert]}, + }, + { + "node_id": "dedup_and_write", + "inputs": {"filtered_alerts": [large_alert]}, + "outputs": {"enriched_alerts": [large_alert]}, + }, + ], + last_node_id="done", + steps=2, + ) + + monkeypatch.setattr(kafka_manager, "create_execution_record", _fake_create_execution_record) + monkeypatch.setattr(kafka_manager, "record_execution_result", _fake_record_execution_result) + monkeypatch.setattr(kafka_manager, "run_workflow", _fake_run_workflow) + + await manager._trigger_workflow( + "wf-compact", + {"start": "receive_alert", "nodes": [], "edges": []}, + {"alarmData": "x" * 50_000}, + "kafka_message", + ) + + assert captured_input_params["kafka_message"]["alarmData"]["_type"] == "string" + assert captured_input_params["kafka_message"]["alarmData"]["chars"] == 50_000 + assert captured_run_kwargs["history_mode"] == "summary" + assert captured_exec_data["outputResults"] == { + "_enriched_alerts_count": 1, + "_kafka_messages_count": 1, + } + assert captured_exec_data["executionLog"][0]["outputs"] == {"_raw_alerts_count": 1} + assert captured_exec_data["executionLog"][1]["inputs"] == {"_filtered_alerts_count": 1} + assert len(json.dumps(captured_exec_data, ensure_ascii=False)) < 10_000 + + +@pytest.mark.asyncio +async def test_trigger_workflow_merges_configured_inputs_with_consumed_message( + monkeypatch: pytest.MonkeyPatch, +) -> None: + manager = kafka_manager.KafkaManager() + captured_run_kwargs: dict = {} + recorded_input_params: dict = {} + + async def _fake_create_execution_record(workflow_id, *, input_params=None, exec_id=None): # noqa: ANN001 + recorded_input_params.update(input_params or {}) + return {"id": "exec-merge", "workflowId": workflow_id, "inputParams": input_params} + + async def _fake_record_execution_result(workflow_id, exec_id, exec_data): # noqa: ANN001 + return None + + def _fake_run_workflow(**kwargs): # noqa: ANN003 + captured_run_kwargs.update(kwargs) + return SimpleNamespace( + status="SUCCEEDED", + error=None, + outputs={"ok": True}, + history=[], + last_node_id="done", + steps=1, + ) + + monkeypatch.setattr(kafka_manager, "create_execution_record", _fake_create_execution_record) + monkeypatch.setattr(kafka_manager, "record_execution_result", _fake_record_execution_result) + monkeypatch.setattr(kafka_manager, "run_workflow", _fake_run_workflow) + + await manager._trigger_workflow( + "wf-merge", + {"start": "receive_alert", "nodes": [], "edges": []}, + {"alarmData": {"id": 1}}, + "kafka_message", + { + "_comment": "remove me", + "kafka_message": {"should": "be overridden"}, + "kafka_output_enabled": True, + "kafka_output_topic": "topic_soc_flocks_result_log", + }, + ) + + assert captured_run_kwargs["inputs"] == { + "kafka_message": {"alarmData": {"id": 1}}, + "kafka_output_enabled": True, + "kafka_output_topic": "topic_soc_flocks_result_log", + } + assert recorded_input_params["_trigger"] == "kafka" + assert recorded_input_params["kafka_output_enabled"] is True + assert recorded_input_params["kafka_output_topic"] == "topic_soc_flocks_result_log" + assert recorded_input_params["kafka_message"]["_type"] == "dict" + assert recorded_input_params["kafka_message"]["keys"] == ["alarmData"] diff --git a/tests/provider/test_chinese_providers.py b/tests/provider/test_chinese_providers.py index 8938fc501..d0faba837 100644 --- a/tests/provider/test_chinese_providers.py +++ b/tests/provider/test_chinese_providers.py @@ -208,12 +208,23 @@ def test_zhipu_catalog(self): def test_minimax_catalog(self): models = get_provider_model_definitions("minimax") assert {m.id for m in models} == { + "minimax-m3", "minimax-m2.7", "minimax-m2.5", } + m3 = next(m for m in models if m.id == "minimax-m3") + assert m3.capabilities.supports_reasoning is True + assert m3.capabilities.interleaved["field"] == "reasoning_details" + assert m3.limits.context_window == 512000 + assert m3.limits.max_output_tokens == 512000 m27 = next(m for m in models if m.id == "minimax-m2.7") assert m27.capabilities.supports_reasoning is True assert m27.capabilities.interleaved["field"] == "reasoning_details" + assert m27.limits.context_window == 196608 + assert m27.limits.max_output_tokens == 128000 + m25 = next(m for m in models if m.id == "minimax-m2.5") + assert m25.limits.context_window == 196608 + assert m25.limits.max_output_tokens == 128000 def test_stepfun_catalog(self): models = get_provider_model_definitions("stepfun") diff --git a/tests/provider/test_openai_compatible_provider.py b/tests/provider/test_openai_compatible_provider.py index bab862dd4..80a16f1d9 100644 --- a/tests/provider/test_openai_compatible_provider.py +++ b/tests/provider/test_openai_compatible_provider.py @@ -119,11 +119,12 @@ async def test_chat_passes_explicit_temperature(self): class TestOpenAICompatibleProviderMiniMaxFallback: - def test_is_minimax_empty_response_target_matches_supported_aliases(self): + def test_is_minimax_empty_response_target_matches_all_minimax_aliases(self): assert OpenAICompatibleProvider._is_minimax_empty_response_target("MiniMax-M2.5") is True assert OpenAICompatibleProvider._is_minimax_empty_response_target("minimax_m2.7") is True + assert OpenAICompatibleProvider._is_minimax_empty_response_target("minimax-m3") is True assert OpenAICompatibleProvider._is_minimax_empty_response_target("custom-minimax-m2.5-prod") is True - assert OpenAICompatibleProvider._is_minimax_empty_response_target("foo/minimax-m2.7-202506") is True + assert OpenAICompatibleProvider._is_minimax_empty_response_target("foo/minimax-m3-202506") is True assert OpenAICompatibleProvider._is_minimax_empty_response_target("gpt-4o-mini") is False @pytest.mark.asyncio diff --git a/tests/provider/test_provider.py b/tests/provider/test_provider.py index 0147e8b99..0a6155c37 100644 --- a/tests/provider/test_provider.py +++ b/tests/provider/test_provider.py @@ -187,7 +187,7 @@ def test_resolve_model_does_not_infer_interleaved_for_non_reasoning_model(monkey ("openai-compatible", "kimi-k2-thinking-turbo", "https://api.example.com/v1", "reasoning_content"), ("openai-compatible", "deepseek-v4-pro", "https://api.deepseek.com/v1", "reasoning_content"), ("openai-compatible", "glm-4.7", "https://api.example.com/v1", "reasoning_content"), - ("openai-compatible", "minimax-m2.1", "https://api.example.com/v1", "reasoning_details"), + ("openai-compatible", "minimax-m3", "https://api.example.com/v1", "reasoning_details"), ("openai-compatible", "gemini-3.1-pro-preview", "https://api.example.com/v1", "reasoning_details"), ("openai-compatible", "step-3.5-flash", "https://api.example.com/v1", "reasoning_content"), ("google-vertex-anthropic", "claude-sonnet-4-6", "https://example.com", "thinking"), diff --git a/tests/server/routes/test_agent_routes.py b/tests/server/routes/test_agent_routes.py index 0fa1f98b2..d222ec0d3 100644 --- a/tests/server/routes/test_agent_routes.py +++ b/tests/server/routes/test_agent_routes.py @@ -13,6 +13,9 @@ from __future__ import annotations +import json +from pathlib import Path + import pytest from fastapi import status from httpx import AsyncClient @@ -27,6 +30,26 @@ "prompt": "You are a test assistant.", } +_SUBAGENT_PAYLOAD = { + **_AGENT_PAYLOAD, + "name": "test-subagent", + "mode": "subagent", +} + + +@pytest.fixture(autouse=True) +def _isolated_delegatable_settings(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + settings_file = tmp_path / "agent_delegatable_settings.json" + monkeypatch.setattr("flocks.agent.delegatable_settings.settings_path", lambda: settings_file) + + from flocks.agent.registry import Agent + + Agent._delegatable_settings_mtime = 0.0 + Agent.invalidate_cache() + yield settings_file + Agent._delegatable_settings_mtime = 0.0 + Agent.invalidate_cache() + # =========================================================================== # List @@ -130,6 +153,13 @@ async def test_created_agent_survives_registry_reload(self, client: AsyncClient) assert list_resp.status_code == status.HTTP_200_OK assert "test-agent" in [agent["name"] for agent in list_resp.json()] + @pytest.mark.asyncio + async def test_create_subagent_defaults_to_delegatable(self, client: AsyncClient): + """Sub-agents default to delegatable=true when the field is omitted.""" + resp = await client.post("/api/agent", json=_SUBAGENT_PAYLOAD) + assert resp.status_code == status.HTTP_200_OK + assert resp.json()["delegatable"] is True + # =========================================================================== # Update @@ -157,6 +187,116 @@ async def test_update_nonexistent_agent_returns_404(self, client: AsyncClient): ) assert resp.status_code == status.HTTP_404_NOT_FOUND + @pytest.mark.asyncio + async def test_update_subagent_delegatable(self, client: AsyncClient): + """PUT /api/agent/{name} can disable delegation for a sub-agent.""" + create_resp = await client.post("/api/agent", json=_SUBAGENT_PAYLOAD) + assert create_resp.status_code == status.HTTP_200_OK + assert create_resp.json()["delegatable"] is True + + resp = await client.put( + "/api/agent/test-subagent", + json={"delegatable": False}, + ) + assert resp.status_code == status.HTTP_200_OK + assert resp.json()["delegatable"] is False + + get_resp = await client.get("/api/agent/test-subagent") + assert get_resp.status_code == status.HTTP_200_OK + assert get_resp.json()["delegatable"] is False + + @pytest.mark.asyncio + async def test_update_subagent_delegatable_survives_registry_reload(self, client: AsyncClient): + """Storage-backed delegatable updates survive a fresh registry load.""" + from flocks.agent.registry import Agent + + create_resp = await client.post("/api/agent", json=_SUBAGENT_PAYLOAD) + assert create_resp.status_code == status.HTTP_200_OK + + update_resp = await client.put( + "/api/agent/test-subagent", + json={"delegatable": False}, + ) + assert update_resp.status_code == status.HTTP_200_OK + assert update_resp.json()["delegatable"] is False + + Agent._custom_agents.clear() + Agent.invalidate_cache() + + get_resp = await client.get("/api/agent/test-subagent") + assert get_resp.status_code == status.HTTP_200_OK + assert get_resp.json()["delegatable"] is False + + @pytest.mark.asyncio + async def test_patch_delegatable_updates_storage_custom_agent_without_sidecar( + self, + client: AsyncClient, + _isolated_delegatable_settings: Path, + ): + create_resp = await client.post("/api/agent", json=_SUBAGENT_PAYLOAD) + assert create_resp.status_code == status.HTTP_200_OK + + patch_resp = await client.patch( + "/api/agent/test-subagent/delegatable", + json={"delegatable": False}, + ) + assert patch_resp.status_code == status.HTTP_200_OK + assert patch_resp.json()["delegatable"] is False + + get_resp = await client.get("/api/agent/test-subagent") + assert get_resp.status_code == status.HTTP_200_OK + assert get_resp.json()["delegatable"] is False + + if _isolated_delegatable_settings.exists(): + payload = json.loads(_isolated_delegatable_settings.read_text(encoding="utf-8")) + assert payload.get("delegatable_overrides", {}).get("test-subagent") is None + + @pytest.mark.asyncio + async def test_patch_delegatable_overrides_builtin_agent_without_rewriting_yaml( + self, + client: AsyncClient, + _isolated_delegatable_settings: Path, + ): + patch_resp = await client.patch( + "/api/agent/explore/delegatable", + json={"delegatable": False}, + ) + assert patch_resp.status_code == status.HTTP_200_OK + assert patch_resp.json()["delegatable"] is False + + get_resp = await client.get("/api/agent/explore") + assert get_resp.status_code == status.HTTP_200_OK + assert get_resp.json()["delegatable"] is False + + payload = json.loads(_isolated_delegatable_settings.read_text(encoding="utf-8")) + assert payload["delegatable_overrides"]["explore"] is False + + @pytest.mark.asyncio + async def test_patch_delegatable_syncs_is_delegatable_without_followup_list( + self, + client: AsyncClient, + _isolated_delegatable_settings: Path, + ): + """PATCH must refresh _agents_ref so delegate_task sees the new value immediately.""" + from flocks.agent.registry import Agent, is_delegatable + + await Agent.state() + assert is_delegatable("explore") is True + + patch_resp = await client.patch( + "/api/agent/explore/delegatable", + json={"delegatable": False}, + ) + assert patch_resp.status_code == status.HTTP_200_OK + assert is_delegatable("explore") is False + + patch_resp = await client.patch( + "/api/agent/explore/delegatable", + json={"delegatable": True}, + ) + assert patch_resp.status_code == status.HTTP_200_OK + assert is_delegatable("explore") is True + # =========================================================================== # Delete diff --git a/tests/server/routes/test_console_upgrade_routes.py b/tests/server/routes/test_console_upgrade_routes.py index 104a44392..8eb8380bd 100644 --- a/tests/server/routes/test_console_upgrade_routes.py +++ b/tests/server/routes/test_console_upgrade_routes.py @@ -886,7 +886,7 @@ async def test_refresh_approved_request_does_not_auto_activate_install( assert "auto_install_task_scheduled_at" not in payload["details"] -async def test_start_approved_request_streams_upgrade_and_marks_activated( +async def test_start_approved_request_streams_restart_without_marking_activated( client: AsyncClient, monkeypatch: pytest.MonkeyPatch, ): @@ -948,10 +948,71 @@ async def _fake_report(record: dict, *, install_result: str, error_message: str assert "Restarting service" in resp.text stored = await Storage.get(f"console:upgrade_request:{request_id}") - assert stored["status"] == "activated" + assert stored["status"] == "approved" assert stored["details"]["auto_install_result"] == "restarting" - assert stored["details"]["auto_install_version"] == "v2026.5.9" - assert reported == [("success", None)] + assert "auto_install_version" not in stored["details"] + assert reported == [] + + +async def test_start_approved_request_reports_error_after_restart_stage( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + from flocks.server.routes import console_upgrade as console_routes + from flocks.storage.storage import Storage + from flocks.updater.models import UpdateProgress + + monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "https://console.example.com") + monkeypatch.setattr(console_routes, "require_admin", lambda _req: _mock_admin()) + await _set_bound_console_session() + request_id = "req_start_restart_then_error" + await Storage.set( + f"console:upgrade_request:{request_id}", + { + "request_id": request_id, + "status": "approved", + "previous_request_id": None, + "reason": None, + "suggestion": None, + "activate_key": "key_start", + "manifest_url": "https://manifest.example.com/v1/manifest/latest", + "details": {"company": "acme"}, + "created_at": "2026-05-08T08:00:00+00:00", + "updated_at": "2026-05-08T08:00:00+00:00", + }, + "json", + ) + + async def _fake_perform_pro_bundle_install(*args, **kwargs): + yield UpdateProgress(stage="fetching", message="Downloading Flocks Pro bundle...", success=None) + yield UpdateProgress(stage="restarting", message="Restarting service...", success=None) + yield UpdateProgress(stage="error", message="Failed to build restart command: missing python", success=False) + + async def _noop(_record: dict): + return None + + reported: list[tuple[str, str | None]] = [] + + async def _fake_report(record: dict, *, install_result: str, error_message: str | None = None): + reported.append((install_result, error_message)) + + monkeypatch.setattr(console_routes, "perform_pro_bundle_install", _fake_perform_pro_bundle_install) + monkeypatch.setattr(console_routes, "_maybe_activate_pro_license", _noop) + monkeypatch.setattr(console_routes, "_maybe_refresh_pro_license", _noop) + monkeypatch.setattr(console_routes, "_report_pro_bundle_installation", _fake_report) + monkeypatch.setattr(console_routes, "_mark_console_upgrade_activated", _noop) + monkeypatch.setattr(console_routes, "_get_pro_capability_status", lambda: {"pro_enabled": True, "active": True}) + + resp = await client.post(f"/api/console/upgrade-requests/{request_id}/start") + + assert resp.status_code == status.HTTP_200_OK + assert "Restarting service" in resp.text + assert "Failed to build restart command" in resp.text + stored = await Storage.get(f"console:upgrade_request:{request_id}") + assert stored["status"] == "approved" + assert stored["details"]["auto_install_result"] == "failed" + assert stored["details"]["auto_install_error"] == "Failed to build restart command: missing python" + assert reported == [("failed", "Failed to build restart command: missing python")] async def test_start_activated_request_reinstalls_when_pro_package_missing( @@ -1178,3 +1239,48 @@ async def _noop(_record: dict): assert payload["details"]["auto_install_result"] == "done" assert payload["details"]["auto_install_version"] == "v2026.5.9" + +async def test_report_pro_bundle_installation_uses_license_id( + monkeypatch: pytest.MonkeyPatch, +): + from flocks.server.routes import console_upgrade as console_routes + + posted_payloads: list[dict] = [] + + class _Response: + def raise_for_status(self): + return None + + class _Client: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def post(self, url, json=None, headers=None): + posted_payloads.append(json) + assert url == "https://console.example.com/v1/pro-bundles/installations" + assert headers == {"Authorization": "Bearer token_abc"} + return _Response() + + monkeypatch.setenv("FLOCKS_CONSOLE_BASE_URL", "https://console.example.com") + await _set_bound_console_session() + monkeypatch.setattr(console_routes.httpx, "AsyncClient", lambda timeout=10: _Client()) + monkeypatch.setattr( + console_routes, + "_read_pro_bundle_install_marker", + lambda: {"installed_version": "v2026.5.9"}, + ) + + record = { + "request_id": "req_receipt", + "status": "approved", + "activate_key": "activation_token", + "license_id": "lic_receipt", + "details": {"license_id": "lic_receipt"}, + } + + await console_routes._report_pro_bundle_installation(record, install_result="success") + + assert posted_payloads[0]["license_id"] == "lic_receipt" diff --git a/tests/server/routes/test_session_routes.py b/tests/server/routes/test_session_routes.py index f633fe30b..e649d7351 100644 --- a/tests/server/routes/test_session_routes.py +++ b/tests/server/routes/test_session_routes.py @@ -12,6 +12,7 @@ from __future__ import annotations +import asyncio from types import SimpleNamespace from unittest.mock import AsyncMock @@ -861,6 +862,90 @@ async def test_clear_session(self, client: AsyncClient, session_id: str): list_resp = await client.get(f"/api/session/{session_id}/message") assert list_resp.json() == [] + @pytest.mark.asyncio + async def test_clear_session_clears_prompt_queue(self, client: AsyncClient, session_id: str): + """POST /api/session/{id}/clear also drops queued prompts.""" + from flocks.session.interaction_queue import InteractionQueue + + await InteractionQueue.clear(session_id) + await InteractionQueue.enqueue( + session_id, + parts=[{"type": "text", "text": "queued after current reply"}], + ) + + clear_resp = await client.post(f"/api/session/{session_id}/clear") + + assert clear_resp.status_code == status.HTTP_200_OK + assert await InteractionQueue.list(session_id) == [] + + @pytest.mark.asyncio + async def test_clear_session_waits_for_idle_before_clearing_messages(self, monkeypatch): + """History is cleared only after abort drains the active session loop.""" + from flocks.server.routes import session as session_routes + from flocks.session.interaction_queue import InteractionQueue + + session_id = "ses_clear_waits_for_idle" + await InteractionQueue.clear(session_id) + await InteractionQueue.enqueue( + session_id, + parts=[{"type": "text", "text": "queued prompt"}], + ) + order: list[str] = [] + + async def fake_abort_session(_session_id: str) -> bool: + order.append("abort") + return True + + async def fake_wait_for_session_idle(_session_id: str) -> None: + order.append("wait") + assert await InteractionQueue.list(session_id) == [] + + async def fake_message_clear(_session_id: str) -> int: + order.append("message_clear") + return 2 + + async def fake_publish_event(_event: str, _payload: dict) -> None: + return None + + monkeypatch.setattr( + session_routes.Session, + "get_by_id", + AsyncMock(return_value=SimpleNamespace(id=session_id, directory="/tmp/project")), + ) + monkeypatch.setattr(session_routes, "abort_session", fake_abort_session) + monkeypatch.setattr(session_routes, "_wait_for_session_idle", fake_wait_for_session_idle) + monkeypatch.setattr("flocks.session.message.Message.clear", fake_message_clear) + monkeypatch.setattr("flocks.server.routes.event.publish_event", fake_publish_event) + + deleted = await session_routes._clear_session_history(session_id) + + assert deleted == 2 + assert order == ["abort", "wait", "message_clear"] + assert await InteractionQueue.list(session_id) == [] + + @pytest.mark.asyncio + async def test_clear_command_removes_messages(self, client: AsyncClient, session_id: str): + """Command route /clear removes messages via the dispatcher task.""" + from flocks.server.routes import session as session_routes + + await client.post( + f"/api/session/{session_id}/message", + json={"parts": [{"type": "text", "text": "msg"}], "noReply": True}, + ) + + clear_resp = await session_routes.send_session_command( + session_id, + session_routes.CommandRequest(command="clear"), + ) + assert clear_resp["status"] == "accepted" + + pending = list(getattr(session_routes.router, "_pending_tasks", set())) + assert pending + await asyncio.gather(*pending) + + list_resp = await client.get(f"/api/session/{session_id}/message") + assert list_resp.json() == [] + @pytest.mark.asyncio async def test_abort_session(self, client: AsyncClient, session_id: str): """POST /api/session/{id}/abort returns 200 (no active generation needed).""" diff --git a/tests/server/routes/test_update_routes.py b/tests/server/routes/test_update_routes.py new file mode 100644 index 000000000..f673fca72 --- /dev/null +++ b/tests/server/routes/test_update_routes.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import pytest +from fastapi import HTTPException, status +from starlette.requests import Request + + +pytestmark = pytest.mark.asyncio + + +def _request() -> Request: + return Request({"type": "http", "method": "GET", "path": "/api/update/check", "headers": []}) + + +async def test_check_version_requires_admin_for_flockspro(monkeypatch: pytest.MonkeyPatch): + from flocks.server.routes import update as update_routes + + called = False + + def _deny_admin(_request): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="请先登录") + + async def _fake_check_update(**kwargs): + nonlocal called + called = True + raise AssertionError("Pro update checks must not reach updater before admin auth") + + monkeypatch.setattr(update_routes, "require_admin", _deny_admin) + monkeypatch.setattr(update_routes, "check_update", _fake_check_update) + + with pytest.raises(HTTPException) as exc: + await update_routes.check_version(_request(), locale=None, edition="flockspro") + + assert exc.value.status_code == status.HTTP_401_UNAUTHORIZED + assert called is False + + +async def test_check_version_keeps_flocks_channel_public(monkeypatch: pytest.MonkeyPatch): + from flocks.server.routes import update as update_routes + from flocks.updater.models import VersionInfo + + def _deny_admin(_request): + raise AssertionError("Flocks channel check should not require admin at route level") + + async def _fake_check_update(**kwargs): + assert kwargs == {"locale": "zh-CN", "force_console_manifest": False} + return VersionInfo(current_version="v2026.5.9") + + monkeypatch.setattr(update_routes, "require_admin", _deny_admin) + monkeypatch.setattr(update_routes, "check_update", _fake_check_update) + + info = await update_routes.check_version(_request(), locale="zh-CN", edition="flocks") + + assert info.current_version == "v2026.5.9" diff --git a/tests/server/routes/test_workflow_poller_routes.py b/tests/server/routes/test_workflow_poller_routes.py new file mode 100644 index 000000000..9baf2d2dc --- /dev/null +++ b/tests/server/routes/test_workflow_poller_routes.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from httpx import AsyncClient + +from flocks.server.routes import workflow as workflow_routes + + +@pytest.mark.asyncio +async def test_save_poller_config_restarts_manager( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + writes: list[tuple[str, dict[str, Any]]] = [] + + async def _fake_write(key: Any, value: dict[str, Any]) -> None: + writes.append((key, value)) + + async def _fake_restart(workflow_id: str) -> dict[str, Any]: + assert workflow_id == "wf-1" + return {"workflowId": workflow_id, "state": "running", "lastStatus": None} + + monkeypatch.setattr( + workflow_routes, + "_read_workflow_from_fs", + lambda workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None, + ) + monkeypatch.setattr(workflow_routes.Storage, "write", _fake_write) + monkeypatch.setattr( + "flocks.workflow.poller_manager.default_manager", + SimpleNamespace(restart_workflow=_fake_restart), + ) + + response = await client.post( + "/api/workflow/wf-1/poller-config", + json={ + "enabled": True, + "intervalSeconds": 45, + "timeoutSeconds": 3600, + "noOverlap": True, + "inputs": {"persist_triage_output": True}, + }, + ) + + assert response.status_code == 200, response.text + poller_writes = [(key, value) for key, value in writes if key == "workflow_poller_config/wf-1"] + assert poller_writes + key, payload = poller_writes[0] + assert key == "workflow_poller_config/wf-1" + assert payload["enabled"] is True + assert payload["intervalSeconds"] == 45 + assert payload["timeoutSeconds"] == 3600 + assert payload["inputs"] == {"persist_triage_output": True} + + +@pytest.mark.asyncio +async def test_get_poller_config_returns_saved_data( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def _fake_read(_key: Any, *_args: Any, **_kwargs: Any) -> dict[str, Any] | None: + if _key != "workflow_poller_config/wf-1": + return None + return { + "workflowId": "wf-1", + "enabled": True, + "intervalSeconds": 30, + "timeoutSeconds": 7200, + "noOverlap": True, + "inputs": {"dedup_source_workflow_name": "stream_alert_denoise_gt_fast"}, + } + + monkeypatch.setattr(workflow_routes.Storage, "read", _fake_read) + + response = await client.get("/api/workflow/wf-1/poller-config") + assert response.status_code == 200, response.text + assert response.json()["workflowId"] == "wf-1" + assert response.json()["intervalSeconds"] == 30 + + +@pytest.mark.asyncio +async def test_get_poller_status_returns_runtime_snapshot( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr( + "flocks.workflow.poller_manager.default_manager", + SimpleNamespace( + get_status=lambda workflow_id: { + "workflowId": workflow_id, + "state": "running", + "lastStatus": "success", + "selectedCount": 12, + }, + ), + ) + + response = await client.get("/api/workflow/wf-1/poller-status") + assert response.status_code == 200, response.text + assert response.json()["state"] == "running" + assert response.json()["selectedCount"] == 12 + + +@pytest.mark.asyncio +async def test_run_poller_once_returns_latest_status( + client: AsyncClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + async def _fake_run_once(workflow_id: str) -> dict[str, Any]: + return { + "workflowId": workflow_id, + "state": "stopped", + "lastStatus": "success", + "selectedCount": 5, + } + + monkeypatch.setattr( + workflow_routes, + "_read_workflow_from_fs", + lambda workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}} if workflow_id == "wf-1" else None, + ) + monkeypatch.setattr( + "flocks.workflow.poller_manager.default_manager", + SimpleNamespace(run_once=_fake_run_once), + ) + + response = await client.post("/api/workflow/wf-1/poller-run-once") + assert response.status_code == 200, response.text + assert response.json()["status"]["lastStatus"] == "success" + assert response.json()["status"]["selectedCount"] == 5 diff --git a/tests/server/routes/test_workflow_run_route.py b/tests/server/routes/test_workflow_run_route.py index a456c9343..77061dc1f 100644 --- a/tests/server/routes/test_workflow_run_route.py +++ b/tests/server/routes/test_workflow_run_route.py @@ -56,3 +56,49 @@ async def test_run_workflow_execution_task_reuses_existing_mcp_without_reinit( run_mock.assert_called_once() assert run_mock.call_args.kwargs["tool_context"] is tool_context record_result.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_save_kafka_config_persists_consumer_settings( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from flocks.ingest.kafka import manager as kafka_manager + + storage_write = AsyncMock(return_value=None) + restart_workflow = AsyncMock(return_value={"state": "running", "error": None}) + + monkeypatch.setattr(workflow_module, "_read_workflow_from_fs", lambda _workflow_id: {"workflowJson": {}}) + monkeypatch.setattr(workflow_module.Storage, "write", storage_write) + monkeypatch.setattr(kafka_manager.default_manager, "restart_workflow", restart_workflow) + + req = workflow_module.KafkaConfigRequest( + enabled=True, + inputBroker="localhost:9092", + inputTopic="workflow-input", + inputGroupId="wf-group", + inputKey="kafka_message", + inputs={ + "_comment": "remove me", + "kafka_output_enabled": True, + "kafka_output_topic": "topic_soc_flocks_result_log", + }, + ) + + response = await workflow_module.save_kafka_config("wf-input", req) + + assert response == {"ok": True, "consumer": {"state": "running", "error": None}} + storage_write.assert_awaited_once() + _, saved_config = storage_write.await_args.args + assert saved_config["enabled"] is True + assert saved_config["inputBroker"] == "localhost:9092" + assert saved_config["inputTopic"] == "workflow-input" + assert saved_config["inputGroupId"] == "wf-group" + assert saved_config["inputKey"] == "kafka_message" + assert saved_config["inputs"] == { + "kafka_output_enabled": True, + "kafka_output_topic": "topic_soc_flocks_result_log", + } + assert "outputEnabled" not in saved_config + assert "outputBroker" not in saved_config + assert "outputTopic" not in saved_config + restart_workflow.assert_awaited_once_with("wf-input") diff --git a/tests/server/test_input_dispatcher.py b/tests/server/test_input_dispatcher.py index 2d2472005..014272dba 100644 --- a/tests/server/test_input_dispatcher.py +++ b/tests/server/test_input_dispatcher.py @@ -50,13 +50,15 @@ async def test_direct_command_uses_direct_response(self): assert not llm @pytest.mark.asyncio - async def test_clear_without_clear_callback_sends_fallback_text(self): + async def test_clear_uses_history_callback_without_direct_response(self): direct = [] llm = [] + clear_history_calls = [] sink = CallbackOutputSink( "webui", direct_response=lambda _event, text: _append(direct, text), run_llm=lambda _event, prompt, display: _append(llm, (prompt, display)), + clear_history=lambda: _append(clear_history_calls, "cleared"), ) event = UserInputEvent( source_type="webui", @@ -68,7 +70,33 @@ async def test_clear_without_clear_callback_sends_fallback_text(self): result = await dispatch_user_input(event, sink) assert result.action == "direct" - assert direct == ["Screen cleared."] + assert clear_history_calls == ["cleared"] + assert not direct + assert not llm + + @pytest.mark.asyncio + async def test_clear_is_allowed_on_channel_surface(self): + direct = [] + llm = [] + clear_history_calls = [] + sink = CallbackOutputSink( + "channel", + direct_response=lambda _event, text: _append(direct, text), + run_llm=lambda _event, prompt, display: _append(llm, (prompt, display)), + clear_history=lambda: _append(clear_history_calls, "cleared"), + ) + event = UserInputEvent( + source_type="wecom", + sessionID="ses_test", + text="/clear", + parts=[{"type": "text", "text": "/clear"}], + ) + + result = await dispatch_user_input(event, sink) + + assert result.action == "direct" + assert clear_history_calls == ["cleared"] + assert not direct assert not llm @pytest.mark.asyncio @@ -256,5 +284,134 @@ async def fake_provide(*, directory, init, fn): assert event.display_text == "/plan investigate" +class TestPromptQueueRoutes: + @pytest.mark.asyncio + async def test_prompt_async_queues_when_session_running_without_creating_message(self, monkeypatch): + from flocks.server.routes import session as session_routes + from flocks.session.interaction_queue import InteractionQueue + + session_id = "ses_prompt_queue_running" + await InteractionQueue.clear(session_id) + + message_create = AsyncMock() + monkeypatch.setattr( + "flocks.session.session.Session.get_by_id", + AsyncMock(return_value=SimpleNamespace(id=session_id, directory="/tmp/project")), + ) + monkeypatch.setattr("flocks.session.session_loop.SessionLoop.is_running", lambda _sid: True) + monkeypatch.setattr("flocks.session.message.Message.create", message_create) + monkeypatch.setattr(session_routes, "_publish_prompt_queue", AsyncMock()) + + request = session_routes.PromptRequest(parts=[{"type": "text", "text": "second question"}]) + + resp = await session_routes.send_session_message_async(session_id, request) + + assert resp["status"] == "queued" + assert resp["queueID"] + items = await InteractionQueue.list(session_id) + assert len(items) == 1 + assert items[0].parts[0]["text"] == "second question" + message_create.assert_not_called() + + @pytest.mark.asyncio + async def test_prompt_queue_rejects_when_full(self, monkeypatch): + from fastapi import HTTPException + + from flocks.server.routes import session as session_routes + from flocks.session.interaction_queue import InteractionQueue, MAX_QUEUE_SIZE + + session_id = "ses_prompt_queue_full" + await InteractionQueue.clear(session_id) + monkeypatch.setattr( + "flocks.session.session.Session.get_by_id", + AsyncMock(return_value=SimpleNamespace(id=session_id, directory="/tmp/project")), + ) + monkeypatch.setattr("flocks.session.session_loop.SessionLoop.is_running", lambda _sid: True) + monkeypatch.setattr(session_routes, "_publish_prompt_queue", AsyncMock()) + + for idx in range(MAX_QUEUE_SIZE): + await InteractionQueue.enqueue( + session_id, + parts=[{"type": "text", "text": f"queued {idx}"}], + ) + + request = session_routes.PromptRequest(parts=[{"type": "text", "text": "overflow"}]) + with pytest.raises(HTTPException) as exc_info: + await session_routes.send_session_message_async(session_id, request) + + assert exc_info.value.status_code == 409 + + @pytest.mark.asyncio + async def test_run_now_aborts_and_schedules_drain(self, monkeypatch): + from flocks.server.routes import session as session_routes + from flocks.session.interaction_queue import InteractionQueue + + session_id = "ses_prompt_queue_run_now" + await InteractionQueue.clear(session_id) + item = await InteractionQueue.enqueue( + session_id, + parts=[{"type": "text", "text": "run this now"}], + ) + + abort_mock = AsyncMock(return_value=True) + wait_mock = AsyncMock() + drain_mock = AsyncMock() + monkeypatch.setattr( + "flocks.session.session.Session.get_by_id", + AsyncMock(return_value=SimpleNamespace(id=session_id, directory="/tmp/project")), + ) + monkeypatch.setattr("flocks.session.session_loop.SessionLoop.is_running", lambda _sid: True) + monkeypatch.setattr(session_routes, "abort_session", abort_mock) + monkeypatch.setattr(session_routes, "_wait_for_session_idle", wait_mock) + monkeypatch.setattr(session_routes, "_schedule_prompt_queue_drain", drain_mock) + monkeypatch.setattr(session_routes, "_publish_prompt_queue", AsyncMock()) + + resp = await session_routes.run_prompt_queue_item_now(session_id, item.id) + + assert resp["status"] == "accepted" + abort_mock.assert_awaited_once_with(session_id) + wait_mock.assert_awaited_once_with(session_id) + drain_mock.assert_awaited_once_with(session_id, "/tmp/project") + + @pytest.mark.asyncio + async def test_scheduled_drain_retries_until_session_idle(self, monkeypatch): + from flocks.server.routes import session as session_routes + from flocks.session.interaction_queue import InteractionQueue + + session_id = "ses_prompt_queue_retry" + await InteractionQueue.clear(session_id) + await InteractionQueue.enqueue( + session_id, + parts=[{"type": "text", "text": "run after idle"}], + ) + + running_states = [True, True, False] + dispatch_mock = AsyncMock() + original_sleep = asyncio.sleep + + async def fake_provide(*, directory, init, fn): + await fn() + + monkeypatch.setattr( + "flocks.session.session.Session.get_by_id", + AsyncMock(return_value=SimpleNamespace(id=session_id, directory="/tmp/project")), + ) + monkeypatch.setattr("flocks.project.instance.Instance.provide", fake_provide) + monkeypatch.setattr(session_routes, "_dispatch_sse_input", dispatch_mock) + monkeypatch.setattr(session_routes, "_publish_prompt_queue", AsyncMock()) + monkeypatch.setattr(asyncio, "sleep", AsyncMock()) + monkeypatch.setattr( + "flocks.session.session_loop.SessionLoop.is_running", + lambda _sid: running_states.pop(0) if running_states else False, + ) + + await session_routes._schedule_prompt_queue_drain(session_id, "/tmp/project") + await original_sleep(0) + await original_sleep(0) + + dispatch_mock.assert_awaited_once() + assert await InteractionQueue.list(session_id) == [] + + async def _append(target: list, value): target.append(value) diff --git a/tests/session/test_cli_title_generation.py b/tests/session/test_cli_title_generation.py index 2e548ba94..4693d3a8f 100644 --- a/tests/session/test_cli_title_generation.py +++ b/tests/session/test_cli_title_generation.py @@ -80,6 +80,48 @@ async def fake_stream(*args, **kwargs): assert title == "Debug Python" mock_update.assert_awaited_once_with("proj-1", "sess-1", title="Debug Python") + @pytest.mark.asyncio + async def test_rejects_tool_call_payload_title_and_falls_back(self): + """Tool-call payloads from the title model are not persisted as titles.""" + from flocks.session.lifecycle.title import SessionTitle + + question = ( + "based on ThreatBook Threat Intelligence, please give me reports for " + "cyber news related to Hong Kong for the past 7 days" + ) + mock_session = _make_session() + msg, part = _make_user_msg(question) + mock_provider = MagicMock() + + async def fake_stream(*args, **kwargs): + yield MagicMock(delta='[TOOL_CALL]\n{tool => "news", args => {\n --query: "Hong Kong"\n}}') + + mock_provider.chat_stream = fake_stream + mock_update = AsyncMock() + + patches = _patch_title_deps(mock_session, [msg], [part], mock_provider, mock_update) + with patches[0], patches[1], patches[2], patches[3], patches[4], patches[5], patches[6]: + title = await SessionTitle.generate_title_after_first_message( + session_id="sess-1", + model_id="minimax-m2.7", + provider_id="threatbook-cn-llm", + ) + + expected = SessionTitle._generate_simple_title(question) + assert title == expected + assert "TOOL_CALL" not in title + mock_update.assert_awaited_once_with("proj-1", "sess-1", title=expected) + + def test_rejects_json_function_call_title_candidate(self): + """Structured function-call shaped JSON is invalid as a thread title.""" + from flocks.session.lifecycle.title import SessionTitle + + title = SessionTitle._sanitize_generated_title( + '{"name": "news", "arguments": {"query": "Hong Kong"}}' + ) + + assert title == "" + @pytest.mark.asyncio async def test_skips_when_more_than_one_user_message(self): """Returns None when there are 2+ user messages (not the first turn).""" diff --git a/tests/session/test_message_parts_persistence.py b/tests/session/test_message_parts_persistence.py index a36c9b508..a7229b285 100644 --- a/tests/session/test_message_parts_persistence.py +++ b/tests/session/test_message_parts_persistence.py @@ -45,6 +45,16 @@ async def _write_legacy_session(session_id: str, messages: dict[str, str]) -> No Message.invalidate_cache(session_id) +async def _write_raw_legacy_payload( + session_id: str, + messages: list[dict], + parts: dict[str, list], +) -> None: + await Storage.set(f"message:{session_id}", messages, "message") + await Storage.set(f"message_parts:{session_id}", parts, "message_parts") + Message.invalidate_cache(session_id) + + @pytest.mark.asyncio async def test_new_sessions_write_per_message_parts_keys() -> None: session_id = "ses_parts_per_message_new" @@ -153,3 +163,164 @@ async def test_clear_removes_legacy_blob_and_per_message_keys() -> None: assert await Storage.get(f"message_parts:{per_message_session_id}") is None assert await Storage.list_keys(prefix=f"message_parts:{per_message_session_id}:") == [] + + +def test_deserialize_legacy_text_part_normalizes_content_and_time() -> None: + part = Message.deserialize_part( + { + "id": "part_legacy_text", + "sessionID": "ses_legacy_text", + "messageID": "msg_legacy_text", + "type": "text", + "content": "hello legacy", + "time": {"created": 7}, + } + ) + + assert part.text == "hello legacy" + assert part.time is not None + assert part.time.start == 7 + assert part.time.end == 7 + + +@pytest.mark.asyncio +async def test_ensure_cache_loads_legacy_assistant_missing_fields() -> None: + session_id = "ses_legacy_assistant_missing_fields" + await _write_raw_legacy_payload( + session_id, + messages=[ + { + "id": "msg_assistant_legacy", + "role": "assistant", + "time": {"created": 2}, + "path": [], + "content": "", + } + ], + parts={ + "msg_assistant_legacy": [ + { + "id": "part_assistant_legacy", + "type": "text", + "content": "restored assistant text", + "time": {"created": 2}, + } + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + info = messages[0].info + assert info.sessionID == session_id + assert info.agent == "rex" + assert info.parentID == "" + assert info.modelID == "" + assert info.providerID == "" + assert info.path.cwd == "./" + assert info.tokens.input == 0 + assert messages[0].parts[0].text == "restored assistant text" + assert messages[0].parts[0].time is not None + assert messages[0].parts[0].time.start == 2 + + +@pytest.mark.asyncio +async def test_ensure_cache_preserves_zero_created_timestamp() -> None: + session_id = "ses_legacy_zero_created" + await _write_raw_legacy_payload( + session_id, + messages=[ + { + "id": "msg_assistant_zero", + "role": "assistant", + "time": {"created": 0}, + "path": [], + "content": "", + } + ], + parts={ + "msg_assistant_zero": [ + { + "id": "part_assistant_zero", + "type": "text", + "content": "zero timestamp text", + "time": {"created": 0}, + } + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + assert messages[0].info.time["created"] == 0 + assert messages[0].parts[0].time is not None + assert messages[0].parts[0].time.start == 0 + assert messages[0].parts[0].time.end == 0 + + +@pytest.mark.asyncio +async def test_ensure_cache_loads_legacy_tool_part_without_time() -> None: + session_id = "ses_legacy_tool_missing_time" + await _write_raw_legacy_payload( + session_id, + messages=[ + { + "id": "msg_assistant_tool", + "role": "assistant", + "time": {"created": 0}, + "path": [], + "content": "", + } + ], + parts={ + "msg_assistant_tool": [ + { + "id": "part_tool_legacy", + "type": "tool", + "tool": "bash", + "callID": "call_legacy", + "state": { + "status": "completed", + "output": "legacy output", + }, + } + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + assert len(messages[0].parts) == 1 + tool_part = messages[0].parts[0] + assert tool_part.type == "tool" + assert tool_part.state.status == "completed" + assert tool_part.state.time == {"start": 0, "end": 0} + + +@pytest.mark.asyncio +async def test_ensure_cache_skips_invalid_part_keeps_siblings() -> None: + session_id = "ses_legacy_bad_part_skip" + await _write_raw_legacy_payload( + session_id, + messages=[_user_message(session_id, "msg_a").model_dump()], + parts={ + "msg_a": [ + "not-a-dict-part", + { + "id": "part_good", + "sessionID": session_id, + "messageID": "msg_a", + "type": "text", + "text": "still here", + }, + ] + }, + ) + + messages = await Message.list_with_parts(session_id) + + assert len(messages) == 1 + assert [part.text for part in messages[0].parts] == ["still here"] diff --git a/tests/tool/test_360_waf_device_plugin.py b/tests/tool/test_360_waf_device_plugin.py new file mode 100644 index 000000000..356f42bb6 --- /dev/null +++ b/tests/tool/test_360_waf_device_plugin.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +import importlib.util +import json +import shutil +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import pytest +import yaml + +from flocks.config.api_versioning import derive_storage_key +from flocks.tool.registry import ToolContext, ToolResult +from flocks.tool.tool_loader import yaml_to_tool + + +_ROOT = Path(__file__).resolve().parents[2] +_PLUGIN_DIR = _ROOT / ".flocks" / "flockshub" / "plugins" / "tools" / "device" / "360_waf_v5_5" +_HANDLER_PATH = _PLUGIN_DIR / "360_waf.handler.py" + + +def _installed_plugin_dir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + project_root = tmp_path / "project" + install_dir = project_root / ".flocks" / "plugins" / "tools" / "device" / "360_waf_v5_5" + shutil.copytree(_PLUGIN_DIR, install_dir) + monkeypatch.chdir(project_root) + return install_dir + + +def _load_handler(): + spec = importlib.util.spec_from_file_location("_test_360_waf_handler", _HANDLER_PATH) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_provider_metadata_declares_360_waf_v5_5_device_plugin(): + raw = yaml.safe_load((_PLUGIN_DIR / "_provider.yaml").read_text(encoding="utf-8")) + + assert raw["name"] == "360_waf" + assert raw["service_id"] == "360_waf" + assert raw["version"] == "5.5" + assert raw["integration_type"] == "device" + assert raw["description_cn"] + assert derive_storage_key(raw["service_id"], raw["version"]) == "360_waf_v5_5" + assert "allow_mutation" not in raw["defaults"] + assert "allow_dangerous_ops" not in raw["defaults"] + + credential_keys = {field["key"] for field in raw["credential_fields"]} + assert {"base_url", "username", "password"} <= credential_keys + + +def test_probe_manifest_declares_connectivity_and_fixtures(): + raw = yaml.safe_load((_PLUGIN_DIR / "_test.yaml").read_text(encoding="utf-8")) + + assert raw["connectivity"]["tool"] == "360_waf_system" + assert raw["connectivity"]["params"] == {"action": "waf_check_login"} + + expected_tools = { + "360_waf_system", + "360_waf_site", + "360_waf_policy_ops", + "360_waf_observability", + "360_waf_api_readonly", + "360_waf_api_mutation", + "360_waf_file", + } + for tool_name in expected_tools: + assert raw["fixtures"][tool_name], tool_name + + +def test_mutation_manifest_uses_json_string_body_without_framework_schema_extensions(): + raw = yaml.safe_load( + (_PLUGIN_DIR / "360_waf_api_mutation.yaml").read_text(encoding="utf-8") + ) + + body_schema = raw["inputSchema"]["properties"]["body"] + + assert raw["requires_confirmation"] is True + assert body_schema["type"] == "string" + assert "oneOf" not in body_schema + assert "confirm" not in raw["inputSchema"]["properties"] + + +def test_json_payload_strings_are_parsed_by_handler(): + handler = _load_handler() + payload = [ + { + "siteId": 2147483647, + "type": 1, + "content": "192.0.2.236", + "is_permanent": "1", + } + ] + object_payload = {"conditions": []} + + assert handler.require_payload(json.dumps(payload)) == payload + assert handler.require_payload(json.dumps(object_payload)) == object_payload + assert handler.optional_payload("") is None + + +def test_policy_ops_manifest_exposes_business_mutation_actions(): + raw = yaml.safe_load((_PLUGIN_DIR / "360_waf_policy_ops.yaml").read_text(encoding="utf-8")) + + action_enum = set(raw["inputSchema"]["properties"]["action"]["enum"]) + body_schema = raw["inputSchema"]["properties"]["body"] + + assert raw["requires_confirmation"] is True + assert { + "waf_blacklist_create", + "waf_blacklist_delete", + "waf_site_global_blacklist_create", + "waf_site_global_blacklist_delete", + "waf_whitelist_create", + "waf_whitelist_delete", + "waf_site_global_whitelist_create", + "waf_site_global_whitelist_delete", + "waf_exception_rule_create", + "waf_exception_rule_update", + "waf_exception_rule_delete", + } <= action_enum + assert body_schema["type"] == "string" + assert "oneOf" not in body_schema + assert "confirm" not in raw["inputSchema"]["properties"] + + +def test_observability_manifest_exposes_log_query_filters(): + raw = yaml.safe_load((_PLUGIN_DIR / "360_waf_observability.yaml").read_text(encoding="utf-8")) + + action_enum = set(raw["inputSchema"]["properties"]["action"]["enum"]) + properties = raw["inputSchema"]["properties"] + + assert "waf_configuration_log_search" in action_enum + assert {"time_start", "time_end", "http_url", "action_filter", "msg"} <= set(properties) + + +@pytest.mark.parametrize( + ("yaml_name", "function_name"), + [ + ("360_waf_system.yaml", "system"), + ("360_waf_site.yaml", "site"), + ("360_waf_policy_ops.yaml", "policy_ops"), + ("360_waf_observability.yaml", "observability"), + ("360_waf_api_readonly.yaml", "api_readonly"), + ("360_waf_api_mutation.yaml", "api_mutation"), + ("360_waf_file.yaml", "file_ops"), + ], +) +def test_group_manifest_loads_as_device_tool( + yaml_name: str, + function_name: str, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +): + install_dir = _installed_plugin_dir(tmp_path, monkeypatch) + yaml_path = install_dir / yaml_name + raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8")) + tool = yaml_to_tool(raw, yaml_path) + + assert tool.info.provider == "360_waf_v5_5" + assert tool.info.source == "device" + assert tool.info.provider_version == "5.5" + assert raw["provider"] == "360_waf" + assert raw["handler"]["script_file"] == "360_waf.handler.py" + assert raw["handler"]["function"] == function_name + assert "action" in raw["inputSchema"]["required"] + assert raw["inputSchema"]["properties"]["action"]["enum"] + + +def test_group_manifests_use_official_confirmation_flags(): + expected = { + "360_waf_system.yaml": False, + "360_waf_site.yaml": False, + "360_waf_policy_ops.yaml": True, + "360_waf_observability.yaml": False, + "360_waf_api_readonly.yaml": False, + "360_waf_api_mutation.yaml": True, + "360_waf_file.yaml": True, + } + + for yaml_name, requires_confirmation in expected.items(): + raw = yaml.safe_load((_PLUGIN_DIR / yaml_name).read_text(encoding="utf-8")) + assert raw["requires_confirmation"] is requires_confirmation + + +def test_runtime_config_resolves_configwriter_and_secret_refs(monkeypatch): + handler = _load_handler() + raw_service = { + "base_url": "https://waf.example.com/", + "username": "{secret:360_waf_v5_5_username}", + "password": "{secret:360_waf_v5_5_password}", + "timeout": "12", + "verify_ssl": "true", + } + secrets = { + "360_waf_v5_5_username": "admin", + "360_waf_v5_5_password": "pass", + } + + monkeypatch.setattr( + handler.ConfigWriter, + "get_api_service_raw", + staticmethod(lambda service_id: raw_service if service_id == "360_waf" else None), + ) + monkeypatch.setattr(handler, "get_secret_manager", lambda: SimpleNamespace(get=secrets.get)) + + config = handler._load_runtime_config() + + assert config.base_url == "https://waf.example.com" + assert config.username == "admin" + assert config.password == "pass" + assert config.timeout == 12 + assert config.verify_ssl is True + + +def test_client_cache_key_does_not_store_plaintext_password(): + handler = _load_handler() + config = handler.RuntimeConfig( + base_url="https://waf.example.com", + username="admin", + password="secret-password", + verify_ssl=False, + timeout=30, + ) + + key = handler._client_cache_key(config) + + assert "secret-password" not in key + assert key == ("https://waf.example.com", "admin", False) + + +def test_ssl_cipher_downgrade_only_for_unverified_connections(monkeypatch): + handler = _load_handler() + + class _FakeContext: + def __init__(self) -> None: + self.ciphers: list[str] = [] + + def set_ciphers(self, value: str) -> None: + self.ciphers.append(value) + + verified = _FakeContext() + unverified = _FakeContext() + monkeypatch.setattr(handler.ssl, "create_default_context", lambda: verified) + monkeypatch.setattr(handler.ssl, "_create_unverified_context", lambda: unverified) + + base = { + "base_url": "https://waf.example.com", + "username": "admin", + "password": "secret", + "timeout": 30, + } + + handler.WafClient(handler.RuntimeConfig(**base, verify_ssl=True)) + handler.WafClient(handler.RuntimeConfig(**base, verify_ssl=False)) + + assert verified.ciphers == [] + assert unverified.ciphers == ["DEFAULT:@SECLEVEL=0"] + + +@pytest.mark.asyncio +async def test_unified_ops_runs_sync_handlers_in_worker_thread(monkeypatch): + handler = _load_handler() + calls: list[tuple[Any, tuple[Any, ...]]] = [] + + async def fake_to_thread(func: Any, *args: Any) -> ToolResult: + calls.append((func, args)) + return func(*args) + + def fake_handler(params: dict[str, Any]) -> ToolResult: + return ToolResult(success=True, output=params) + + monkeypatch.setattr(handler.asyncio, "to_thread", fake_to_thread) + monkeypatch.setitem(handler._ACTION_MAP, "fake_thread_action", fake_handler) + + result = await handler.unified_ops( + ToolContext(session_id="s", message_id="m"), + action="fake_thread_action", + value=1, + ) + + assert result.success is True + assert result.output == {"value": 1} + assert calls == [(fake_handler, ({"value": 1},))] + + +@pytest.mark.asyncio +async def test_api_readonly_group_dispatches_to_original_waf_action(monkeypatch): + handler = _load_handler() + calls: list[tuple[str, dict[str, Any] | None]] = [] + + class _FakeClient: + def call_readonly(self, path: str, query: dict[str, Any] | None = None) -> dict[str, Any]: + calls.append((path, query)) + return {"success": True, "result": [{"hostname": "waf01"}]} + + monkeypatch.setattr(handler, "get_client", lambda: _FakeClient()) + + result: ToolResult = await handler.api_readonly( + ToolContext(session_id="s", message_id="m"), + action="waf_call_raw_readonly", + path="rest/api/sysinfo", + query={"conditions": []}, + ) + + assert result.success is True + assert result.output == {"success": True, "result": [{"hostname": "waf01"}]} + assert calls == [("/rest/api/sysinfo", {"conditions": []})] + + +@pytest.mark.asyncio +async def test_policy_ops_builds_blacklist_and_whitelist_payloads(monkeypatch): + handler = _load_handler() + calls: list[tuple[str, str, dict[str, Any] | None, Any]] = [] + + class _FakeClient: + def request( + self, + method: str, + path: str, + query: dict[str, Any] | None = None, + body: Any = None, + ) -> dict[str, Any]: + calls.append((method, path, query, body)) + return {"success": True, "result": []} + + monkeypatch.setattr(handler, "get_client", lambda: _FakeClient()) + + ctx = ToolContext(session_id="s", message_id="m") + create_blacklist = await handler.policy_ops( + ctx, + action="waf_blacklist_create", + siteId=2147483647, + content="192.0.2.10", + ) + delete_blacklist = await handler.policy_ops( + ctx, + action="waf_blacklist_delete", + siteId=2147483647, + content="192.0.2.10", + ) + create_whitelist = await handler.policy_ops( + ctx, + action="waf_whitelist_create", + id=2147483647, + ip_start="192.0.2.11", + desc="allow scanner", + ) + delete_whitelist = await handler.policy_ops( + ctx, + action="waf_whitelist_delete", + id=2147483647, + ip_start="192.0.2.11", + ) + + assert create_blacklist.success is True + assert delete_blacklist.success is True + assert create_whitelist.success is True + assert delete_whitelist.success is True + assert calls == [ + ( + "POST", + "/rest/api/blacklist", + None, + [{"siteId": 2147483647, "type": 1, "content": "192.0.2.10", "is_permanent": "1"}], + ), + ( + "DELETE", + "/rest/api/blacklist", + None, + [{"siteId": 2147483647, "type": 1, "content": "192.0.2.10"}], + ), + ( + "POST", + "/rest/api/whitelist", + None, + { + "id": 2147483647, + "ip_whitelist": { + "ip_ver": "0", + "type": "0", + "ip_start": "192.0.2.11", + "desc": "allow scanner", + }, + }, + ), + ( + "DELETE", + "/rest/api/whitelist", + None, + { + "id": 2147483647, + "ip_whitelist": { + "ip_ver": "0", + "type": "0", + "ip_start": "192.0.2.11", + "ip_end": "0", + "netmask": 32, + }, + }, + ), + ] + + +@pytest.mark.asyncio +async def test_policy_ops_builds_global_list_and_exception_payloads(monkeypatch): + handler = _load_handler() + calls: list[tuple[str, str, dict[str, Any] | None, Any]] = [] + + class _FakeClient: + def request( + self, + method: str, + path: str, + query: dict[str, Any] | None = None, + body: Any = None, + ) -> dict[str, Any]: + calls.append((method, path, query, body)) + return {"success": True, "result": []} + + monkeypatch.setattr(handler, "get_client", lambda: _FakeClient()) + + ctx = ToolContext(session_id="s", message_id="m") + await handler.policy_ops( + ctx, + action="waf_site_global_blacklist_create", + content="192.0.2.12", + is_permanent=0, + block_time=120, + ) + await handler.policy_ops( + ctx, + action="waf_site_global_blacklist_delete", + content="192.0.2.12", + is_permanent=0, + ) + await handler.policy_ops( + ctx, + action="waf_site_global_whitelist_create", + ip_start="192.0.2.13", + ) + await handler.policy_ops( + ctx, + action="waf_site_global_whitelist_delete", + ip_start="192.0.2.13", + ) + payload = {"rule_id": "1000000015", "protection_sub_type": "10000"} + await handler.policy_ops(ctx, action="waf_exception_rule_create", body=payload) + await handler.policy_ops(ctx, action="waf_exception_rule_update", body=payload) + await handler.policy_ops(ctx, action="waf_exception_rule_delete", body=payload) + + assert calls == [ + ( + "POST", + "/rest/api/site_global_blacklist", + None, + [{"type": 1, "content": "192.0.2.12", "is_permanent": "0", "block_time": 120}], + ), + ( + "DELETE", + "/rest/api/site_global_blacklist", + None, + [{"type": 1, "content": "192.0.2.12"}], + ), + ( + "POST", + "/rest/api/site_global_whitelist", + None, + [{"type": "0", "ip_ver": "0", "ip_start": "192.0.2.13"}], + ), + ( + "DELETE", + "/rest/api/site_global_whitelist", + None, + [{"type": "0", "ip_ver": "0", "ip_start": "192.0.2.13", "ip_end": "0", "netmask": 32}], + ), + ("POST", "/rest/api/exceptionlist", None, payload), + ("PUT", "/rest/api/exceptionlist", None, payload), + ("DELETE", "/rest/api/exceptionlist", None, payload), + ] + + +@pytest.mark.asyncio +async def test_observability_filters_security_and_configuration_logs(monkeypatch): + handler = _load_handler() + calls: list[tuple[str, dict[str, Any] | None]] = [] + + class _FakeClient: + def get(self, path: str, query: dict[str, Any] | None = None) -> dict[str, Any]: + calls.append((path, query)) + return {"success": True, "result": []} + + monkeypatch.setattr(handler, "get_client", lambda: _FakeClient()) + + ctx = ToolContext(session_id="s", message_id="m") + security_result = await handler.observability( + ctx, + action="waf_security_log_search", + time_start="2026/05/29 08:00:00", + time_end="2026/05/29 09:00:00", + http_url="/login", + action_filter="deny", + start=5, + limit=10, + ) + config_result = await handler.observability( + ctx, + action="waf_configuration_log_search", + time_start="2026/05/29 08:00:00", + time_end="2026/05/29 09:00:00", + msg="blacklist", + start=0, + limit=20, + ) + + assert security_result.success is True + assert config_result.success is True + assert calls == [ + ( + "/rest/api/websecuritylog", + { + "conditions": [ + {"field": "time_start", "operator": 0, "value": "2026/05/29 08:00:00"}, + {"field": "time_end", "operator": 0, "value": "2026/05/29 09:00:00"}, + {"field": "http_url", "operator": 0, "value": "/login"}, + {"field": "action", "operator": 0, "value": "deny"}, + ], + "start": 5, + "limit": 10, + }, + ), + ( + "/rest/api/configurationlog", + { + "lifeTime": { + "interval": "custom", + "start": "2026/05/29 08:00:00", + "end": "2026/05/29 09:00:00", + }, + "conditions": [{"field": "msg", "operator": 0, "value": "blacklist"}], + "start": 0, + "limit": 20, + }, + ), + ] + + +@pytest.mark.parametrize( + ("method", "path"), + [ + ("POST", "/rest/api/reboot_system"), + ("POST", "/rest/api/mgmt_image"), + ("DELETE", "/rest/api/mgmt_image"), + ("POST", "/rest/api/signature"), + ("PUT", "/rest/api/waf_deploy_mode"), + ("PUT", "/rest/api/licenseManagementAgent"), + ("POST", "/rest/api/interface"), + ("DELETE", "/rest/api/zone"), + ], +) +def test_raw_mutation_rejects_waf_device_state_changes(monkeypatch, method: str, path: str): + handler = _load_handler() + monkeypatch.setattr(handler, "get_client", lambda: pytest.fail("blocked raw mutation must not call WAF")) + + with pytest.raises(handler.WafApiError, match="does not support modifying WAF device state"): + handler.waf_call_mutation({"method": method, "path": path, "body": []}) + + +@pytest.mark.parametrize( + ("action", "params"), + [ + ( + "waf_file_upload", + { + "path": "/rest/file/signature_import", + "file_path": "signature.dat", + }, + ), + ( + "waf_file_request", + {"method": "DELETE", "path": "/rest/file?fileName=tmp"}, + ), + ], +) +def test_file_ops_reject_upgrade_and_import_helpers(monkeypatch, action: str, params: dict[str, Any]): + handler = _load_handler() + monkeypatch.setattr(handler, "get_client", lambda: pytest.fail("blocked file helper must not call WAF")) + + with pytest.raises(handler.WafApiError, match="does not support WAF upgrade or import file operations"): + handler._ACTION_MAP[action](params) + + +@pytest.mark.asyncio +async def test_observability_test_action_uses_readonly_security_log_probe(monkeypatch): + handler = _load_handler() + calls: list[tuple[str, dict[str, Any] | None]] = [] + + class _FakeClient: + def get(self, path: str, query: dict[str, Any] | None = None) -> dict[str, Any]: + calls.append((path, query)) + return {"success": True, "result": []} + + monkeypatch.setattr(handler, "get_client", lambda: _FakeClient()) + + result: ToolResult = await handler.observability( + ToolContext(session_id="s", message_id="m"), + action="test", + ) + + assert result.success is True + assert calls == [ + ( + "/rest/api/websecuritylog", + {"conditions": [{"field": "interval", "operator": 0, "value": "hour"}], "start": 0, "limit": 50}, + ) + ] + + +@pytest.mark.asyncio +async def test_file_ops_test_action_returns_clear_no_probe_error(monkeypatch): + handler = _load_handler() + monkeypatch.setattr( + handler, + "get_client", + lambda: pytest.fail("file_ops action=test must not touch the WAF"), + ) + + result: ToolResult = await handler.file_ops( + ToolContext(session_id="s", message_id="m"), + action="test", + ) + + assert result.success is False + assert "does not define a zero-argument connectivity probe" in result.error diff --git a/tests/tool/test_device_context_prompt.py b/tests/tool/test_device_context_prompt.py index 616d2531a..ba1a420a5 100644 --- a/tests/tool/test_device_context_prompt.py +++ b/tests/tool/test_device_context_prompt.py @@ -6,6 +6,30 @@ from flocks.tool.device.prompt import build_device_context_section from flocks.tool.registry import ParameterType, ToolCategory, ToolInfo, ToolParameter +def _stub_groups(monkeypatch: pytest.MonkeyPatch, groups): + monkeypatch.setattr( + "flocks.tool.device.prompt.list_groups", AsyncMock(return_value=groups) + ) + + +def _stub_devices(monkeypatch: pytest.MonkeyPatch, devices): + monkeypatch.setattr( + "flocks.tool.device.prompt.list_devices", AsyncMock(return_value=devices) + ) + + +def _stub_per_device(monkeypatch: pytest.MonkeyPatch, mapping): + monkeypatch.setattr( + "flocks.tool.device.prompt.list_all_device_tool_settings", + AsyncMock(return_value=mapping), + ) + + +def _stub_tools(monkeypatch: pytest.MonkeyPatch, tools): + monkeypatch.setattr( + "flocks.tool.registry.ToolRegistry.list_tools", lambda: tools + ) + @pytest.mark.asyncio async def test_device_context_deduplicates_tool_sets_and_references_them_from_devices( @@ -81,3 +105,86 @@ async def test_device_context_deduplicates_tool_sets_and_references_them_from_de assert "`tdp_alert_list`" in content assert "**TDP-A**" in content assert "**TDP-B**" in content + + +@pytest.mark.asyncio +async def test_device_context_shows_per_device_disabled_tools_only_for_their_device( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Per-device disabled tools must be annotated only under the device that + has the override — not under sibling devices sharing the same storage_key. + """ + _stub_groups(monkeypatch, [SimpleNamespace(id="room-1", name="上海机房")]) + _stub_devices(monkeypatch, [ + SimpleNamespace( + id="dev-a", group_id="room-1", name="TDP-A", + storage_key="tdp_v3_3_10", enabled=True, + ), + SimpleNamespace( + id="dev-b", group_id="room-1", name="TDP-B", + storage_key="tdp_v3_3_10", enabled=True, + ), + ]) + # Per-device override: TDP-A has tdp_alert_list disabled; TDP-B has no overrides. + _stub_per_device(monkeypatch, { + "dev-a": {"tdp_alert_list": False}, + }) + _stub_tools(monkeypatch, [ + ToolInfo( + name="tdp_event_list", description="List TDP events.", + category=ToolCategory.CUSTOM, parameters=[], + enabled=True, source="device", provider="tdp_v3_3_10", + ), + ToolInfo( + name="tdp_alert_list", description="List TDP alerts.", + category=ToolCategory.CUSTOM, parameters=[], + enabled=True, source="device", provider="tdp_v3_3_10", + ), + ]) + + content = await build_device_context_section() + assert content is not None + + # Locate the per-device blocks by anchoring on the device name line. + a_block_start = content.index("**TDP-A**") + b_block_start = content.index("**TDP-B**") + a_block = content[a_block_start:b_block_start] + b_block = content[b_block_start:] + + # TDP-A block must mention the disabled tool with its OWN device name & id. + assert "已单独禁用" in a_block + assert "`tdp_alert_list`" in a_block + assert "TDP-A" in a_block + assert "dev-a" in a_block + + # TDP-B block must NOT mention any per-device disable. + assert "已单独禁用" not in b_block + + # Defence-in-depth: TDP-A's notice must not leak the wrong device name. + assert "TDP-B" not in a_block.split("已单独禁用", 1)[1].split("\n", 1)[0] + + +@pytest.mark.asyncio +async def test_device_context_omits_notice_when_no_per_device_overrides( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """No notice line should appear when a device has no per-device overrides.""" + _stub_groups(monkeypatch, [SimpleNamespace(id="room-1", name="上海机房")]) + _stub_devices(monkeypatch, [ + SimpleNamespace( + id="dev-a", group_id="room-1", name="TDP-A", + storage_key="tdp_v3_3_10", enabled=True, + ), + ]) + _stub_per_device(monkeypatch, {}) + _stub_tools(monkeypatch, [ + ToolInfo( + name="tdp_event_list", description="List TDP events.", + category=ToolCategory.CUSTOM, parameters=[], + enabled=True, source="device", provider="tdp_v3_3_10", + ), + ]) + + content = await build_device_context_section() + assert content is not None + assert "已单独禁用" not in content diff --git a/tests/tool/test_device_tool_isolation.py b/tests/tool/test_device_tool_isolation.py new file mode 100644 index 000000000..c519222bb --- /dev/null +++ b/tests/tool/test_device_tool_isolation.py @@ -0,0 +1,392 @@ +"""Tests for per-device tool enable/disable isolation (DB-backed). + +Regression suite for the bug where two device instances sharing the same +``storage_key`` (same product version, different names) would have their +tool on/off state coupled — toggling a tool "for Device A" also affected +Device B. + +Fix: store per-device tool overrides in the ``device_tool_settings`` SQLite +table (ON DELETE CASCADE cleans up automatically on device removal). The +override is checked at ToolRegistry.execute() time, AFTER the shared global +tool_settings have been applied. The in-memory ToolInfo.enabled remains a +global/shared concept; per-device gates live exclusively in the execution path. +""" + +from __future__ import annotations + +import uuid +from contextlib import asynccontextmanager +from pathlib import Path +from types import SimpleNamespace +from typing import Optional +from unittest.mock import AsyncMock + +import pytest + +from flocks.tool.registry import Tool, ToolCategory, ToolContext, ToolInfo, ToolRegistry, ToolResult + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +async def db_env(tmp_path: Path, monkeypatch): + """Isolated SQLite DB for each test. + + Importing ``flocks.tool.device.models`` ensures all DDLs (including + device_tool_settings) are registered before Storage.init() runs. + """ + from flocks.config.config import Config + from flocks.storage.storage import Storage + import flocks.tool.device.models # noqa: F401 — registers DDLs + + data_dir = tmp_path / "flocks_data" + data_dir.mkdir(parents=True, exist_ok=True) + monkeypatch.setenv("FLOCKS_DATA_DIR", str(data_dir)) + + Config._global_config = None + Config._cached_config = None + Storage._db_path = None + Storage._initialized = False + + await Storage.init() + yield data_dir + + +@pytest.fixture +def isolated_registry(monkeypatch): + saved_tools = dict(ToolRegistry._tools) + saved_defaults = dict(ToolRegistry._enabled_defaults) + saved_plugin_names = list(ToolRegistry._plugin_tool_names) + saved_dynamic = dict(ToolRegistry._dynamic_tools_by_module) + monkeypatch.setattr(ToolRegistry, "_tools", {}) + monkeypatch.setattr(ToolRegistry, "_enabled_defaults", {}) + monkeypatch.setattr(ToolRegistry, "_plugin_tool_names", []) + monkeypatch.setattr(ToolRegistry, "_dynamic_tools_by_module", {}) + yield + ToolRegistry._tools = saved_tools + ToolRegistry._enabled_defaults = saved_defaults + ToolRegistry._plugin_tool_names = saved_plugin_names + ToolRegistry._dynamic_tools_by_module = saved_dynamic + + +def _device_tool(name: str, storage_key: str, *, enabled: bool = True) -> Tool: + async def _handler(_ctx: ToolContext, **_kwargs) -> ToolResult: + return ToolResult(success=True, output="ok") + + return Tool( + info=ToolInfo( + name=name, + description=f"stub device tool {name}", + category=ToolCategory.CUSTOM, + enabled=enabled, + source="device", + provider=storage_key, + ), + handler=_handler, + ) + + +async def _insert_stub_device(device_id: str, storage_key: str) -> None: + """Insert a minimal device row so FK constraints on device_tool_settings pass.""" + from flocks.storage.storage import Storage + + now = 1_700_000_000_000 + async with Storage.connect(Storage.get_db_path()) as db: + await db.execute(""" + INSERT OR IGNORE INTO device_groups (id, name, sort_order, created_at, updated_at) + VALUES ('default-room', '默认机房', 0, ?, ?) + """, (now, now)) + await db.execute(""" + INSERT OR IGNORE INTO device_integrations + (id, group_id, name, storage_key, service_id, enabled, + verify_ssl, fields, status, created_at, updated_at) + VALUES (?, 'default-room', ?, ?, ?, 1, 0, '{}', 'unknown', ?, ?) + """, (device_id, f"dev-{device_id[:4]}", storage_key, storage_key, now, now)) + await db.commit() + + +# --------------------------------------------------------------------------- +# store.py: device_tool_settings CRUD +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +class TestStoreDeviceToolSettings: + async def test_get_returns_none_when_absent(self, db_env): + from flocks.tool.device.store import get_device_tool_enabled + result = await get_device_tool_enabled("dev-a", "sangfor_af_login") + assert result is None + + async def test_set_and_get_roundtrip(self, db_env): + from flocks.tool.device.store import get_device_tool_enabled, set_device_tool_enabled + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_id, "sangfor_af_login", False) + result = await get_device_tool_enabled(dev_id, "sangfor_af_login") + assert result is False + + async def test_set_does_not_affect_other_device(self, db_env): + from flocks.tool.device.store import get_device_tool_enabled, set_device_tool_enabled + dev_a = str(uuid.uuid4()) + dev_b = str(uuid.uuid4()) + await _insert_stub_device(dev_a, "sangfor_af_v8_0_106") + await _insert_stub_device(dev_b, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_a, "sangfor_af_login", False) + result_b = await get_device_tool_enabled(dev_b, "sangfor_af_login") + assert result_b is None + + async def test_set_does_not_affect_other_tool(self, db_env): + from flocks.tool.device.store import get_device_tool_enabled, set_device_tool_enabled + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_id, "sangfor_af_login", False) + other = await get_device_tool_enabled(dev_id, "sangfor_af_query") + assert other is None + + async def test_delete_returns_true_when_entry_existed(self, db_env): + from flocks.tool.device.store import ( + delete_device_tool_setting, + get_device_tool_enabled, + set_device_tool_enabled, + ) + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_id, "sangfor_af_login", False) + removed = await delete_device_tool_setting(dev_id, "sangfor_af_login") + assert removed is True + assert await get_device_tool_enabled(dev_id, "sangfor_af_login") is None + + async def test_delete_returns_false_when_absent(self, db_env): + from flocks.tool.device.store import delete_device_tool_setting + removed = await delete_device_tool_setting("non-existent", "some_tool") + assert removed is False + + async def test_list_returns_all_settings_for_device(self, db_env): + from flocks.tool.device.store import list_device_tool_settings, set_device_tool_enabled + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_id, "tool_x", False) + await set_device_tool_enabled(dev_id, "tool_y", False) + settings = await list_device_tool_settings(dev_id) + assert set(settings.keys()) == {"tool_x", "tool_y"} + assert settings["tool_x"] is False + assert settings["tool_y"] is False + + async def test_delete_does_not_affect_other_device(self, db_env): + from flocks.tool.device.store import ( + delete_device_tool_setting, + get_device_tool_enabled, + set_device_tool_enabled, + ) + dev_a = str(uuid.uuid4()) + dev_b = str(uuid.uuid4()) + await _insert_stub_device(dev_a, "sangfor_af_v8_0_106") + await _insert_stub_device(dev_b, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_a, "tool_x", False) + await set_device_tool_enabled(dev_b, "tool_x", False) + await delete_device_tool_setting(dev_a, "tool_x") + assert await get_device_tool_enabled(dev_b, "tool_x") is False + + async def test_cascade_delete_on_device_removal(self, db_env): + """Removing the parent device row must cascade to device_tool_settings.""" + from flocks.storage.storage import Storage + from flocks.tool.device.store import get_device_tool_enabled, set_device_tool_enabled + + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_id, "sangfor_af_login", False) + + async with Storage.connect(Storage.get_db_path()) as db: + await db.execute( + "DELETE FROM device_integrations WHERE id = ?", (dev_id,) + ) + await db.commit() + + result = await get_device_tool_enabled(dev_id, "sangfor_af_login") + assert result is None + + async def test_global_tool_settings_unaffected(self, db_env): + """device_tool_settings must not touch flocks.json tool_settings.""" + from flocks.config.config_writer import ConfigWriter + from flocks.tool.device.store import set_device_tool_enabled + + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_id, "sangfor_af_login", False) + + global_setting = ConfigWriter.get_tool_setting("sangfor_af_login") + assert global_setting is None + + async def test_set_bumps_device_revision(self, db_env): + """Cache-invalidation contract: setting a per-device override must bump + the device_revision so the session runner rebuilds the DeviceAssetContext + section in the system prompt. + """ + from flocks.tool.device.store import device_revision, set_device_tool_enabled + + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + before = device_revision() + await set_device_tool_enabled(dev_id, "sangfor_af_login", False) + assert device_revision() > before + + async def test_delete_bumps_device_revision_only_on_real_removal(self, db_env): + """Deleting a non-existent override must NOT bump the revision.""" + from flocks.tool.device.store import ( + delete_device_tool_setting, + device_revision, + set_device_tool_enabled, + ) + + dev_id = str(uuid.uuid4()) + await _insert_stub_device(dev_id, "sangfor_af_v8_0_106") + + # No-op delete: revision unchanged. + rev_a = device_revision() + await delete_device_tool_setting(dev_id, "missing_tool") + assert device_revision() == rev_a + + # Real delete after a set: revision bumps. + await set_device_tool_enabled(dev_id, "sangfor_af_login", False) + rev_b = device_revision() + removed = await delete_device_tool_setting(dev_id, "sangfor_af_login") + assert removed is True + assert device_revision() > rev_b + + async def test_list_all_returns_grouped_by_device(self, db_env): + """Batch helper used by the system-prompt builder to avoid N+1.""" + from flocks.tool.device.store import ( + list_all_device_tool_settings, + set_device_tool_enabled, + ) + + dev_a = str(uuid.uuid4()) + dev_b = str(uuid.uuid4()) + await _insert_stub_device(dev_a, "sangfor_af_v8_0_106") + await _insert_stub_device(dev_b, "sangfor_af_v8_0_106") + await set_device_tool_enabled(dev_a, "tool_x", False) + await set_device_tool_enabled(dev_a, "tool_y", True) + await set_device_tool_enabled(dev_b, "tool_x", False) + + all_settings = await list_all_device_tool_settings() + assert all_settings[dev_a] == {"tool_x": False, "tool_y": True} + assert all_settings[dev_b] == {"tool_x": False} + + async def test_list_all_returns_empty_dict_when_no_overrides(self, db_env): + from flocks.tool.device.store import list_all_device_tool_settings + + result = await list_all_device_tool_settings() + assert result == {} + + +# --------------------------------------------------------------------------- +# ToolRegistry.execute: per-device gate +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +class TestDeviceToolIsolationExecution: + async def _run_tool( + self, + monkeypatch, + db_env, + *, + storage_key: str, + device_id: str, + tool_name: str = "sangfor_af_login", + enabled_in_registry: bool = True, + per_device_enabled: Optional[bool] = None, + ) -> ToolResult: + """Helper: stub the registry + device store, then call execute().""" + tool = _device_tool(tool_name, storage_key, enabled=enabled_in_registry) + monkeypatch.setattr( + "flocks.tool.registry.ToolRegistry.get", lambda _name: tool + ) + + monkeypatch.setattr( + "flocks.tool.device.store.list_devices", + AsyncMock(return_value=[ + SimpleNamespace(id=device_id, storage_key=storage_key, enabled=True), + ]), + ) + + @asynccontextmanager + async def _activate(did: str): + yield True + + monkeypatch.setattr( + "flocks.tool.credential_context.activate_device_credentials", _activate + ) + + # Apply per-device DB setting. + from flocks.tool.device.store import ( + delete_device_tool_setting, + set_device_tool_enabled, + ) + await _insert_stub_device(device_id, storage_key) + if per_device_enabled is False: + await set_device_tool_enabled(device_id, tool_name, False) + elif per_device_enabled is True: + await delete_device_tool_setting(device_id, tool_name) + + return await ToolRegistry.execute(tool_name, device_id=device_id) + + async def test_tool_executes_when_no_per_device_override( + self, monkeypatch, db_env + ): + result = await self._run_tool( + monkeypatch, db_env, + storage_key="sangfor_af_v8_0_106", + device_id=str(uuid.uuid4()), + per_device_enabled=None, + ) + assert result.success is True + + async def test_tool_blocked_by_per_device_disable( + self, monkeypatch, db_env + ): + result = await self._run_tool( + monkeypatch, db_env, + storage_key="sangfor_af_v8_0_106", + device_id=str(uuid.uuid4()), + per_device_enabled=False, + ) + assert result.success is False + assert "已禁用" in (result.error or "") + + async def test_per_device_disable_does_not_affect_other_device( + self, monkeypatch, db_env + ): + """Core regression: disabling tool for dev-a must NOT affect dev-b.""" + from flocks.tool.device.store import set_device_tool_enabled + + dev_a = str(uuid.uuid4()) + dev_b = str(uuid.uuid4()) + storage_key = "sangfor_af_v8_0_106" + + await _insert_stub_device(dev_a, storage_key) + await set_device_tool_enabled(dev_a, "sangfor_af_login", False) + + result = await self._run_tool( + monkeypatch, db_env, + storage_key=storage_key, + device_id=dev_b, + per_device_enabled=None, + ) + assert result.success is True, ( + "Disabling a tool for dev-a must not affect dev-b even if they " + "share the same storage_key (same plugin version)." + ) + + async def test_global_disable_still_blocks_all_devices( + self, monkeypatch, db_env, isolated_registry + ): + """Global tool_settings (enabled=False in registry) must still block ALL devices.""" + tool = _device_tool("sangfor_af_login", "sangfor_af_v8_0_106", enabled=False) + monkeypatch.setattr( + "flocks.tool.registry.ToolRegistry.get", lambda _name: tool + ) + + result = await ToolRegistry.execute("sangfor_af_login", device_id=str(uuid.uuid4())) + assert result.success is False + assert "disabled" in (result.error or "").lower() diff --git a/tests/tool/test_logging_noise.py b/tests/tool/test_logging_noise.py index c3e198b2f..8e5e71391 100644 --- a/tests/tool/test_logging_noise.py +++ b/tests/tool/test_logging_noise.py @@ -121,7 +121,11 @@ def test_tool_registry_api_service_sync_logs_at_debug(monkeypatch) -> None: assert tool.info.enabled is False event, payload = logger.debug.call_args.args assert event == "tool_registry.api_service_sync" - assert payload == {"disabled_tools": 1, "disabled_providers": ["svc"]} + assert payload == { + "disabled_tools": 1, + "disabled_providers": ["svc"], + "restored_tools": 0, + } logger.info.assert_not_called() diff --git a/tests/tool/test_tool_plugin.py b/tests/tool/test_tool_plugin.py index bf05d8952..aa10e2615 100644 --- a/tests/tool/test_tool_plugin.py +++ b/tests/tool/test_tool_plugin.py @@ -13,6 +13,7 @@ from flocks.tool.tool_loader import ( _build_execution_handler, _build_http_handler, + _build_tcp_connector, _extract_response, _json_schema_to_params, _merge_provider_defaults, @@ -1133,6 +1134,80 @@ async def test_response_extract(self): assert result.output == [1, 2] +class TestTcpConnector: + """Regression tests for CLOSE_WAIT socket accumulation under rapid tool calls.""" + + def test_connector_forces_socket_close(self): + """force_close must be True so per-request sockets don't linger in CLOSE_WAIT.""" + captured: Dict[str, Any] = {} + + def _fake_connector(**kwargs): + captured.update(kwargs) + return MagicMock() + + with patch("aiohttp.TCPConnector", side_effect=_fake_connector): + _build_tcp_connector(verify_ssl=False) + + assert captured["force_close"] is True + assert captured["ssl"] is False + + def test_connector_enables_cleanup_closed_only_on_affected_python(self): + """enable_cleanup_closed is only meaningful before the CPython 3.12.7 SSL fix.""" + import sys as _sys + + captured: Dict[str, Any] = {} + + def _fake_connector(**kwargs): + captured.update(kwargs) + return MagicMock() + + with patch("aiohttp.TCPConnector", side_effect=_fake_connector): + _build_tcp_connector(verify_ssl=True) + + if _sys.version_info < (3, 12, 7): + assert captured.get("enable_cleanup_closed") is True + else: + assert "enable_cleanup_closed" not in captured + + @pytest.mark.asyncio + async def test_http_handler_uses_hardened_connector(self): + """The declarative HTTP handler must build its session with the hardened connector.""" + cfg = { + "type": "http", + "method": "GET", + "url": "https://appliance.example.com/api", + "timeout": 10, + } + handler = _build_http_handler(cfg) + + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"ok": True}) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + mock_session = AsyncMock() + mock_session.request = MagicMock(return_value=mock_resp) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + + captured: Dict[str, Any] = {} + + def _fake_connector(**kwargs): + captured.update(kwargs) + return MagicMock() + + ctx = ToolContext(session_id="test", message_id="test") + + with patch("aiohttp.TCPConnector", side_effect=_fake_connector), patch( + "aiohttp.ClientSession", return_value=mock_session + ): + result = await handler(ctx) + + assert result.success is True + assert captured["force_close"] is True + + class TestExecutionHandler: @pytest.mark.asyncio async def test_inline_yaml_execution_loads_but_refuses_to_run_by_default( diff --git a/tests/utils/test_log_compatibility.py b/tests/utils/test_log_compatibility.py index df955bdfe..bcaabd2ac 100644 --- a/tests/utils/test_log_compatibility.py +++ b/tests/utils/test_log_compatibility.py @@ -6,7 +6,7 @@ import time import tempfile from pathlib import Path -from flocks.utils.log import Log, Logger, LogLevel +from flocks.utils.log import Log, Logger, LogLevel, _RotatingTextWriter, rotate_log_file class TestLoggerCompatibility: @@ -225,6 +225,57 @@ def test_object_serialization(self): assert 'data={"nested": "value"}' in output or 'data={"nested":"value"}' in output finally: Log._writer = old_stderr + + def test_large_object_values_are_truncated(self, monkeypatch): + """Test large objects are bounded before being written to logs.""" + logger = Log.create(service="test") + + import io + old_stderr = Log._writer + Log._writer = io.StringIO() + monkeypatch.setenv("FLOCKS_LOG_VALUE_MAX_CHARS", "20") + + try: + logger.info("message", {"data": {"payload": "x" * 100}}) + output = Log._writer.getvalue() + + assert "data=" in output + assert " None: + monkeypatch.setenv("FLOCKS_LOG_DIR", str(tmp_path)) + monkeypatch.setenv("FLOCKS_LOG_MAX_BYTES", "1234") + monkeypatch.setenv("FLOCKS_LOG_BACKUP_COUNT", "2") + + setup_workflow_logging(stream=None) + + logger = logging.getLogger("flocks.workflow") + handlers = [handler for handler in logger.handlers if isinstance(handler, RotatingFileHandler)] + + try: + assert len(handlers) == 1 + assert handlers[0].baseFilename == str(tmp_path / "workflow.log") + assert handlers[0].maxBytes == 1234 + assert handlers[0].backupCount == 2 + finally: + logger.handlers.clear() diff --git a/tests/workflow/test_poller_manager.py b/tests/workflow/test_poller_manager.py new file mode 100644 index 000000000..82fcb86ea --- /dev/null +++ b/tests/workflow/test_poller_manager.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import asyncio +import threading +from typing import Any + +import pytest + +from flocks.workflow import poller_manager +from flocks.workflow.runner import RunWorkflowResult + + +@pytest.mark.asyncio +async def test_restart_disabled_config_reports_stopped(monkeypatch: pytest.MonkeyPatch) -> None: + manager = poller_manager.WorkflowPollerManager() + + async def _fake_read(_key: str) -> dict[str, Any]: + return {"enabled": False} + + monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + + status = await manager.restart_workflow("wf-disabled") + assert status["state"] == "stopped" + assert status["error"] is None + + +@pytest.mark.asyncio +async def test_restart_missing_workflow_reports_failed(monkeypatch: pytest.MonkeyPatch) -> None: + manager = poller_manager.WorkflowPollerManager() + + async def _fake_read(_key: str) -> dict[str, Any]: + return {"enabled": True, "intervalSeconds": 30} + + monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(poller_manager, "read_workflow_from_fs", lambda _workflow_id: None) + + status = await manager.restart_workflow("wf-missing") + assert status["state"] == "failed" + assert status["error"] == "workflow_not_found" + + +@pytest.mark.asyncio +async def test_run_once_injects_dynamic_inputs_and_summary(monkeypatch: pytest.MonkeyPatch) -> None: + manager = poller_manager.WorkflowPollerManager() + captured_inputs: dict[str, Any] = {} + + async def _fake_read(_key: str) -> dict[str, Any]: + return { + "enabled": False, + "timeoutSeconds": 9, + "inputs": {"dedup_source_workflow_name": "stream_alert_denoise_gt_fast"}, + } + + def _fake_run_workflow(*, workflow: Any, inputs: dict[str, Any], timeout_s: int, trace: bool, cancel): # noqa: ANN001 + captured_inputs.update(inputs) + assert workflow == {"start": "n1", "nodes": [], "edges": []} + assert timeout_s == 9 + assert trace is False + assert cancel() is False + return RunWorkflowResult( + status="success", + run_id="run-1", + outputs={ + "load_stats": {"record_count": 7}, + "processed_mark_count": 3, + "processed_cache_size_after": 11, + "channel_notify_status": "sent", + }, + ) + + monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr( + poller_manager, + "read_workflow_from_fs", + lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + ) + monkeypatch.setattr( + poller_manager, + "create_execution_record", + lambda workflow_id, *, input_params=None, exec_id=None: asyncio.sleep(0, result={ + "id": exec_id or f"exec-{workflow_id}", + "workflowId": workflow_id, + "inputParams": input_params or {}, + "status": "running", + "startedAt": 111, + "executionLog": [], + "currentPhase": "queued", + "currentStepIndex": 0, + }), + ) + monkeypatch.setattr( + poller_manager, + "record_execution_result", + lambda workflow_id, exec_id, exec_data: asyncio.sleep(0), + ) + monkeypatch.setattr(poller_manager, "run_workflow", _fake_run_workflow) + + status = await manager.run_once("wf-run-once") + + assert status["lastStatus"] == "success" + assert status["selectedCount"] == 7 + assert status["processedMarkCount"] == 11 + assert status["channelNotifyStatus"] == "sent" + assert status["state"] == "stopped" + assert captured_inputs["dedup_source_workflow_name"] == "stream_alert_denoise_gt_fast" + assert captured_inputs["input_date"] + assert captured_inputs["_trigger"] == "poller" + assert captured_inputs["_poller_run_id"].startswith("poller-") + + +@pytest.mark.asyncio +async def test_run_once_records_execution_and_normalizes_business_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + manager = poller_manager.WorkflowPollerManager() + created_records: list[dict[str, Any]] = [] + recorded_results: list[dict[str, Any]] = [] + + async def _fake_read(_key: str) -> dict[str, Any]: + return { + "enabled": False, + "timeoutSeconds": 9, + "inputs": {"dedup_source_workflow_name": "stream_alert_denoise_gt_fast"}, + } + + async def _fake_create_execution_record( + workflow_id: str, + *, + input_params: dict[str, Any] | None = None, + exec_id: str | None = None, + ) -> dict[str, Any]: + record = { + "id": exec_id or "exec-1", + "workflowId": workflow_id, + "inputParams": input_params or {}, + "status": "running", + "startedAt": 111, + "executionLog": [], + "currentPhase": "queued", + "currentStepIndex": 0, + } + created_records.append(record) + return dict(record) + + async def _fake_record_execution_result( + workflow_id: str, + exec_id: str, + exec_data: dict[str, Any], + ) -> None: + _ = workflow_id, exec_id + recorded_results.append(dict(exec_data)) + + def _fake_run_workflow(*, workflow: Any, inputs: dict[str, Any], timeout_s: int, trace: bool, cancel): # noqa: ANN001 + assert workflow == {"start": "n1", "nodes": [], "edges": []} + assert timeout_s == 9 + assert trace is False + assert cancel() is False + assert inputs["dedup_source_workflow_name"] == "stream_alert_denoise_gt_fast" + return RunWorkflowResult( + status="SUCCEEDED", + run_id="run-1", + steps=2, + last_node_id="finish", + outputs={ + "workflow_success": False, + "reason": "business rule blocked", + "load_stats": {"record_count": 9}, + }, + ) + + monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr( + poller_manager, + "read_workflow_from_fs", + lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + ) + monkeypatch.setattr(poller_manager, "create_execution_record", _fake_create_execution_record) + monkeypatch.setattr(poller_manager, "record_execution_result", _fake_record_execution_result) + monkeypatch.setattr(poller_manager, "run_workflow", _fake_run_workflow) + + status = await manager.run_once("wf-business-failure") + + assert created_records[0]["inputParams"]["_trigger"] == "poller" + assert created_records[0]["inputParams"]["_poller_run_id"].startswith("poller-") + assert recorded_results[0]["status"] == "error" + assert recorded_results[0]["errorMessage"] == "business rule blocked" + assert recorded_results[0]["currentPhase"] == "error" + assert status["lastStatus"] == "error" + assert status["lastError"] == "business rule blocked" + assert status["selectedCount"] == 9 + + +@pytest.mark.asyncio +async def test_no_overlap_skips_when_previous_run_is_still_active( + monkeypatch: pytest.MonkeyPatch, +) -> None: + manager = poller_manager.WorkflowPollerManager() + threading_event = asyncio.Event() + + config = { + "enabled": True, + "intervalSeconds": 1, + "timeoutSeconds": 5, + "noOverlap": True, + "inputs": {}, + } + + def _fake_run_workflow(*, workflow: Any, inputs: dict[str, Any], timeout_s: int, trace: bool, cancel): # noqa: ANN001 + _ = workflow, inputs, timeout_s, trace, cancel + # Keep the run active until the test releases it so a second tick skips. + asyncio.run(asyncio.wait_for(threading_event.wait(), timeout=2.0)) + return RunWorkflowResult(status="success", outputs={"load_stats": {"record_count": 1}}) + + monkeypatch.setattr(poller_manager, "run_workflow", _fake_run_workflow) + monkeypatch.setattr( + poller_manager, + "create_execution_record", + lambda workflow_id, *, input_params=None, exec_id=None: asyncio.sleep(0, result={ + "id": exec_id or f"exec-{workflow_id}", + "workflowId": workflow_id, + "inputParams": input_params or {}, + "status": "running", + "startedAt": 111, + "executionLog": [], + "currentPhase": "queued", + "currentStepIndex": 0, + }), + ) + monkeypatch.setattr( + poller_manager, + "record_execution_result", + lambda workflow_id, exec_id, exec_data: asyncio.sleep(0), + ) + monkeypatch.setattr( + poller_manager, + "read_workflow_from_fs", + lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + ) + + await manager._schedule_run("wf-overlap", {"start": "n1", "nodes": [], "edges": []}, config) + await asyncio.sleep(0.02) + await manager._schedule_run("wf-overlap", {"start": "n1", "nodes": [], "edges": []}, config) + status = manager.get_status("wf-overlap") + + threading_event.set() + await asyncio.sleep(0.02) + + assert status["lastStatus"] == "skipped" + assert status["lastError"] == "previous_run_still_active" + + +@pytest.mark.asyncio +async def test_stop_workflow_keeps_unfinished_run_tracked_until_thread_exits( + monkeypatch: pytest.MonkeyPatch, +) -> None: + manager = poller_manager.WorkflowPollerManager() + release_run = threading.Event() + + async def _fake_create_execution_record( + workflow_id: str, + *, + input_params: dict[str, Any] | None = None, + exec_id: str | None = None, + ) -> dict[str, Any]: + _ = input_params + return { + "id": exec_id or f"exec-{workflow_id}", + "workflowId": workflow_id, + "status": "running", + "startedAt": 111, + "executionLog": [], + "currentPhase": "queued", + "currentStepIndex": 0, + } + + async def _fake_record_execution_result( + workflow_id: str, + exec_id: str, + exec_data: dict[str, Any], + ) -> None: + _ = workflow_id, exec_id, exec_data + + def _fake_run_workflow(*, workflow: Any, inputs: dict[str, Any], timeout_s: int, trace: bool, cancel): # noqa: ANN001 + _ = workflow, inputs, timeout_s, trace, cancel + release_run.wait(timeout=0.2) + return RunWorkflowResult(status="SUCCEEDED", run_id="run-stop") + + monkeypatch.setattr(poller_manager, "RUN_SHUTDOWN_GRACE_SECONDS", 0.01) + monkeypatch.setattr(poller_manager, "create_execution_record", _fake_create_execution_record) + monkeypatch.setattr(poller_manager, "record_execution_result", _fake_record_execution_result) + monkeypatch.setattr(poller_manager, "run_workflow", _fake_run_workflow) + + await manager._schedule_run( + "wf-stop", + {"start": "n1", "nodes": [], "edges": []}, + {"enabled": True, "intervalSeconds": 1, "timeoutSeconds": 5, "noOverlap": True, "inputs": {}}, + ) + await asyncio.sleep(0.02) + + await manager.stop_workflow("wf-stop") + assert manager.get_status("wf-stop")["activeRuns"] == 1 + + release_run.set() + await asyncio.sleep(0.05) + assert manager.get_status("wf-stop")["activeRuns"] == 0 + + +@pytest.mark.asyncio +async def test_start_all_only_restarts_enabled_configs(monkeypatch: pytest.MonkeyPatch) -> None: + manager = poller_manager.WorkflowPollerManager() + restarted: list[str] = [] + + async def _fake_list_keys(_prefix: str) -> list[str]: + return [ + "workflow_poller_config/wf-enabled", + "workflow_poller_config/wf-disabled", + ] + + async def _fake_read(key: str) -> dict[str, Any]: + return {"enabled": key.endswith("wf-enabled")} + + async def _fake_restart(workflow_id: str) -> dict[str, Any]: + restarted.append(workflow_id) + return {"workflowId": workflow_id, "state": "running"} + + monkeypatch.setattr(poller_manager.Storage, "list_keys", _fake_list_keys) + monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr(manager, "restart_workflow", _fake_restart) + + await manager.start_all() + assert restarted == ["wf-enabled"] + + +@pytest.mark.asyncio +async def test_restart_workflow_replaces_existing_task(monkeypatch: pytest.MonkeyPatch) -> None: + manager = poller_manager.WorkflowPollerManager() + config = {"enabled": True, "intervalSeconds": 30, "timeoutSeconds": 10, "noOverlap": True, "inputs": {}} + + async def _fake_read(_key: str) -> dict[str, Any]: + return config + + async def _fake_loop(*args, **kwargs) -> None: # noqa: ANN002, ANN003 + await asyncio.sleep(60) + + monkeypatch.setattr(poller_manager.Storage, "read", _fake_read) + monkeypatch.setattr( + poller_manager, + "read_workflow_from_fs", + lambda _workflow_id: {"workflowJson": {"start": "n1", "nodes": [], "edges": []}}, + ) + monkeypatch.setattr(manager, "_poller_loop", _fake_loop) + + first = await manager.restart_workflow("wf-restart") + first_task = manager._tasks["wf-restart"] + second = await manager.restart_workflow("wf-restart") + second_task = manager._tasks["wf-restart"] + + assert first["state"] == "running" + assert second["state"] == "running" + assert first_task is not second_task + assert first_task.cancelled() or first_task.done() + + await manager.stop_workflow("wf-restart") diff --git a/tests/workflow/test_workflow_history_mode.py b/tests/workflow/test_workflow_history_mode.py new file mode 100644 index 000000000..4ebacd36c --- /dev/null +++ b/tests/workflow/test_workflow_history_mode.py @@ -0,0 +1,88 @@ +from flocks.workflow.runner import run_workflow +from flocks.workflow.repl_runtime import PythonExecRuntime + + +def test_run_workflow_summary_history_does_not_retain_large_step_payloads() -> None: + workflow = { + "start": "produce", + "nodes": [ + { + "id": "produce", + "type": "python", + "code": "\n".join( + [ + "outputs['raw_alerts'] = [{'id': i, 'body': 'x' * 1000} for i in range(200)]", + "outputs['count'] = len(outputs['raw_alerts'])", + ] + ), + }, + { + "id": "consume", + "type": "python", + "code": "\n".join( + [ + "alerts = inputs.get('raw_alerts', [])", + "outputs['final_count'] = len(alerts)", + ] + ), + }, + ], + "edges": [{"from": "produce", "to": "consume"}], + } + + result = run_workflow( + workflow=workflow, + history_mode="summary", + ensure_requirements=False, + ) + + assert result.status == "SUCCEEDED" + assert result.outputs == {"final_count": 200} + assert result.history[0]["outputs"]["raw_alerts"] == { + "_type": "list", + "count": 200, + "preview": [ + {"_type": "dict", "keys": ["id", "body"]}, + {"_type": "dict", "keys": ["id", "body"]}, + {"_type": "dict", "keys": ["id", "body"]}, + ], + } + assert result.history[1]["inputs"]["raw_alerts"]["count"] == 200 + + +def test_run_workflow_summary_outputs_do_not_retain_large_final_payloads() -> None: + workflow = { + "start": "final", + "nodes": [ + { + "id": "final", + "type": "python", + "code": "outputs['items'] = [{'body': 'x' * 1000} for _ in range(200)]", + }, + ], + "edges": [], + } + + result = run_workflow( + workflow=workflow, + history_mode="summary", + ensure_requirements=False, + ) + + assert result.status == "SUCCEEDED" + assert result.outputs["items"]["_type"] == "list" + assert result.outputs["items"]["count"] == 200 + + +def test_python_runtime_can_cleanup_node_globals_after_execute() -> None: + runtime = PythonExecRuntime(cleanup_globals_after_execute=True) + + outputs, _stdout = runtime.execute( + "temporary_payload = 'x' * 1000\noutputs['ok'] = True", + {}, + ) + + assert outputs == {"ok": True} + assert "temporary_payload" not in runtime.globals + assert runtime.globals["outputs"] == {"ok": True} + diff --git a/tui/flocks/provider/transform.ts b/tui/flocks/provider/transform.ts index 2f91fc3d7..e719c221a 100644 --- a/tui/flocks/provider/transform.ts +++ b/tui/flocks/provider/transform.ts @@ -285,7 +285,7 @@ export namespace ProviderTransform { if (id.includes("gemini")) return 1.0 if (id.includes("glm-4.6")) return 1.0 if (id.includes("glm-4.7")) return 1.0 - if (id.includes("minimax-m2")) return 1.0 + if (id.includes("minimax")) return 1.0 if (id.includes("kimi-k2")) { if (id.includes("thinking")) return 1.0 return 0.6 @@ -296,7 +296,7 @@ export namespace ProviderTransform { export function topP(model: Provider.Model) { const id = model.id.toLowerCase() if (id.includes("qwen")) return 1 - if (id.includes("minimax-m2")) { + if (id.includes("minimax")) { return 0.95 } if (id.includes("gemini")) return 0.95 @@ -305,7 +305,7 @@ export namespace ProviderTransform { export function topK(model: Provider.Model) { const id = model.id.toLowerCase() - if (id.includes("minimax-m2")) { + if (id.includes("minimax")) { if (id.includes("m2.1")) return 40 return 20 } diff --git a/uv.lock b/uv.lock index 468d452b5..c44801299 100644 --- a/uv.lock +++ b/uv.lock @@ -58,6 +58,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/b7/76175c7cb4eb73d91ad63c34e29fc4f77c9386bba4a65b53ba8e05ee3c39/aiohttp-3.13.3-cp312-cp312-win_amd64.whl", hash = "sha256:e3531d63d3bdfa7e3ac5e9b27b2dd7ec9df3206a98e0b3445fa906f233264c57", size = 455407, upload-time = "2026-01-03T17:30:44.195Z" }, ] +[[package]] +name = "aiokafka" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/5f/dfc1180fd22d1acdc91949ec36e97199c43742dacb057cb8efed3679ed04/aiokafka-0.14.0.tar.gz", hash = "sha256:8ffdc945798ba4d3d132b705d4244d0a1f493925efb57c637a2ca88ee82794e1", size = 601374, upload-time = "2026-04-29T10:43:03.574Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/9d/3441db94829f9feb802a2f4052df61c0d1a01272accd174c351d7e9e1f6a/aiokafka-0.14.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:284a90d617584d7e42688a181aaa8c2a909d9c658ab9b69c6cf92f4df5c4b320", size = 348458, upload-time = "2026-04-29T10:42:37.243Z" }, + { url = "https://files.pythonhosted.org/packages/a4/10/7297589aac95654596af13301b31da2c9502c80e7e308530ee7a9bd5b9f1/aiokafka-0.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b4f211d9e03a1fc83871a37eefcf307bc0943ee99adae25aa39bd1722e70747b", size = 351057, upload-time = "2026-04-29T10:42:38.69Z" }, + { url = "https://files.pythonhosted.org/packages/26/4e/5c0aa8db717fff0ffb8f3e16deece8f98ded6ca17c6a543b6b20cc9a7f84/aiokafka-0.14.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be517b9b9513eba43ba19961dd770a6e26d08325743093feb47182770d235dd9", size = 1142238, upload-time = "2026-04-29T10:42:39.96Z" }, + { url = "https://files.pythonhosted.org/packages/88/78/322f797b9593a4cc8afd647342fa66b9ad732ee55098e5e084188c6202aa/aiokafka-0.14.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:219d2dc66b97b1aaea100697c928024b6a0348b7baa370b824900054bf86916e", size = 1131567, upload-time = "2026-04-29T10:42:41.542Z" }, + { url = "https://files.pythonhosted.org/packages/8d/0a/a45320778385142299a7fc3ae402152ec1f383537130b8aa8e8587742fad/aiokafka-0.14.0-cp312-cp312-win32.whl", hash = "sha256:1086b470f6c452471603a2d9c8d6933739230c75758d777d8d113ff8112bad68", size = 312160, upload-time = "2026-04-29T10:42:42.811Z" }, + { url = "https://files.pythonhosted.org/packages/a3/fb/7802a0ed69200e3e8e8791df06bd6daf9b00523839d045662de4ff061b18/aiokafka-0.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:bcf3a8f6592d73f45965ca0750bfdfccf2555c8625358175c92f75f2cce1261a", size = 331897, upload-time = "2026-04-29T10:42:43.984Z" }, +] + [[package]] name = "aiosignal" version = "1.4.0" @@ -139,6 +158,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274, upload-time = "2024-11-06T16:41:39.6Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233, upload-time = "2024-11-06T16:41:37.9Z" }, +] + [[package]] name = "asyncssh" version = "2.22.0" @@ -509,11 +537,12 @@ wheels = [ [[package]] name = "flocks" -version = "2026.5.27" +version = "2026.6.3" source = { editable = "." } dependencies = [ { name = "aiofiles" }, { name = "aiohttp" }, + { name = "aiokafka" }, { name = "aiosqlite" }, { name = "anthropic" }, { name = "asyncssh" }, @@ -580,6 +609,7 @@ dev = [ requires-dist = [ { name = "aiofiles", specifier = ">=23.2.1" }, { name = "aiohttp", specifier = ">=3.13.3" }, + { name = "aiokafka", specifier = ">=0.14.0" }, { name = "aiosqlite", specifier = ">=0.19.0" }, { name = "anthropic", specifier = ">=0.86.0" }, { name = "asyncssh", specifier = ">=2.22.0" }, diff --git a/webui/src/api/agent.ts b/webui/src/api/agent.ts index e9434beab..4475825de 100644 --- a/webui/src/api/agent.ts +++ b/webui/src/api/agent.ts @@ -60,6 +60,7 @@ export const agentAPI = { color?: string; mode?: string; model?: { modelID: string; providerID: string }; + delegatable?: boolean; skills?: string[]; tools?: string[]; }) => @@ -72,6 +73,7 @@ export const agentAPI = { temperature?: number; color?: string; model?: { modelID: string; providerID: string }; + delegatable?: boolean; skills?: string[]; tools?: string[]; }) => @@ -80,6 +82,9 @@ export const agentAPI = { updateModel: (name: string, model: { modelID: string; providerID: string } | null, temperature?: number) => client.put(`/api/agent/${name}/model`, { model, temperature }), + setDelegatable: (name: string, delegatable: boolean) => + client.patch(`/api/agent/${name}/delegatable`, { delegatable }), + delete: (name: string) => client.delete(`/api/agent/${name}`), diff --git a/webui/src/api/device.ts b/webui/src/api/device.ts index 357114a98..588bf1a7f 100644 --- a/webui/src/api/device.ts +++ b/webui/src/api/device.ts @@ -89,6 +89,22 @@ export interface DeviceTestRequest { verify_ssl?: boolean; } +// --------------------------------------------------------------------------- +// Per-device tool settings +// --------------------------------------------------------------------------- + +export interface DeviceToolInfo { + name: string; + description: string; + description_cn?: string | null; + /** 全局工具开关(影响所有同版本设备) */ + enabled_global: boolean; + /** 本设备的独立覆盖值;null = 未设置,遵从全局 */ + enabled_device: boolean | null; + /** 最终生效状态 */ + enabled_effective: boolean; +} + export const deviceAPI = { // groups listGroups: () => @@ -121,4 +137,11 @@ export const deviceAPI = { test: (id: string, body?: DeviceTestRequest) => client.post(`/api/devices/${id}/test`, body ?? {}), + + // per-device tool settings + listDeviceTools: (device_id: string) => + client.get(`/api/devices/${device_id}/tools`), + + updateDeviceTool: (device_id: string, tool_name: string, enabled: boolean) => + client.patch(`/api/devices/${device_id}/tools/${tool_name}`, { enabled }), }; diff --git a/webui/src/api/session.test.ts b/webui/src/api/session.test.ts index db8c55596..3303e361e 100644 --- a/webui/src/api/session.test.ts +++ b/webui/src/api/session.test.ts @@ -17,8 +17,10 @@ vi.mock('./client', () => ({ describe('sessionApi message actions', () => { beforeEach(() => { vi.clearAllMocks(); + mockGet.mockResolvedValue({ data: { sessionID: 'session-1', items: [] } }); mockPatch.mockResolvedValue({ data: { ok: true } }); mockPost.mockResolvedValue({ data: { ok: true } }); + mockDelete.mockResolvedValue({ data: { ok: true } }); }); it('updates a message part through the patch endpoint', async () => { @@ -67,4 +69,29 @@ describe('sessionApi message actions', () => { { timeout: 0 }, ); }); + + it('calls prompt queue endpoints', async () => { + const { sessionApi } = await import('./session'); + + await sessionApi.listPromptQueue('session-1'); + await sessionApi.enqueuePrompt('session-1', { + parts: [{ type: 'text', text: 'queued prompt' }], + agent: 'rex', + }); + await sessionApi.updateQueuedPrompt('session-1', 'queue-1', 'edited prompt'); + await sessionApi.removeQueuedPrompt('session-1', 'queue-1'); + await sessionApi.runQueuedPromptNow('session-1', 'queue-2'); + + expect(mockGet).toHaveBeenCalledWith('/api/session/session-1/prompt_queue'); + expect(mockPost).toHaveBeenCalledWith( + '/api/session/session-1/prompt_queue', + { parts: [{ type: 'text', text: 'queued prompt' }], agent: 'rex' }, + ); + expect(mockPatch).toHaveBeenCalledWith( + '/api/session/session-1/prompt_queue/queue-1', + { text: 'edited prompt' }, + ); + expect(mockDelete).toHaveBeenCalledWith('/api/session/session-1/prompt_queue/queue-1'); + expect(mockPost).toHaveBeenCalledWith('/api/session/session-1/prompt_queue/queue-2/run_now'); + }); }); diff --git a/webui/src/api/session.ts b/webui/src/api/session.ts index c81748ed2..26193973e 100644 --- a/webui/src/api/session.ts +++ b/webui/src/api/session.ts @@ -13,6 +13,24 @@ export interface SessionMessagePartPayload { metadata?: Record; } +export interface QueuedPrompt { + id: string; + sessionID: string; + parts: Array>; + agent?: string | null; + model?: Record | null; + variant?: string | null; + messageID?: string | null; + status: 'pending' | 'executing' | string; + createdAt: number; + updatedAt: number; +} + +export interface PromptQueueResponse { + sessionID: string; + items: QueuedPrompt[]; +} + export interface SessionListParams { limit?: number; offset?: number; @@ -117,6 +135,36 @@ export const sessionApi = { return response.data; }, + listPromptQueue: async (sessionId: string): Promise => { + const response = await client.get(`/api/session/${sessionId}/prompt_queue`); + return response.data; + }, + + enqueuePrompt: async (sessionId: string, data: { + parts: Array>; + agent?: string; + model?: Record; + variant?: string; + }) => { + const response = await client.post(`/api/session/${sessionId}/prompt_queue`, data); + return response.data; + }, + + updateQueuedPrompt: async (sessionId: string, queueId: string, text: string) => { + const response = await client.patch(`/api/session/${sessionId}/prompt_queue/${queueId}`, { text }); + return response.data; + }, + + removeQueuedPrompt: async (sessionId: string, queueId: string) => { + const response = await client.delete(`/api/session/${sessionId}/prompt_queue/${queueId}`); + return response.data; + }, + + runQueuedPromptNow: async (sessionId: string, queueId: string) => { + const response = await client.post(`/api/session/${sessionId}/prompt_queue/${queueId}/run_now`); + return response.data; + }, + /** * 更新消息 part */ diff --git a/webui/src/api/tool.ts b/webui/src/api/tool.ts index 7a65771e4..b0299f6b8 100644 --- a/webui/src/api/tool.ts +++ b/webui/src/api/tool.ts @@ -41,8 +41,12 @@ export const toolAPI = { getStatistics: (name: string) => client.get(`/api/tools/${name}/statistics`), - setEnabled: (name: string, enabled: boolean) => - client.patch(`/api/tools/${name}`, { enabled }), + setEnabled: (name: string, enabled: boolean, options?: { device_id?: string }) => + client.patch( + `/api/tools/${name}`, + { enabled }, + options?.device_id ? { params: { device_id: options.device_id } } : undefined, + ), /** * Remove the user-level setting and restore the YAML/registration default diff --git a/webui/src/api/workflow.ts b/webui/src/api/workflow.ts index 2023ab3d8..301e368a8 100644 --- a/webui/src/api/workflow.ts +++ b/webui/src/api/workflow.ts @@ -174,6 +174,61 @@ export interface SyslogListenerStatus { workerCount?: number; } +export interface KafkaConfig { + workflowId?: string; + enabled?: boolean; + inputBroker?: string; + inputTopic?: string; + inputGroupId?: string; + inputKey?: string; + autoOffsetReset?: string; + inputs?: Record; + updatedAt?: number; +} + +/** Runtime state of the Kafka consumer (independent from saved config). */ +export interface KafkaConsumerStatus { + state: 'connecting' | 'running' | 'failed' | 'stopped'; + error?: string | null; + broker?: string; + topic?: string; + groupId?: string; + queueSize?: number; + queueCapacity?: number; + workerCount?: number; +} + +export interface WorkflowPollerConfig { + workflowId?: string; + enabled?: boolean; + intervalSeconds?: number; + timeoutSeconds?: number; + noOverlap?: boolean; + inputs?: Record; + updatedAt?: number; +} + +export interface WorkflowPollerStatus { + workflowId?: string; + state: 'running' | 'stopped' | 'failed'; + error?: string | null; + enabled?: boolean; + intervalSeconds?: number; + timeoutSeconds?: number; + noOverlap?: boolean; + activeRuns?: number; + lastRunAt?: number | null; + lastRunId?: string | null; + lastStatus?: string | null; + lastError?: string | null; + lastDurationMs?: number | null; + selectedCount?: number | null; + processedMarkCount?: number | null; + channelNotifyStatus?: string | null; + kafkaMessageCount?: number | null; + nextRunAt?: number | null; +} + export const workflowAPI = { list: (params?: { category?: string; status?: string; excludeId?: string }) => client.get('/api/workflow', { params }), @@ -249,22 +304,41 @@ export const workflowAPI = { client.get('/api/workflow-services'), saveKafkaConfig: (id: string, config: { + enabled?: boolean; inputBroker?: string; inputTopic?: string; inputGroupId?: string; - outputBroker?: string; - outputTopic?: string; + inputKey?: string; + autoOffsetReset?: string; + inputs?: Record; }) => - client.post<{ ok: boolean }>(`/api/workflow/${id}/kafka-config`, config), + client.post<{ ok: boolean; consumer?: KafkaConsumerStatus }>( + `/api/workflow/${id}/kafka-config`, + config, + ), getKafkaConfig: (id: string) => - client.get<{ - inputBroker?: string; - inputTopic?: string; - inputGroupId?: string; - outputBroker?: string; - outputTopic?: string; - } | null>(`/api/workflow/${id}/kafka-config`), + client.get(`/api/workflow/${id}/kafka-config`), + + getKafkaStatus: (id: string) => + client.get(`/api/workflow/${id}/kafka-status`), + + savePollerConfig: (id: string, config: WorkflowPollerConfig) => + client.post<{ ok: boolean; status?: WorkflowPollerStatus }>( + `/api/workflow/${id}/poller-config`, + config, + ), + + getPollerConfig: (id: string) => + client.get(`/api/workflow/${id}/poller-config`), + + getPollerStatus: (id: string) => + client.get(`/api/workflow/${id}/poller-status`), + + runPollerOnce: (id: string) => + client.post<{ ok: boolean; status?: WorkflowPollerStatus }>( + `/api/workflow/${id}/poller-run-once`, + ), saveSyslogConfig: (id: string, config: { enabled?: boolean; diff --git a/webui/src/components/common/SessionChat.test.ts b/webui/src/components/common/SessionChat.test.ts index f6b9ad1c3..e26218a32 100644 --- a/webui/src/components/common/SessionChat.test.ts +++ b/webui/src/components/common/SessionChat.test.ts @@ -1,10 +1,12 @@ import React from 'react'; -import { render, waitFor } from '@testing-library/react'; +import { render, screen, waitFor } from '@testing-library/react'; +import userEvent from '@testing-library/user-event'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import type { Message } from '@/types'; import { + areChatMessagePartsRenderEqual, buildTodoWriteSummary, dedupeUploadedDocumentAttachments, default as SessionChat, @@ -22,6 +24,14 @@ import { const clientGetMock = vi.fn(); const clientPostMock = vi.fn(); +const sessionApiListPromptQueueMock = vi.fn(); +const sessionApiEnqueuePromptMock = vi.fn(); +const sessionApiUpdateQueuedPromptMock = vi.fn(); +const sessionApiRemoveQueuedPromptMock = vi.fn(); +const sessionApiRunQueuedPromptNowMock = vi.fn(); +const sessionApiUpdateMessagePartMock = vi.fn(); +const sessionApiResendMessageMock = vi.fn(); +const sessionApiRegenerateMessageMock = vi.fn(); const useSessionMessagesMock = vi.fn(); const tMock = (key: string) => ({ 'chat.placeholder': '请输入消息', @@ -30,6 +40,10 @@ const tMock = (key: string) => ({ 'chat.thinking': '思考中...', 'chat.streaming': '继续输出中...', 'chat.compacting': '压缩中...', + 'chat.mention.title': '选择 Agent', + 'chat.mention.navigate': '导航', + 'chat.mention.select': '选择', + 'smartAssistant': '智能助手', }[key] ?? key); const pendingQuestionsHookMock = { pendingQuestions: {}, @@ -49,6 +63,7 @@ const toastMock = { vi.mock('react-i18next', () => ({ useTranslation: () => ({ t: tMock, + i18n: { language: 'zh-CN' }, }), })); @@ -85,14 +100,36 @@ vi.mock('@/api/client', () => ({ getApiBase: () => '', })); +vi.mock('@/api/session', () => ({ + sessionApi: { + listPromptQueue: (...args: unknown[]) => sessionApiListPromptQueueMock(...args), + enqueuePrompt: (...args: unknown[]) => sessionApiEnqueuePromptMock(...args), + updateQueuedPrompt: (...args: unknown[]) => sessionApiUpdateQueuedPromptMock(...args), + removeQueuedPrompt: (...args: unknown[]) => sessionApiRemoveQueuedPromptMock(...args), + runQueuedPromptNow: (...args: unknown[]) => sessionApiRunQueuedPromptNowMock(...args), + updateMessagePart: (...args: unknown[]) => sessionApiUpdateMessagePartMock(...args), + resendMessage: (...args: unknown[]) => sessionApiResendMessageMock(...args), + regenerateMessage: (...args: unknown[]) => sessionApiRegenerateMessageMock(...args), + }, +})); + beforeEach(() => { vi.clearAllMocks(); + localStorage.clear(); Object.defineProperty(window.HTMLElement.prototype, 'scrollIntoView', { configurable: true, value: vi.fn(), }); clientGetMock.mockResolvedValue({ data: {} }); clientPostMock.mockResolvedValue({ data: {} }); + sessionApiListPromptQueueMock.mockResolvedValue({ items: [] }); + sessionApiEnqueuePromptMock.mockResolvedValue({}); + sessionApiUpdateQueuedPromptMock.mockResolvedValue({}); + sessionApiRemoveQueuedPromptMock.mockResolvedValue({}); + sessionApiRunQueuedPromptNowMock.mockResolvedValue({}); + sessionApiUpdateMessagePartMock.mockResolvedValue({}); + sessionApiResendMessageMock.mockResolvedValue({}); + sessionApiRegenerateMessageMock.mockResolvedValue({}); pendingQuestionsHookMock.fetchPendingQuestions.mockResolvedValue(undefined); useSessionMessagesMock.mockReturnValue({ messages: [], @@ -304,6 +341,126 @@ describe('SessionChat standalone thinking indicator', () => { }); }); +describe('SessionChat agent mentions', () => { + const mentionAgents = [ + { + name: 'rex', + description: 'Main orchestrator', + descriptionCn: '主编排 Agent', + mode: 'primary', + permission: [], + options: {}, + skills: [], + tools: [], + }, + { + name: 'explore', + description: 'Explore the codebase', + descriptionCn: '探索代码库', + mode: 'subagent', + native: true, + permission: [], + options: {}, + skills: [], + tools: [], + }, + ]; + + it('shows matching agents when typing @', async () => { + const user = userEvent.setup(); + render(React.createElement(SessionChat, { + sessionId: 'sess-1', + mentionAgents, + })); + + await user.type(screen.getByPlaceholderText('请输入消息'), '@ex'); + + expect(screen.getByText('@explore')).toBeInTheDocument(); + expect(screen.getByText('探索代码库')).toBeInTheDocument(); + }); + + it('routes one message to the mentioned agent without changing the default agent', async () => { + const user = userEvent.setup(); + render(React.createElement(SessionChat, { + sessionId: 'sess-1', + agentName: 'rex', + mentionAgents, + })); + + await user.type(screen.getByPlaceholderText('请输入消息'), '@explore summarize this file{enter}'); + + await waitFor(() => { + expect(clientPostMock).toHaveBeenCalledWith( + '/api/session/sess-1/prompt_async', + expect.objectContaining({ + agent: 'explore', + parts: expect.any(Array), + }), + ); + }); + }); + + it('queues streaming messages to the mentioned agent', async () => { + const user = userEvent.setup(); + render(React.createElement(SessionChat, { + sessionId: 'sess-1', + agentName: 'rex', + mentionAgents, + initialMessage: 'start streaming', + })); + + await waitFor(() => { + expect(clientPostMock).toHaveBeenCalledWith( + '/api/session/sess-1/prompt_async', + expect.objectContaining({ parts: expect.any(Array) }), + ); + }); + + sessionApiEnqueuePromptMock.mockClear(); + await user.type(screen.getByRole('textbox'), '@explore queued message{enter}'); + + await waitFor(() => { + expect(sessionApiEnqueuePromptMock).toHaveBeenCalledWith( + 'sess-1', + expect.objectContaining({ + agent: 'explore', + parts: expect.any(Array), + }), + ); + }); + }); + + it('queues streaming messages to the default agent when no mention is provided', async () => { + const user = userEvent.setup(); + render(React.createElement(SessionChat, { + sessionId: 'sess-1', + agentName: 'rex', + mentionAgents, + initialMessage: 'start streaming', + })); + + await waitFor(() => { + expect(clientPostMock).toHaveBeenCalledWith( + '/api/session/sess-1/prompt_async', + expect.objectContaining({ parts: expect.any(Array) }), + ); + }); + + sessionApiEnqueuePromptMock.mockClear(); + await user.type(screen.getByRole('textbox'), 'queued message{enter}'); + + await waitFor(() => { + expect(sessionApiEnqueuePromptMock).toHaveBeenCalledWith( + 'sess-1', + expect.objectContaining({ + agent: 'rex', + parts: expect.any(Array), + }), + ); + }); + }); +}); + describe('truncateToolDisplayText', () => { it('returns short text unchanged', () => { expect(truncateToolDisplayText('bash')).toBe('bash'); @@ -382,3 +539,73 @@ describe('shouldRefetchFinishedMessage', () => { })).toBe(true); }); }); + +describe('areChatMessagePartsRenderEqual', () => { + it('detects streamed text updates even when a later tool part exists', () => { + const sharedToolPart = { + id: 'tool-1', + type: 'tool', + tool: 'todowrite', + state: { status: 'running', metadata: { step: 1 } }, + } as Message['parts'][number]; + + expect(areChatMessagePartsRenderEqual( + [ + { id: 'text-1', type: 'text', text: '现在生成简化版 wor' } as Message['parts'][number], + sharedToolPart, + ], + [ + { id: 'text-1', type: 'text', text: '现在生成简化版 workflow.json' } as Message['parts'][number], + sharedToolPart, + ], + )).toBe(false); + }); + + it('keeps skipping rerenders when semantically identical parts are recreated', () => { + expect(areChatMessagePartsRenderEqual( + [ + { + id: 'tool-1', + type: 'tool', + tool: 'question', + state: { status: 'completed', metadata: { label: 'done' } }, + } as Message['parts'][number], + ], + [ + { + id: 'tool-1', + type: 'tool', + tool: 'question', + state: { status: 'completed', metadata: { label: 'done' } }, + } as Message['parts'][number], + ], + )).toBe(true); + }); + + it('detects legacy tool payload updates that still drive the UI', () => { + expect(areChatMessagePartsRenderEqual( + [ + { + id: 'tool-call-1', + type: 'toolCall', + toolCall: { + id: 'call-1', + name: 'question', + params: { prompt: 'first' }, + }, + } as Message['parts'][number], + ], + [ + { + id: 'tool-call-1', + type: 'toolCall', + toolCall: { + id: 'call-1', + name: 'question', + params: { prompt: 'updated' }, + }, + } as Message['parts'][number], + ], + )).toBe(false); + }); +}); diff --git a/webui/src/components/common/SessionChat.tsx b/webui/src/components/common/SessionChat.tsx index 0a183f60d..fe648e249 100644 --- a/webui/src/components/common/SessionChat.tsx +++ b/webui/src/components/common/SessionChat.tsx @@ -17,7 +17,7 @@ */ import { useState, useCallback, useRef, useEffect, useMemo, memo } from 'react'; -import { Send, Loader2, ChevronDown, Square, Copy, User, FileText, AlertCircle, X, RefreshCw, Pencil, Save, ImageIcon, Paperclip, ArrowUp, Clock, CheckCircle2, XCircle, Brain } from 'lucide-react'; +import { Send, Loader2, ChevronDown, Square, Copy, User, FileText, AlertCircle, X, RefreshCw, Pencil, Save, ImageIcon, Paperclip, ArrowUp, Clock, CheckCircle2, XCircle, Brain, Trash2, Bot } from 'lucide-react'; import { StreamingMarkdown } from './StreamingMarkdown'; import { useTranslation } from 'react-i18next'; import LoadingSpinner from './LoadingSpinner'; @@ -29,12 +29,14 @@ import { useSessionMessages } from '@/hooks/useSessions'; import { useSSE, type SSEConnectionStatus } from '@/hooks/useSSE'; import { useReasoningToggle } from '@/hooks/useReasoningToggle'; import { usePendingQuestions, type PendingQuestion } from '@/hooks/usePendingQuestions'; -import { sessionApi } from '@/api/session'; +import { sessionApi, type QueuedPrompt } from '@/api/session'; import client, { getApiBase } from '@/api/client'; import { commandAPI, type Command } from '@/api/skill'; +import type { Agent } from '@/api/agent'; import { useToast } from './Toast'; import { workspaceAPI } from '@/api/workspace'; import { formatSmartTime } from '@/utils/time'; +import { getAgentDisplayDescription } from '@/utils/agentDisplay'; import { FILE_INPUT_ACCEPT_IMAGES, batchCompressOptions, @@ -105,6 +107,8 @@ export interface SessionChatProps { onInitialMessageConsumed?: () => void; /** Agent name to include in prompt_async requests */ agentName?: string; + /** Agents available for one-turn @mention routing. */ + mentionAgents?: Agent[]; /** Display configuration (compact, showActions, showTimestamp) */ display?: SessionChatDisplay; /** Custom welcome content when no messages. Can be a render prop receiving setInput. */ @@ -130,7 +134,7 @@ export interface SessionChatProps { * session id) directly without an empty ``async (..) => { await ... }`` * shim. */ - onCreateAndSend?: (text: string, imageParts?: ImagePartData[]) => Promise | unknown; + onCreateAndSend?: (text: string, imageParts?: ImagePartData[], agentOverride?: string) => Promise | unknown; /** Called when the user sends "/new" to create a new session */ onCreateNewSession?: () => Promise | void; /** @@ -156,6 +160,59 @@ interface ComposerAttachment { error?: string; } +type UploadedDocumentAttachmentLike = { + id?: string; + status?: AttachmentStatus; + workspacePath?: string; + isImage?: boolean; +}; + +function isSuccessfulUploadedDocumentAttachment( + attachment: UploadedDocumentAttachmentLike, +): attachment is UploadedDocumentAttachmentLike & { status: 'success'; workspacePath: string; isImage?: false } { + return ( + attachment.status === 'success' + && !attachment.isImage + && typeof attachment.workspacePath === 'string' + && attachment.workspacePath.length > 0 + ); +} + +export function dedupeUploadedDocumentAttachments(items: T[]): T[] { + const latestIndexByPath = new Map(); + + items.forEach((attachment, index) => { + if (isSuccessfulUploadedDocumentAttachment(attachment)) { + latestIndexByPath.set(attachment.workspacePath, index); + } + }); + + return items.filter((attachment, index) => { + if (!isSuccessfulUploadedDocumentAttachment(attachment)) { + return true; + } + return latestIndexByPath.get(attachment.workspacePath) === index; + }); +} + +export function listUploadedDocumentPaths(items: UploadedDocumentAttachmentLike[]): string[] { + const seen = new Set(); + const paths: string[] = []; + + items.forEach((attachment) => { + if (!isSuccessfulUploadedDocumentAttachment(attachment)) { + return; + } + if (seen.has(attachment.workspacePath)) { + return; + } + seen.add(attachment.workspacePath); + paths.push(attachment.workspacePath); + }); + + return paths; +} + // Composer drafts are persisted to ``localStorage`` so navigating away from // the page (e.g. clicking the sidebar to open Agents / Workflows) and coming // back doesn't lose the half-typed message. Keyed per session so two sessions @@ -365,6 +422,82 @@ export function getUserAvatarSpacerClassName(compact: boolean): string { return compact ? 'h-3.5' : 'h-4'; } +function areToolStatesRenderEqual( + prevState?: ToolState, + nextState?: ToolState, +): boolean { + if (prevState === nextState) return true; + if ( + prevState?.status !== nextState?.status || + prevState?.title !== nextState?.title || + prevState?.error !== nextState?.error || + prevState?.time?.start !== nextState?.time?.start || + prevState?.time?.end !== nextState?.time?.end + ) { + return false; + } + + return ( + JSON.stringify(prevState?.input) === JSON.stringify(nextState?.input) + && JSON.stringify(prevState?.output) === JSON.stringify(nextState?.output) + && JSON.stringify(prevState?.metadata) === JSON.stringify(nextState?.metadata) + ); +} + +function areLegacyToolPayloadsRenderEqual( + prevPayload?: MessagePart['toolCall'] | MessagePart['toolResult'], + nextPayload?: MessagePart['toolCall'] | MessagePart['toolResult'], +): boolean { + if (prevPayload === nextPayload) return true; + return JSON.stringify(prevPayload) === JSON.stringify(nextPayload); +} + +export function areChatMessagePartsRenderEqual( + prevParts?: MessagePart[], + nextParts?: MessagePart[], +): boolean { + if (prevParts === nextParts) return true; + if ((prevParts?.length ?? 0) !== (nextParts?.length ?? 0)) return false; + + const total = prevParts?.length ?? 0; + for (let i = 0; i < total; i++) { + const prevPart = prevParts?.[i]; + const nextPart = nextParts?.[i]; + + if (prevPart === nextPart) continue; + if (!prevPart || !nextPart) return false; + + if ( + prevPart.id !== nextPart.id || + prevPart.type !== nextPart.type || + prevPart.text !== nextPart.text || + prevPart.thinking !== nextPart.thinking || + prevPart.synthetic !== nextPart.synthetic || + prevPart.ignored !== nextPart.ignored || + prevPart.tool !== nextPart.tool || + prevPart.callID !== nextPart.callID || + prevPart.mime !== nextPart.mime || + prevPart.filename !== nextPart.filename || + prevPart.url !== nextPart.url || + prevPart.image?.url !== nextPart.image?.url || + prevPart.image?.alt !== nextPart.image?.alt + ) { + return false; + } + + if (!areToolStatesRenderEqual(prevPart.state, nextPart.state)) { + return false; + } + if (!areLegacyToolPayloadsRenderEqual(prevPart.toolCall, nextPart.toolCall)) { + return false; + } + if (!areLegacyToolPayloadsRenderEqual(prevPart.toolResult, nextPart.toolResult)) { + return false; + } + } + + return true; +} // ============================================================================ // Main component @@ -385,40 +518,165 @@ function isAllowedUploadFile(file: File): boolean { return ALLOWED_UPLOAD_EXTENSIONS.has(getFileExtension(file.name)); } -function isUploadedDocumentAttachment( - attachment: T, -): attachment is T & { workspacePath: string } { - return attachment.status === 'success' && !attachment.isImage && Boolean(attachment.workspacePath); +function getQueuedPromptText(item: QueuedPrompt): string { + const textPart = item.parts.find((part) => part.type === 'text' && typeof part.text === 'string'); + return typeof textPart?.text === 'string' ? textPart.text : ''; } -export function dedupeUploadedDocumentAttachments(items: T[]): T[] { - const latestIndexByPath = new Map(); - items.forEach((item, index) => { - if (isUploadedDocumentAttachment(item)) { - latestIndexByPath.set(item.workspacePath, index); - } - }); - return items.filter((item, index) => ( - !isUploadedDocumentAttachment(item) || latestIndexByPath.get(item.workspacePath) === index - )); +interface QueuedPromptPanelProps { + items: QueuedPrompt[]; + expanded: boolean; + editingId: string | null; + editingText: string; + actionId: string | null; + t: ReturnType['t']; + onToggle: () => void; + onEditStart: (item: QueuedPrompt) => void; + onEditChange: (text: string) => void; + onEditCancel: () => void; + onEditSave: (item: QueuedPrompt) => void; + onRemove: (item: QueuedPrompt) => void; + onRunNow: (item: QueuedPrompt) => void; } -export function listUploadedDocumentPaths(items: T[]): string[] { - return dedupeUploadedDocumentAttachments(items) - .filter(isUploadedDocumentAttachment) - .map((item) => item.workspacePath); +function QueuedPromptPanel({ + items, + expanded, + editingId, + editingText, + actionId, + t, + onToggle, + onEditStart, + onEditChange, + onEditCancel, + onEditSave, + onRemove, + onRunNow, +}: QueuedPromptPanelProps) { + if (items.length === 0) return null; + + return ( +
+ + {expanded && ( +
+ {items.map((item) => { + const isEditing = editingId === item.id; + const isBusy = actionId === item.id || item.status === 'executing'; + const text = getQueuedPromptText(item); + return ( +
+
+
+ {isEditing ? ( +