diff --git a/.gitignore b/.gitignore index 9ddf46223d..603c631805 100644 --- a/.gitignore +++ b/.gitignore @@ -264,3 +264,4 @@ docker/office/bisheng/*.gz CLAUDE.md !src/backend/bisheng/telemetry_search/**/*.pyc +.omx/ diff --git a/AGENT.md b/AGENT.md new file mode 100644 index 0000000000..14be34e1eb --- /dev/null +++ b/AGENT.md @@ -0,0 +1,89 @@ +## 基本行为约束 + +### 1. 直接解决问题,禁止废话 +- 不要寒暄 +- 不要重复用户问题 +- 不要输出无信息量的过渡句、总结句、鼓励句、评价句 +- 不要使用“可以”“当然”“没问题”“你这个问题很好”之类的表达 +- 直接进入分析、判断和方案输出 + +### 2. 优先永久性解决方案,禁止临时绕过 +- 默认优先选择可维护、可复用、可扩展、可验证的方案 +- 如果存在“根因修复”和“临时绕过”两种路径,优先输出根因修复 +- 不要把 hack、patch、手工补丁、一次性操作伪装成正式方案 +- 如果只能提供临时方案,必须明确标注这是临时措施,并说明其局限、风险、替代的正式方案 + +### 3. 结合上下文理解用户目的,禁止直接猜测 +- 回答前先判断: + - 用户显式提出的问题是什么 + - 用户真正要达成的目标是什么 + - 当前上下文中有哪些已知约束 +- 不要脱离上下文按字面机械作答 +- 不要凭空补充不存在的前提 +- 当信息不足时: + - 先基于已有信息给出条件化判断 + - 明确指出哪些部分是已知,哪些部分未知 + - 不要把推测当成事实 + +## 决策原则 + +### 根因优先 +输出方案时,优先按以下顺序思考: +1. 问题的根因是什么 +2. 是否可以从结构、流程、接口、数据、权限、配置、架构层面彻底解决 +3. 该方案是否能避免同类问题再次发生 +4. 该方案的维护成本是否可接受 + +### 用户目标优先 +不要只回答表面问题,要判断用户是在: +- 要一个定义 +- 要一个判断 +- 要一个方案 +- 要一个可落地实现 +- 要一个权衡分析 +输出必须匹配用户真实任务层级。 + +### 上下文一致性 +如果用户前文已经给出: +- 业务背景 +- 系统架构 +- 术语定义 +- 限制条件 +- 偏好方案 +则必须沿用这些上下文,不要重新发明一套设定。 + +## 输出要求 + +### 信息组织 +优先输出以下结构之一,按任务类型自动选择: +- 问题分析 → 根因 → 方案 → 风险/边界 +- 目标 → 约束 → 可选方案对比 → 推荐方案 +- 结论 → 依据 → 落地步骤 +- 定义 → 与相近概念的区别 → 在当前场景下的含义 + +### 表达要求 +- 用精确术语,不用空泛表达 +- 能具体就具体,不要泛泛而谈 +- 能落地到机制、规则、数据结构、流程,就不要只停留在概念层 +- 不要为了显得全面而输出无关内容 + +## 禁止行为 +- 禁止为了“显得聪明”而过度延展 +- 禁止把不确定内容说成确定事实 +- 禁止忽略用户已有上下文重新回答 +- 禁止优先给表面 workaround 而不说明正式方案 +- 禁止只给正确但不可执行的抽象建议 + +## 遇到信息不足时 +不要直接猜答案。改为: +1. 先明确当前已知信息 +2. 给出在这些信息下最合理的分析 +3. 标出需要额外信息才可确定的部分 +4. 如果可能,给出分条件方案 + +## 最终标准 +你的回答必须同时满足: +- 对准用户真正目标 +- 尽量一次解决,而不是临时糊住 +- 与已有上下文一致 +- 结论可执行、可维护、可验证 \ No newline at end of file diff --git a/docker/bisheng/entrypoint.sh b/docker/bisheng/entrypoint.sh index 82082fcfb9..d75dfdce57 100644 --- a/docker/bisheng/entrypoint.sh +++ b/docker/bisheng/entrypoint.sh @@ -29,9 +29,14 @@ start_default(){ celery -A bisheng.worker.main worker -l info -c 100 -P threads -Q celery -n celery@%h } +start_min_worker(){ + # 最小化worker进程数,减少资源占用 + celery -A bisheng.worker.main worker -l info -c 100 -P threads -Q knowledge_celery,workflow_celery,celery -n min_worker@%h +} + if [ "$start_mode" = "api" ]; then echo "Starting API server..." - uvicorn bisheng.main:app --host 0.0.0.0 --port 7860 --no-access-log --workers 8 + uvicorn bisheng.main:app --host 0.0.0.0 --port 7860 --no-access-log --workers 1 elif [ "$start_mode" = "knowledge" ]; then echo "Starting Knowledge Celery worker..." start_knowledge @@ -49,14 +54,8 @@ elif [ "$start_mode" = "linsight" ]; then start_linsight elif [ "$start_mode" = "worker" ]; then echo "Starting All worker..." - # 处理知识库相关任务的worker - start_knowledge & - # 处理工作流相关任务的worker - start_workflow & - # 处理linsight相关任务的worker - start_linsight & - # 默认其他任务的执行worker,目前是定时统计埋点数据 - start_default & + # 最小化worker进程数,减少资源占用 + start_min_worker & start_beat echo "All workers started successfully." diff --git a/docker/docker-compose.override.yml b/docker/docker-compose.override.yml new file mode 100644 index 0000000000..7e1b7152f4 --- /dev/null +++ b/docker/docker-compose.override.yml @@ -0,0 +1,8 @@ +services: + backend: + volumes: + - ../src/backend/bisheng:/app/bisheng + + backend_worker: + volumes: + - ../src/backend/bisheng:/app/bisheng diff --git a/docs/bisheng-workflow-mcp.md b/docs/bisheng-workflow-mcp.md new file mode 100644 index 0000000000..4dcca9adbe --- /dev/null +++ b/docs/bisheng-workflow-mcp.md @@ -0,0 +1,828 @@ +# Bisheng Workflow MCP 接入说明 + +## 概述 + +Bisheng 提供一个面向外部 Agent 的 Workflow Authoring MCP 服务,支持: + +- 发现可编辑 workflow +- 读取 workflow manifest / version / graph +- 发现可用 node type +- 读取 node template +- 创建 workflow draft +- 读取 workflow 节点 +- 读取节点可编辑参数 +- 原子编辑 workflow graph +- 编辑节点参数 +- 校验 workflow +- 发布 workflow + +服务地址: + +- MCP: `/mcp` +- MCP token 申请接口: `POST /api/v1/user/mcp_token` + +注意: + +- `/mcp` 只接受 **Bisheng MCP token** +- 不接受普通 Bisheng 登录 token 直接调用 + +--- + +## 服务端配置 + +建议至少配置这些项: + +- `BISHENG_MCP_ALLOWED_ORIGINS` + - 用逗号分隔的 origin 白名单 + - 例如:`https://clawith.example.com,https://clawith-dev.example.com` +- `JWT_SECRET` + - 用于签发和校验 Bisheng token / MCP token +- `redis_url` + - 用于校验 Bisheng 主 session 是否仍然有效 + +说明: + +- 如果未配置 `BISHENG_MCP_ALLOWED_ORIGINS`,本地默认只做 `localhost/127.0.0.1/同 host` 的基础放行 +- 生产环境应显式配置 `BISHENG_MCP_ALLOWED_ORIGINS` + +--- + +## 认证模型 + +### 1. 先有 Bisheng 登录态 + +你需要先在 Bisheng 完成正常登录,拿到普通 Bisheng access token。 + +这个 token 仍然是 Bisheng 主站登录态,不建议直接交给 MCP client 长期使用。 + +### 2. 再换取 MCP token + +使用普通 Bisheng access token 调: + +```http +POST /api/v1/user/mcp_token +Authorization: Bearer +Content-Type: application/json +``` + +请求体: + +```json +{ + "expires_in": 1800 +} +``` + +返回示例: + +```json +{ + "status_code": 200, + "status_message": "SUCCESS", + "data": { + "access_token": "", + "token_type": "Bearer", + "expires_in": 1800, + "scopes": [ + "workflow.read", + "workflow.write", + "workflow.publish" + ], + "audience": "bisheng-workflow-mcp" + } +} +``` + +### 3. 用 MCP token 调 `/mcp` + +```http +Authorization: Bearer +``` + +MCP token 特性: + +- audience 固定为 `bisheng-workflow-mcp` +- token type 固定为 `mcp_access_token` +- 绑定当前 Bisheng 登录 session +- 如果 Bisheng 主 session 失效或被替换,MCP token 也会失效 +- 默认有效期 30 分钟,可通过 `expires_in` 调整,范围 `60-3600` 秒 + +--- + +## Scope + +当前 MCP token 默认带 3 个 scope: + +- `workflow.read` +- `workflow.write` +- `workflow.publish` + +tool 权限要求: + +- `ping`: 任意已认证 MCP token +- `whoami`: 任意已认证 MCP token +- `list_workflows`: `workflow.read` +- `get_workflow`: `workflow.read` +- `get_workflow_versions`: `workflow.read` +- `get_workflow_graph`: `workflow.read` +- `list_node_types`: `workflow.read` +- `get_node_template`: `workflow.read` +- `list_workflow_nodes`: `workflow.read` +- `get_workflow_node_params`: `workflow.read` +- `create_workflow_draft`: `workflow.write` +- `update_workflow_draft`: `workflow.write` +- `update_workflow_node_params`: `workflow.write` +- `add_node`: `workflow.write` +- `remove_node`: `workflow.write` +- `connect_nodes`: `workflow.write` +- `disconnect_edge`: `workflow.write` +- `validate_workflow`: `workflow.write` +- `publish_workflow`: `workflow.publish` + +--- + +## Tool 列表 + +### `ping` + +用于测试 MCP 连通性和当前认证状态。 + +返回: + +```json +{ + "ok": true, + "service": "bisheng-workflow-mcp", + "authenticated": true, + "user_id": 1, + "user_name": "admin", + "scopes": [ + "workflow.read", + "workflow.write", + "workflow.publish" + ] +} +``` + +### `whoami` + +返回当前 MCP token 对应的 Bisheng 用户。 + +### `list_workflows` + +列出当前用户有 authoring 权限的 workflow。 + +返回会带: + +- `flow_id` +- `name` +- `description` +- `status` +- `current_version_id` +- `editable_version_id` +- `draft_revision` +- `schema_version` + +### `get_workflow` + +读取单个 workflow 的 manifest 信息。 + +### `get_workflow_versions` + +列出 workflow 全部版本摘要。 + +返回会带: + +- `version_id` +- `name` +- `description` +- `is_current` +- `is_editable` +- `is_external_draft` +- `original_version_id` +- `draft_revision` +- `schema_version` + +### `get_workflow_graph` + +读取 workflow 当前可编辑版本的标准化 graph。 + +输入: + +```json +{ + "flow_id": "" +} +``` + +可选传 `version_id` 指定版本。 + +返回会带: + +- `flow_id` +- `version_id` +- `draft_revision` +- `schema_version` +- `nodes` +- `edges` + +每个 `node` 会包含: + +- `id` +- `type` +- `name` +- `tab` +- `param_keys` +- `params` + +### `list_node_types` + +列出当前 Workflow Authoring MCP 支持发现的 node type。 + +返回会带: + +- `type` +- `display_name` +- `description` +- `param_keys` +- `dynamic_template` +- `schema_version` + +### `get_node_template` + +读取单个 node type 的标准化 template。 + +输入: + +```json +{ + "node_type": "llm" +} +``` + +返回会带: + +- `node_type` +- `display_name` +- `description` +- `tab` +- `groups` +- `params` +- `dynamic_template` +- `schema_version` + +### `create_workflow_draft` + +创建一个新的 workflow draft。 + +行为说明: + +- `graph_data` 允许为空图 +- 如果初始图缺 `start`,MCP 会自动补一个 `start` 节点并连到入口节点 +- 如果初始图缺 `end`,MCP 会自动补一个 `end` 节点并连到终点节点 +- 对空图,MCP 会直接生成最小合法 scaffold:`start -> end` + +输入: + +```json +{ + "name": "demo-workflow", + "description": "created by agent", + "guide_word": "", + "graph_data": { + "nodes": [], + "edges": [] + } +} +``` + +返回: + +```json +{ + "ok": true, + "flow_id": "", + "version_id": 12, + "status": "draft", + "draft_revision": 1 +} +``` + +### `list_workflow_nodes` + +列出当前 workflow 可编辑版本中的节点摘要。 + +输入: + +```json +{ + "flow_id": "" +} +``` + +返回: + +```json +{ + "ok": true, + "flow_id": "", + "version_id": 12, + "draft_revision": 3, + "nodes": [ + { + "id": "node-1", + "type": "llm", + "name": "LLM", + "param_keys": [ + "system_prompt", + "user_prompt", + "model_id", + "temperature" + ] + } + ] +} +``` + +### `get_workflow_node_params` + +读取单个节点当前可编辑参数。 + +输入: + +```json +{ + "flow_id": "", + "node_id": "" +} +``` + +返回会带: + +- `draft_revision` +- `node_type` +- `node_name` +- `params` + +每个 param 字段会包含: + +- `display_name` +- `group_name` +- `type` +- `required` +- `show` +- `options` +- `scope` +- `placeholder` +- `refresh` +- `value` + +### `update_workflow_node_params` + +编辑单个节点参数。 + +输入: + +```json +{ + "flow_id": "", + "node_id": "", + "updates": { + "temperature": 0.3, + "system_prompt": "new prompt" + }, + "expected_revision": 3 +} +``` + +### `get_condition_node` + +读取条件节点的结构化分支配置。 + +输入: + +```json +{ + "flow_id": "", + "node_id": "condition_1234" +} +``` + +返回: + +```json +{ + "ok": true, + "flow_id": "", + "version_id": 12, + "draft_revision": 4, + "node_id": "condition_1234", + "node_name": "Condition Node", + "condition_cases": [ + { + "id": "case_a", + "operator": "and", + "conditions": [ + { + "id": "rule_1", + "left_var": "score", + "comparison_operation": "greater_than", + "right_value_type": "const", + "right_value": "80", + "variable_key_value": {} + } + ], + "variable_key_value": {} + } + ], + "route_handles": [ + "case_a", + "right_handle" + ], + "outgoing_edges": { + "case_a": [ + { + "edge_id": "edge_1", + "target_node_id": "node-2", + "target_handle": "input" + } + ] + } +} +``` + +说明: + +- `condition_cases[].id` 就是该分支在 graph 里的 `source_handle` +- 默认兜底分支固定是 `right_handle` +- 分支连线仍然通过 `connect_nodes` / `disconnect_edge` 操作 +- 每个 `condition_cases[].id` 都必须有对应的出边,且 `right_handle` 也必须有兜底出边 + +### `update_condition_node` + +更新一个已有条件节点的结构化条件分支配置。 + +输入: + +```json +{ + "flow_id": "", + "node_id": "condition_1234", + "condition_cases": [ + { + "id": "case_a", + "operator": "and", + "conditions": [ + { + "id": "rule_1", + "left_var": "score", + "comparison_operation": "greater_than_or_equal", + "right_value_type": "const", + "right_value": "90", + "variable_key_value": {} + } + ], + "variable_key_value": {} + } + ], + "expected_revision": 4 +} +``` + +返回: + +```json +{ + "ok": true, + "flow_id": "", + "version_id": 12, + "status": "draft", + "draft_revision": 5, + "node_id": "condition_1234" +} +``` + +注意: + +- 如果你修改了 `condition_cases[].id`,必须同步调整对应边的 `source_handle` +- 如果 case id 和出边 handle 不一致,服务端会拒绝保存 + +返回: + +```json +{ + "ok": true, + "flow_id": "", + "version_id": 12, + "status": "draft", + "draft_revision": 4 +} +``` + +### `add_node` + +向当前可编辑 graph 添加一个节点。 + +输入: + +```json +{ + "flow_id": "", + "node_type": "code", + "name": "Code Node", + "position_x": 120, + "position_y": 260, + "initial_params": { + "code": "print('hello')" + }, + "expected_revision": 4 +} +``` + +返回: + +```json +{ + "ok": true, + "flow_id": "", + "version_id": 12, + "status": "draft", + "draft_revision": 5, + "node_id": "code_ab12cd34" +} +``` + +说明: + +- `node_type` 必须来自 `list_node_types` +- `initial_params` 可选,仅能覆盖该 node type 已暴露的可编辑参数 + +### `remove_node` + +从当前可编辑 graph 删除一个节点。 + +输入: + +```json +{ + "flow_id": "", + "node_id": "code_ab12cd34", + "cascade": true, + "expected_revision": 5 +} +``` + +说明: + +- `cascade=true` 时会一并删除与该节点关联的边 +- `cascade=false` 且节点仍有边时会拒绝删除 + +### `connect_nodes` + +在两个现有节点之间新增一条边。 + +输入: + +```json +{ + "flow_id": "", + "source_node_id": "node-1", + "target_node_id": "node-2", + "source_handle": "output", + "target_handle": "input", + "expected_revision": 5 +} +``` + +返回: + +```json +{ + "ok": true, + "flow_id": "", + "version_id": 12, + "status": "draft", + "draft_revision": 6, + "edge_id": "edge_ef56gh78" +} +``` + +说明: + +- 当前会拒绝重复同构边 +- 当前要求显式传 `source_handle` / `target_handle` + +### `disconnect_edge` + +从当前可编辑 graph 删除一条边。 + +优先推荐按 `edge_id` 删除。 + +输入: + +```json +{ + "flow_id": "", + "edge_id": "edge_ef56gh78", + "expected_revision": 6 +} +``` + +也支持按 `(source_node_id, target_node_id, source_handle, target_handle)` 精确匹配删除。 + +### `update_workflow_draft` + +整体替换当前 draft graph。 + +输入: + +```json +{ + "flow_id": "", + "graph_data": { + "nodes": [ + { + "id": "start_1", + "data": { + "id": "start_1", + "type": "start", + "name": "Start", + "group_params": [] + } + } + ], + "edges": [] + }, + "expected_revision": 4 +} +``` + +### `validate_workflow` + +校验 workflow 当前版本。 + +返回新增 `diagnostics` 字段,每个诊断项包含: + +- `code` +- `severity` +- `message` +- `node_id` +- `field_path` +- `suggested_fix` + +### `publish_workflow` + +发布指定 workflow 版本。 + +--- + +## Draft Revision + +Workflow draft 使用乐观并发控制。 + +关键字段: + +- 读接口返回 `draft_revision` +- 写接口要求传 `expected_revision` +- 成功写入后会返回新的 `draft_revision` + +推荐调用顺序: + +1. `list_workflow_nodes` 或 `get_workflow_node_params` +2. 取返回里的 `draft_revision` +3. 调 `update_workflow_node_params` / `update_workflow_draft` +4. 把 `draft_revision` 作为 `expected_revision` 传回去 + +graph 原子编辑接口也遵守同样规则: + +- `add_node` +- `remove_node` +- `connect_nodes` +- `disconnect_edge` + +如果 revision 不匹配,会拒绝写入。 + +拒绝示例: + +```json +{ + "ok": false, + "message": "Workflow draft revision mismatch, expected 2, got 3", + "error_code": 10532 +} +``` + +这用于防止: + +- 多个 Agent 同时编辑同一个 draft +- UI 和 Agent 同时改同一个 workflow + +--- + +## 节点参数暴露规则 + +MCP 不会把节点里所有字段都暴露出去。 + +当前只允许读取和编辑: + +- `show != false` +- 非敏感字段 +- 非 password/file 类型字段 + +默认会屏蔽这些字段: + +- `password` +- `token` +- `secret` +- `api_key` +- `apikey` +- `credential` +- `auth` +- `cookie` + +也就是说: + +- 隐藏字段不会出现在 `param_keys` +- 敏感字段不会出现在 `get_workflow_node_params` +- 敏感字段也不能通过 `update_workflow_node_params` 修改 + +--- + +## 错误语义 + +### HTTP 层 + +- `401` + - 缺少 Bearer token + - MCP token 非法 + - MCP token 过期 + - Bisheng 主 session 失效 +- `403` + - Origin 不允许 + +### Tool 层 + +tool 调用失败时,返回结构化 JSON: + +```json +{ + "ok": false, + "message": "xxx", + "error_code": 10532 +} +``` + +典型错误: + +- `10526`: workflow graph / 节点参数校验失败 +- `10529`: workflow 重名 +- `10532`: draft revision 冲突 + +--- + +## Clawith 配置建议 + +Clawith 里建议这样接: + +- `server_url`: `https:///mcp` +- `transport`: `streamable_http` +- `auth_type`: `bearer` +- `token_source`: 用户级 credential +- `credential_value`: `POST /api/v1/user/mcp_token` 返回的 `access_token` + +不要让模型自己持有普通 Bisheng 登录 token。 + +更合理的流程是: + +1. 用户先在 Clawith 里完成 Bisheng 登录/绑定 +2. Clawith 后端持有普通 Bisheng 登录态 +3. Clawith 后端按需调用 `/api/v1/user/mcp_token` +4. Clawith 用返回的 MCP token 调 `/mcp` + +--- + +## 已验证的本地链路 + +本地已验证: + +1. 普通 Bisheng access token 可成功换取 MCP token +2. MCP token 可成功调用 `ping` +3. MCP token 可成功调用 discovery / authoring tool +4. workflow tool 返回 `draft_revision` +5. 节点参数读取正常 +6. `update_workflow_node_params` 成功写入后会推进 `draft_revision` +7. 旧 `expected_revision` 的重复写入会被 `10532` 拒绝 + +本地测试地址: + +- `http://127.0.0.1:7860/api/v1/user/mcp_token` +- `http://127.0.0.1:7860/mcp` + +--- + +## 注意事项 + +- `/mcp` 和 `/mcp/` 都可访问,但文档统一使用 `/mcp` +- 生产环境建议显式配置 `BISHENG_MCP_ALLOWED_ORIGINS` +- 不建议长期缓存 MCP token,建议按需刷新 +- 普通 Bisheng access token 和 MCP token 不要混用 diff --git a/docs/clawith-bisheng-distributed-workflow.md b/docs/clawith-bisheng-distributed-workflow.md new file mode 100644 index 0000000000..4b164aad44 --- /dev/null +++ b/docs/clawith-bisheng-distributed-workflow.md @@ -0,0 +1,175 @@ +# Clawith × Bisheng:分布式 Agent 协作架构 + +> 2026-03-27 +> 状态:愿景草案。当前分支已落地的是 `Workflow Authoring MCP`,不包含本文里的 `Clawith Node` 运行时闭环。 + +## 核心想法 + +每个人有自己的 **Clawith**(个人 Agent),团队用 **Bisheng** 设计业务流程(Workflow)。 + +Workflow 里的每个节点可以分发给对应的人,由他们的 Clawith 来执行。 + +流程启动时,通过 Clawith 作为入口,驱动整个 Workflow 运转。 + +--- + +## 架构示意 + +``` + 组织层(Bisheng) +┌────────────────────────────────────┐ +│ │ +│ [Node A] → [Node B] → [Node C] │ +│ ↓ ↓ ↓ │ +└─────┼────────────┼──────────┼──────┘ + │ │ │ + ▼ ▼ ▼ + 张三的 李四的 王五的 + Clawith Clawith Clawith + (个人Agent) (个人Agent) (个人Agent) +``` + +--- + +## 节点类型对比:INPUT Node vs Clawith Node + +**现有 Bisheng INPUT Node** +``` +Workflow → 暂停等待 → 同一个发起人输入 → 继续 +``` +- 只能是发起人自己回答 +- 没有 AI 辅助,只是一个表单 + +**Clawith Node(新想法)** +``` +Workflow → 里面有clawith node - 分发node任务给指定人 → 那个人用自己的 AI Agent 完成 → 结果回传 → 继续 +``` + +| 维度 | LLM/Code Node | INPUT Node | Clawith Node | +|---|---|---|---| +| 执行者 | 机器 | 发起人 | 指定人 + 其 AI | +| 速度 | 秒级 | 取决于人 | 取决于人 | +| 多人协作 | ✗ | ✗ | ✓ | +| 人工判断 | ✗ | 简单输入 | 深度处理 | +| AI 辅助执行者 | - | ✗ | ✓ | + +--- + +## 三种方案对比 + +### 方案 A:纯 Clawith(Agent 自组织) + +``` +Agent A 收到任务 → 自己判断需要谁 → 调用 Agent B/C → 汇总结果 +``` + +**优点:** 灵活,动态调整,不需要预定义流程,适合探索性任务 + +**缺点:** +- 不可控:不知道会调用谁,顺序不确定 +- 无审计:事后难以追溯 +- 无治理:无法限制跨部门调用权限 +- 容易死循环:A 调 B,B 调 C,C 又调 A +- 无 SLA:不知道整个流程要多久 + +### 方案 B:Clawith Skill 预定义流程 + +```yaml +# 定义一个 skill:客户工单处理流程 +skill: handle_customer_ticket +steps: + 1. 调用 LLM 分类工单 + 2. if 技术问题 → call_agent("tech_support", task) + 3. if 账单问题 → call_agent("finance", task) + 4. 等待结果 → 汇总 → 回复客户 +``` + +**优点:** 能跑,稳定,适合简单固定流程 + +**缺点:** +- 业务人员改不了流程 — 每次改都要找开发者改代码 +- 看不到全局 — skill 里嵌套调用了 5 个 Agent,哪个卡住了看不到 +- 状态管理是噩梦 — Agent A 调了 B 和 C,B 完成了 C 超时了怎么办? +- 规模不 scale — 3 个 Agent 协作用 skill 可以,15 个呢? + +### 方案 C:Clawith + Bisheng Workflow + +``` +Bisheng 定义流程 → 节点分发到 Clawith → Agent 执行 → 结果回传 → 下一节点 +``` + +**Bisheng 是「导演」**,定义谁做什么、顺序、条件 +**Clawith 是「演员」**,执行具体任务 + +**优点:** 可预测、可审计、可治理、可复用、可优化 + +**缺点:** 需要预先设计流程,灵活性略低 + +### 场景对比 + +| 场景 | 适合方案 | +|---|---| +| 探索性任务(「帮我调研竞品」) | 纯 Clawith,Agent 自行决定 | +| 客户工单处理,保证每个环节走到 | Clawith + Bisheng Workflow | +| 跨部门审批、多人协作 | Clawith + Bisheng Workflow | + +--- + +## Clawith Skill vs Bisheng Workflow + +用 Skill 预定义协作路径能跑,但有天花板: + +| 维度 | Clawith Skill | Bisheng Workflow | +|---|---|---| +| 流程定义 | 代码/prompt 写死 | 可视化拖拽 | +| 修改流程 | 改代码,重新部署 | UI 上拖线 | +| 谁能改 | 开发者 | 业务人员 | +| 并行分支 | 需要自己写 async | 原生支持 fan-out/fan-in | +| 条件路由 | if/else 硬编码 | 条件节点,可视化配置 | +| 执行状态 | 自己记日志 | 全局状态机,可暂停/恢复 | +| 错误处理 | try/catch | 节点级重试、超时、fallback | +| 监控 | 无 | 每个节点执行耗时、成功率 | +| 复用 | 复制粘贴 skill | 流程模板,一键复用 | + +### 适用边界 + +``` +任务复杂度 + │ +高 │ ▓▓▓▓ Bisheng Workflow ▓▓▓▓ 可视化、可管理、可审计 + │ ░░░░░ 灰色地带(两者都行)░░░ +低 │ ████ Clawith Skill ████ 简单、灵活、快速 + └────────────────────────────── +``` + +Skill 永远解决不了「业务人员改不了流程」这个问题。 + +--- + +## 分工边界 + +| | Clawith | Bisheng | +|---|---|---| +| 使用者 | 个人 | 团队/业务 | +| 设计者 | 个人配置 | 业务/研发设计 | +| 适合场景 | 轻量、灵活、快速 | 标准化、可管理、可审计 | +| 核心价值 | 个人效率 | 流程协作 | + +--- + +## 核心价值 + +Bisheng 的真正价值不是「能跑流程」,而是「让非技术人员能设计和管理流程」。 + +Clawith 负责执行,Bisheng 负责编排 —— 两者互补,而非替代。 + +--- + +## 商业价值 + +- **降本**:减少跨部门沟通成本(会议、邮件往返) +- **提效**:异步并行处理,不阻塞流程 +- **可审计**:每个环节有明确责任人和 AI 辅助记录 +- **差异化**:市场上没有「个人 Agent + 组织 Workflow」的产品 + +**Slogan:** 从单机 Agent 到协作网络 diff --git a/src/backend/bisheng/api/services/external_workflow.py b/src/backend/bisheng/api/services/external_workflow.py new file mode 100644 index 0000000000..fbaf1ffc98 --- /dev/null +++ b/src/backend/bisheng/api/services/external_workflow.py @@ -0,0 +1,1433 @@ +import asyncio +import copy +import time +from typing import Optional + +from fastapi import Request +from sqlmodel import select + +from bisheng.api.services.flow import FlowService +from bisheng.api.services.workflow import WorkFlowService +from bisheng.common.dependencies.user_deps import UserPayload +from bisheng.common.errcode.flow import (NotFoundVersionError, WorkflowNameExistsError, + WorkFlowInitError, WorkFlowOnlineEditError, WorkFlowVersionUpdateError) +from bisheng.common.errcode.http_error import NotFoundError, UnAuthorizedError +from bisheng.core.database import get_sync_db_session +from bisheng.database.models.flow import Flow, FlowDao, FlowStatus, FlowType +from bisheng.database.models.flow_version import FlowVersion, FlowVersionDao +from bisheng.database.models.role_access import AccessType +from bisheng.utils import generate_uuid +from bisheng.workflow.common.node import BaseNodeData, NodeType +from bisheng.workflow.edges.edges import EdgeBase +from bisheng.workflow.graph.workflow import Workflow +from bisheng.workflow.authoring.editor_compat import normalize_workflow_editor_graph +from bisheng.workflow.authoring.registry import create_graph_node_payload +from bisheng.workflow.nodes.condition.conidition_case import ConditionCases + + +class ExternalWorkflowService: + _DRAFT_META_KEY = '_external_workflow_meta' + _DRAFT_SOURCE = 'clawith_mcp' + _MAX_EXTERNAL_DRAFT_SCAN = 20 + _WORKFLOW_VALIDATION_MAX_STEPS = 10 + _WORKFLOW_VALIDATION_TIMEOUT_SECONDS = 10 + _CONDITION_NODE_TYPE = 'condition' + _CONDITION_PARAM_KEY = 'condition' + _CONDITION_FALLBACK_HANDLE = 'right_handle' + _DEFAULT_SOURCE_HANDLE = 'right_handle' + _DEFAULT_TARGET_HANDLE = 'left_handle' + _DEFAULT_HORIZONTAL_NODE_GAP = 320 + _NOTE_NODE_TYPE = NodeType.NOTE.value + _EDITOR_FLOW_NODE_TYPE = 'flowNode' + _EDITOR_NOTE_NODE_TYPE = 'noteNode' + _START_NODE_TYPE = NodeType.START.value + _END_NODE_TYPE = NodeType.END.value + _SENSITIVE_KEY_PATTERNS = ( + 'password', + 'token', + 'secret', + 'api_key', + 'apikey', + 'credential', + 'auth', + 'cookie', + ) + _BLOCKED_FIELD_TYPES = {'file', 'password'} + + @staticmethod + def _internal_request() -> Request: + scope = { + 'type': 'http', + 'method': 'POST', + 'path': '/mcp/workflow', + 'headers': [], + 'query_string': b'', + 'client': ('127.0.0.1', 0), + 'server': ('127.0.0.1', 80), + 'scheme': 'http', + 'root_path': '', + 'http_version': '1.1', + } + return Request(scope) + + @staticmethod + def _next_version_name() -> str: + return f'draft-{int(time.time() * 1000)}' + + @classmethod + def _raise_workflow_error(cls, message: str): + raise WorkFlowInitError(msg=message) + + @classmethod + def _assert_workflow_name_available(cls, + login_user: UserPayload, + name: str, + exclude_flow_id: Optional[str] = None): + with get_sync_db_session() as session: + statement = select(Flow).where( + Flow.name == name, + Flow.flow_type == FlowType.WORKFLOW.value, + Flow.user_id == login_user.user_id, + ) + exists = session.exec(statement).first() + if exists and exists.id != exclude_flow_id: + raise WorkflowNameExistsError() + + @classmethod + def _mark_graph_as_draft(cls, graph_data: dict, *, in_place: bool = False) -> dict: + updated_graph = graph_data if in_place else copy.deepcopy(graph_data) + current_meta = updated_graph.get(cls._DRAFT_META_KEY, {}) + if not isinstance(current_meta, dict): + current_meta = {} + updated_graph[cls._DRAFT_META_KEY] = { + 'draft': True, + 'source': cls._DRAFT_SOURCE, + 'revision': int(current_meta.get('revision', 0)) + 1, + 'updated_at': int(time.time() * 1000), + } + return updated_graph + + @classmethod + def _clear_graph_draft_marker(cls, graph_data: dict, *, in_place: bool = False) -> dict: + updated_graph = graph_data if in_place else copy.deepcopy(graph_data) + updated_graph.pop(cls._DRAFT_META_KEY, None) + return updated_graph + + @classmethod + def _is_draft_graph(cls, graph_data: Optional[dict]) -> bool: + if not isinstance(graph_data, dict): + return False + meta = graph_data.get(cls._DRAFT_META_KEY, {}) + return isinstance(meta, dict) and meta.get('draft') is True + + @classmethod + def get_graph_revision(cls, graph_data: Optional[dict]) -> int: + if not isinstance(graph_data, dict): + return 0 + meta = graph_data.get(cls._DRAFT_META_KEY, {}) + if not isinstance(meta, dict): + return 0 + try: + return int(meta.get('revision', 0) or 0) + except Exception: + return 0 + + @classmethod + def _assert_expected_revision(cls, graph_data: Optional[dict], expected_revision: Optional[int]): + if expected_revision is None: + return + current_revision = cls.get_graph_revision(graph_data) + if current_revision != expected_revision: + raise WorkFlowVersionUpdateError( + msg=f'Workflow draft revision mismatch, expected {expected_revision}, got {current_revision}' + ) + + @classmethod + def _validate_graph_structure(cls, graph_data: dict): + if not isinstance(graph_data, dict): + cls._raise_workflow_error('Workflow graph must be a JSON object') + + nodes = graph_data.get('nodes') + edges = graph_data.get('edges') + if not isinstance(nodes, list): + cls._raise_workflow_error('Workflow graph must contain a nodes array') + if not isinstance(edges, list): + cls._raise_workflow_error('Workflow graph must contain an edges array') + if not nodes: + cls._raise_workflow_error('Workflow graph must contain at least one node') + + node_ids: set[str] = set() + for index, raw_node in enumerate(nodes): + if not isinstance(raw_node, dict): + cls._raise_workflow_error(f'Node at index {index} must be an object') + node_data = raw_node.get('data') + if not isinstance(node_data, dict): + cls._raise_workflow_error(f'Node {raw_node.get("id", index)} must contain a data object') + try: + parsed_node = BaseNodeData(**node_data) + except Exception as exc: + cls._raise_workflow_error(f'Invalid node structure for {raw_node.get("id", index)}: {exc}') + if raw_node.get('id') != parsed_node.id: + cls._raise_workflow_error(f'Node id mismatch for {raw_node.get("id", index)}') + if parsed_node.id in node_ids: + cls._raise_workflow_error(f'Duplicate node id: {parsed_node.id}') + node_ids.add(parsed_node.id) + cls._validate_node_group_params(parsed_node.id, node_data.get('group_params')) + + for index, raw_edge in enumerate(edges): + if not isinstance(raw_edge, dict): + cls._raise_workflow_error(f'Edge at index {index} must be an object') + try: + parsed_edge = EdgeBase(**raw_edge) + except Exception as exc: + cls._raise_workflow_error(f'Invalid edge structure at index {index}: {exc}') + if parsed_edge.source not in node_ids: + cls._raise_workflow_error(f'Edge source not found: {parsed_edge.source}') + if parsed_edge.target not in node_ids: + cls._raise_workflow_error(f'Edge target not found: {parsed_edge.target}') + + @classmethod + def _validate_node_group_params(cls, node_id: str, group_params): + if group_params is None: + return + if not isinstance(group_params, list): + cls._raise_workflow_error(f'Node {node_id} group_params must be a list') + seen_keys: set[str] = set() + for group_index, group in enumerate(group_params): + if not isinstance(group, dict): + cls._raise_workflow_error(f'Node {node_id} group {group_index} must be an object') + params = group.get('params', []) + if not isinstance(params, list): + cls._raise_workflow_error(f'Node {node_id} group {group_index} params must be a list') + for param_index, param in enumerate(params): + if not isinstance(param, dict): + cls._raise_workflow_error( + f'Node {node_id} param at group {group_index} index {param_index} must be an object' + ) + key = param.get('key') + if not key or not isinstance(key, str): + cls._raise_workflow_error(f'Node {node_id} has a param without a valid key') + if key in seen_keys: + cls._raise_workflow_error(f'Node {node_id} has duplicate param key: {key}') + seen_keys.add(key) + + @classmethod + def _validate_condition_node_routes(cls, graph_data: dict, node_id: str, condition_cases: list[dict]): + valid_handles = {one['id'] for one in condition_cases} + valid_handles.add(cls._CONDITION_FALLBACK_HANDLE) + outgoing_handles = { + (edge.get('sourceHandle') or '') + for edge in graph_data.get('edges', []) + if edge.get('source') == node_id + } + stale_handles = sorted(handle for handle in outgoing_handles if handle not in valid_handles) + if stale_handles: + cls._raise_workflow_error( + f'Condition node {node_id} has stale route handles: {", ".join(stale_handles)}' + ) + + missing_handles = [one['id'] for one in condition_cases if one['id'] not in outgoing_handles] + if missing_handles: + cls._raise_workflow_error( + f'Condition node {node_id} is missing route edges for: {", ".join(missing_handles)}' + ) + + if cls._CONDITION_FALLBACK_HANDLE not in outgoing_handles: + cls._raise_workflow_error( + f'Condition node {node_id} is missing fallback route edge: {cls._CONDITION_FALLBACK_HANDLE}' + ) + + @classmethod + def _validate_special_node_routes(cls, graph_data: dict): + for raw_node in graph_data.get('nodes', []): + node_id = raw_node.get('id', '') + node_data = raw_node.get('data', {}) + if not isinstance(node_data, dict): + continue + if node_data.get('type') != cls._CONDITION_NODE_TYPE: + continue + condition_field = cls._get_node_param_field(raw_node, cls._CONDITION_PARAM_KEY) + condition_cases = cls._normalize_condition_cases(condition_field.get('value') or []) + cls._validate_condition_node_routes(graph_data, node_id, condition_cases) + + @classmethod + def _validate_workflow_runtime(cls, login_user: UserPayload, graph_data: dict, flow_name: str, flow_id: Optional[str] = None): + try: + Workflow( + workflow_id=flow_id or f'draft_{generate_uuid()}', + workflow_name=flow_name or 'draft-workflow', + user_id=login_user.user_id, + workflow_data=graph_data, + async_mode=False, + max_steps=cls._WORKFLOW_VALIDATION_MAX_STEPS, + timeout=cls._WORKFLOW_VALIDATION_TIMEOUT_SECONDS, + callback=None, + ) + except Exception as exc: + raise WorkFlowInitError(exception=exc, msg=str(exc)) + + @classmethod + def _validate_draft_graph(cls, + login_user: UserPayload, + graph_data: dict, + flow_name: str, + flow_id: Optional[str] = None): + cls._validate_graph_structure(graph_data) + cls._validate_special_node_routes(graph_data) + cls._validate_workflow_runtime(login_user, graph_data, flow_name=flow_name, flow_id=flow_id) + + @classmethod + async def _get_base_version(cls, login_user: UserPayload, flow_id: str, version_id: Optional[int] = None) -> tuple[Flow, FlowVersion]: + flow = await cls._get_workflow_with_write_access(login_user, flow_id) + if version_id is None: + version = cls._get_existing_external_draft_version(flow_id) + if version: + return flow, version + current_version = FlowVersionDao.get_version_by_flow(flow_id) + if not current_version: + raise NotFoundVersionError() + return flow, current_version + version = await cls._get_workflow_version(flow_id, version_id) + return flow, version + + @classmethod + def _get_existing_external_draft_version(cls, flow_id: str) -> Optional[FlowVersion]: + offset = 0 + while True: + with get_sync_db_session() as session: + statement = select(FlowVersion).where( + FlowVersion.flow_id == flow_id, + FlowVersion.is_delete == 0, + ).order_by(FlowVersion.id.desc()).limit(cls._MAX_EXTERNAL_DRAFT_SCAN).offset(offset) + versions = session.exec(statement).all() + for version in versions: + if cls._is_draft_graph(version.data): + return version + if len(versions) < cls._MAX_EXTERNAL_DRAFT_SCAN: + break + offset += cls._MAX_EXTERNAL_DRAFT_SCAN + return None + + @classmethod + def _create_external_draft_version(cls, + login_user: UserPayload, + flow: Flow, + base_version: FlowVersion) -> FlowVersion: + flow_version = FlowVersion( + flow_id=flow.id, + name=cls._next_version_name(), + description=f'Editable draft created by external MCP at {int(time.time())}', + user_id=login_user.user_id, + data=cls._mark_graph_as_draft(base_version.data), + original_version_id=base_version.id, + flow_type=FlowType.WORKFLOW.value, + ) + return FlowVersionDao.create_version(flow_version) + + @classmethod + async def _get_editable_version(cls, + login_user: UserPayload, + flow_id: str, + version_id: Optional[int] = None) -> tuple[Flow, FlowVersion]: + flow, version = await cls._get_base_version(login_user, flow_id, version_id) + if flow.status != FlowStatus.ONLINE.value: + return flow, version + if cls._is_draft_graph(version.data): + return flow, version + existing_draft = cls._get_existing_external_draft_version(flow_id) + if existing_draft: + return flow, existing_draft + return flow, cls._create_external_draft_version(login_user, flow, version) + + @staticmethod + def _find_node(graph_data: dict, node_id: str) -> dict: + for node in graph_data.get('nodes', []): + if node.get('id') == node_id: + return node + raise NotFoundError(msg=f'Workflow node not found: {node_id}') + + @staticmethod + def _find_node_index(graph_data: dict, node_id: str) -> int: + for index, node in enumerate(graph_data.get('nodes', [])): + if node.get('id') == node_id: + return index + raise NotFoundError(msg=f'Workflow node not found: {node_id}') + + @staticmethod + def _find_edge_index(graph_data: dict, edge_id: str) -> int: + for index, edge in enumerate(graph_data.get('edges', [])): + if edge.get('id') == edge_id: + return index + raise NotFoundError(msg=f'Workflow edge not found: {edge_id}') + + @staticmethod + def _get_node_type(node: dict) -> str: + return node.get('data', {}).get('type') or node.get('type', '') + + @classmethod + def _assert_condition_node(cls, node: dict, node_id: str): + if cls._get_node_type(node) != cls._CONDITION_NODE_TYPE: + cls._raise_workflow_error(f'Workflow node {node_id} is not a condition node') + + @staticmethod + def _get_node_template(node: dict) -> dict: + template = node.get('data', {}).get('node', {}).get('template') + if not isinstance(template, dict): + raise NotFoundError(msg='Workflow node template not found') + return template + + @staticmethod + def _get_node_group_params(node: dict) -> list[dict]: + group_params = node.get('data', {}).get('group_params') + if not isinstance(group_params, list): + raise NotFoundError(msg='Workflow node group_params not found') + return group_params + + @staticmethod + def _get_editable_param_keys(template: dict) -> list[str]: + keys = [] + for key, value in template.items(): + if key.startswith('_'): + continue + if isinstance(value, dict): + keys.append(key) + return keys + + @classmethod + def _is_sensitive_param_field(cls, key: str, field: dict) -> bool: + lowered_key = key.lower() + if any(pattern in lowered_key for pattern in cls._SENSITIVE_KEY_PATTERNS): + return True + display_name = str(field.get('display_name') or field.get('label') or '').lower() + if any(pattern in display_name for pattern in cls._SENSITIVE_KEY_PATTERNS): + return True + if field.get('password') is True: + return True + if field.get('type') in cls._BLOCKED_FIELD_TYPES: + return True + return False + + @classmethod + def _is_editable_param_field(cls, key: str, field: dict) -> bool: + if field.get('show') is False: + return False + if cls._is_sensitive_param_field(key, field): + return False + return True + + @classmethod + def _iter_node_param_fields(cls, node: dict): + group_params = node.get('data', {}).get('group_params') + if isinstance(group_params, list): + for group in group_params: + group_name = group.get('name', '') + for field in group.get('params', []) or []: + if isinstance(field, dict) and field.get('key'): + yield group_name, field + return + + template = node.get('data', {}).get('node', {}).get('template') + if isinstance(template, dict): + for key, field in template.items(): + if key.startswith('_') or not isinstance(field, dict): + continue + normalized_field = copy.deepcopy(field) + normalized_field.setdefault('key', key) + normalized_field.setdefault('label', field.get('display_name', key)) + yield '', normalized_field + + @classmethod + def _get_node_param_field(cls, node: dict, key: str) -> dict: + for _group_name, field in cls._iter_node_param_fields(node): + if field.get('key') == key: + return field + raise NotFoundError(msg=f'Workflow node param not found: {key}') + + @classmethod + def _get_node_param_fields(cls, node: dict) -> dict[str, tuple[str, dict]]: + fields = {} + for group_name, field in cls._iter_node_param_fields(node): + if cls._is_editable_param_field(field['key'], field): + fields[field['key']] = (group_name, field) + if not fields: + raise NotFoundError(msg='Workflow node params not found') + return fields + + @classmethod + def _get_editable_node_param_keys(cls, node: dict) -> list[str]: + return list(cls._get_node_param_fields(node).keys()) + + @classmethod + def _coerce_bool(cls, value): + if isinstance(value, bool): + return value + if isinstance(value, str): + lowered = value.strip().lower() + if lowered in {'true', '1', 'yes', 'on'}: + return True + if lowered in {'false', '0', 'no', 'off'}: + return False + cls._raise_workflow_error(f'Expected boolean value, got {value!r}') + + @classmethod + def _coerce_param_value(cls, field: dict, value): + current_value = field.get('value') + field_type = field.get('type') + + if field.get('required') and value in (None, '', []): + cls._raise_workflow_error(f'Param {field.get("key")} is required') + + if isinstance(current_value, bool): + normalized = cls._coerce_bool(value) + elif isinstance(current_value, int) and not isinstance(current_value, bool): + try: + normalized = int(value) + except Exception: + cls._raise_workflow_error(f'Param {field.get("key")} expects an integer') + elif isinstance(current_value, float): + try: + normalized = float(value) + except Exception: + cls._raise_workflow_error(f'Param {field.get("key")} expects a float') + elif isinstance(current_value, list): + if not isinstance(value, list): + cls._raise_workflow_error(f'Param {field.get("key")} expects a list') + normalized = value + elif isinstance(current_value, dict): + if not isinstance(value, dict): + cls._raise_workflow_error(f'Param {field.get("key")} expects an object') + normalized = value + elif field_type in {'slide', 'float'}: + try: + normalized = float(value) + except Exception: + cls._raise_workflow_error(f'Param {field.get("key")} expects a float') + elif field_type in {'switch', 'bool'}: + normalized = cls._coerce_bool(value) + elif field_type in {'bisheng_model'}: + try: + normalized = int(value) + except Exception: + cls._raise_workflow_error(f'Param {field.get("key")} expects a model id integer') + elif field_type in {'var', 'image_prompt', 'user_question'}: + if not isinstance(value, list): + cls._raise_workflow_error(f'Param {field.get("key")} expects a list') + normalized = value + else: + if not isinstance(value, str): + cls._raise_workflow_error(f'Param {field.get("key")} expects a string') + normalized = value + + scope = field.get('scope') + if isinstance(scope, list) and len(scope) == 2 and isinstance(normalized, (int, float)): + if normalized < scope[0] or normalized > scope[1]: + cls._raise_workflow_error( + f'Param {field.get("key")} must be within scope [{scope[0]}, {scope[1]}]' + ) + + options = field.get('options') + if isinstance(options, list) and options: + allowed_values = set() + for option in options: + if isinstance(option, dict): + if 'key' in option: + allowed_values.add(option['key']) + if 'value' in option: + allowed_values.add(option['value']) + else: + allowed_values.add(option) + if allowed_values and normalized not in allowed_values: + cls._raise_workflow_error( + f'Param {field.get("key")} must be one of {sorted(allowed_values)}' + ) + return normalized + + @staticmethod + def _build_param_payload(group_name: str, value: dict) -> dict: + key = value['key'] + return { + 'display_name': value.get('display_name') or value.get('label') or value.get('name') or key, + 'group_name': group_name, + 'type': value.get('type'), + 'required': value.get('required', False), + 'show': value.get('show', True), + 'options': value.get('options'), + 'scope': value.get('scope'), + 'placeholder': value.get('placeholder'), + 'refresh': value.get('refresh', False), + 'value': value.get('value'), + } + + @classmethod + def _next_graph_node_id(cls, node_type: str) -> str: + return f'{node_type}_{generate_uuid()[:8]}' + + @classmethod + def _next_graph_edge_id(cls) -> str: + return f'edge_{generate_uuid()[:8]}' + + @staticmethod + def _node_position_value(node: dict, axis: str) -> float: + position = node.get('position', {}) + if not isinstance(position, dict): + return 0.0 + value = position.get(axis, 0) + try: + return float(value) + except Exception: + return 0.0 + + @classmethod + def _node_type_matches(cls, node: dict, node_type: str) -> bool: + return cls._get_node_type(node) == node_type + + @classmethod + def _find_nodes_by_type(cls, graph_data: dict, node_type: str) -> list[dict]: + return [node for node in graph_data.get('nodes', []) if cls._node_type_matches(node, node_type)] + + @classmethod + def _build_graph_edge_payload(cls, + source_node_id: str, + target_node_id: str, + *, + source_type: str, + target_type: str, + source_handle: str = '', + target_handle: str = '', + edge_id: Optional[str] = None) -> dict: + return { + 'id': edge_id or cls._next_graph_edge_id(), + 'source': source_node_id, + 'sourceHandle': source_handle or cls._DEFAULT_SOURCE_HANDLE, + 'sourceType': source_type, + 'target': target_node_id, + 'targetHandle': target_handle or cls._DEFAULT_TARGET_HANDLE, + 'targetType': target_type, + } + + @classmethod + def _append_graph_edge_if_missing(cls, + graph_data: dict, + *, + source_node: dict, + target_node: dict, + source_handle: str = '', + target_handle: str = '') -> bool: + resolved_source_handle = source_handle or cls._DEFAULT_SOURCE_HANDLE + resolved_target_handle = target_handle or cls._DEFAULT_TARGET_HANDLE + for edge in graph_data.get('edges', []): + if ( + edge.get('source') == source_node.get('id') and + edge.get('target') == target_node.get('id') and + edge.get('sourceHandle') == resolved_source_handle and + edge.get('targetHandle') == resolved_target_handle + ): + return False + graph_data.setdefault('edges', []).append( + cls._build_graph_edge_payload( + source_node.get('id', ''), + target_node.get('id', ''), + source_type=cls._get_node_type(source_node), + target_type=cls._get_node_type(target_node), + source_handle=resolved_source_handle, + target_handle=resolved_target_handle, + ) + ) + return True + + @classmethod + def _get_root_nodes(cls, graph_data: dict) -> list[dict]: + incoming_counts = { + node.get('id', ''): 0 + for node in graph_data.get('nodes', []) + } + for edge in graph_data.get('edges', []): + target_id = edge.get('target') + if target_id in incoming_counts: + incoming_counts[target_id] += 1 + return [ + node for node in graph_data.get('nodes', []) + if cls._get_node_type(node) not in {cls._NOTE_NODE_TYPE, cls._START_NODE_TYPE} + and incoming_counts.get(node.get('id', ''), 0) == 0 + ] + + @classmethod + def _get_terminal_nodes(cls, graph_data: dict) -> list[dict]: + outgoing_counts = { + node.get('id', ''): 0 + for node in graph_data.get('nodes', []) + } + for edge in graph_data.get('edges', []): + source_id = edge.get('source') + if source_id in outgoing_counts: + outgoing_counts[source_id] += 1 + return [ + node for node in graph_data.get('nodes', []) + if cls._get_node_type(node) not in {cls._NOTE_NODE_TYPE, cls._END_NODE_TYPE} + and outgoing_counts.get(node.get('id', ''), 0) == 0 + ] + + @classmethod + def _build_scaffold_node(cls, + node_type: str, + *, + node_id: Optional[str] = None, + position_x: float = 0, + position_y: float = 0) -> dict: + resolved_node_id = node_id or cls._next_graph_node_id(node_type) + node_payload = create_graph_node_payload( + node_type, + node_id=resolved_node_id, + position_x=position_x, + position_y=position_y, + ) + if node_payload is None: + raise NotFoundError(msg=f'Workflow node template not found: {node_type}') + return node_payload + + @classmethod + def _normalize_editor_node_types(cls, graph_data: dict) -> dict: + return normalize_workflow_editor_graph(graph_data, in_place=True) + + @classmethod + def _ensure_create_graph_scaffold(cls, graph_data: dict) -> dict: + if not isinstance(graph_data, dict): + return graph_data + nodes = graph_data.get('nodes') + edges = graph_data.get('edges') + if not isinstance(nodes, list) or not isinstance(edges, list): + return graph_data + + updated_graph = copy.deepcopy(graph_data) + cls._normalize_editor_node_types(updated_graph) + start_nodes = cls._find_nodes_by_type(updated_graph, cls._START_NODE_TYPE) + end_nodes = cls._find_nodes_by_type(updated_graph, cls._END_NODE_TYPE) + + if not updated_graph['nodes']: + start_node = cls._build_scaffold_node(cls._START_NODE_TYPE, position_x=0, position_y=0) + end_node = cls._build_scaffold_node( + cls._END_NODE_TYPE, + position_x=cls._DEFAULT_HORIZONTAL_NODE_GAP, + position_y=0, + ) + updated_graph['nodes'].extend([start_node, end_node]) + cls._append_graph_edge_if_missing(updated_graph, source_node=start_node, target_node=end_node) + return updated_graph + + if not start_nodes: + root_nodes = cls._get_root_nodes(updated_graph) + root_x_positions = [cls._node_position_value(node, 'x') for node in root_nodes] + root_y_positions = [cls._node_position_value(node, 'y') for node in root_nodes] + start_node = cls._build_scaffold_node( + cls._START_NODE_TYPE, + position_x=(min(root_x_positions) - cls._DEFAULT_HORIZONTAL_NODE_GAP) if root_x_positions else 0, + position_y=(sum(root_y_positions) / len(root_y_positions)) if root_y_positions else 0, + ) + updated_graph['nodes'].append(start_node) + start_nodes = [start_node] + for root_node in root_nodes: + cls._append_graph_edge_if_missing(updated_graph, source_node=start_node, target_node=root_node) + + if not end_nodes: + terminal_nodes = cls._get_terminal_nodes(updated_graph) + terminal_x_positions = [cls._node_position_value(node, 'x') for node in terminal_nodes] + terminal_y_positions = [cls._node_position_value(node, 'y') for node in terminal_nodes] + end_node = cls._build_scaffold_node( + cls._END_NODE_TYPE, + position_x=(max(terminal_x_positions) + cls._DEFAULT_HORIZONTAL_NODE_GAP) if terminal_x_positions else cls._DEFAULT_HORIZONTAL_NODE_GAP, + position_y=(sum(terminal_y_positions) / len(terminal_y_positions)) if terminal_y_positions else 0, + ) + updated_graph['nodes'].append(end_node) + end_nodes = [end_node] + if terminal_nodes: + for terminal_node in terminal_nodes: + terminal_type = cls._get_node_type(terminal_node) + if terminal_type == cls._CONDITION_NODE_TYPE: + for condition_case in cls._get_condition_cases(terminal_node): + cls._append_graph_edge_if_missing( + updated_graph, + source_node=terminal_node, + target_node=end_node, + source_handle=condition_case['id'], + ) + cls._append_graph_edge_if_missing( + updated_graph, + source_node=terminal_node, + target_node=end_node, + source_handle=cls._CONDITION_FALLBACK_HANDLE, + ) + else: + cls._append_graph_edge_if_missing(updated_graph, source_node=terminal_node, target_node=end_node) + + return updated_graph + + @classmethod + def _extract_descriptor_param_updates(cls, node_descriptor: dict) -> dict: + params = node_descriptor.get('params') + if not isinstance(params, dict): + return {} + updates = {} + for key, value in params.items(): + if isinstance(value, dict) and 'value' in value: + updates[key] = copy.deepcopy(value.get('value')) + else: + updates[key] = copy.deepcopy(value) + return updates + + @classmethod + def _resolve_descriptor_position(cls, + node_descriptor: dict, + existing_node: Optional[dict], + *, + axis: str, + fallback: float) -> float: + position = node_descriptor.get('position') + if isinstance(position, dict) and axis in position: + try: + return float(position.get(axis)) + except Exception: + cls._raise_workflow_error( + f'Workflow node {node_descriptor.get("id") or "unknown"} has an invalid {axis} position' + ) + if existing_node is not None: + return cls._node_position_value(existing_node, axis) + return fallback + + @classmethod + def _build_graph_node_from_descriptor(cls, + node_descriptor: dict, + *, + existing_node: Optional[dict] = None, + node_index: int = 0) -> dict: + node_id = node_descriptor.get('id') or (existing_node or {}).get('id') + if not node_id: + cls._raise_workflow_error('Workflow node descriptor must contain an id') + + node_type = node_descriptor.get('type') or (cls._get_node_type(existing_node) if existing_node else '') + if not node_type: + cls._raise_workflow_error(f'Workflow node descriptor {node_id} must contain a type') + + name = node_descriptor.get('name') + description = node_descriptor.get('description') + tab = node_descriptor.get('tab') + position_x = cls._resolve_descriptor_position( + node_descriptor, + existing_node, + axis='x', + fallback=float(node_index * cls._DEFAULT_HORIZONTAL_NODE_GAP), + ) + position_y = cls._resolve_descriptor_position( + node_descriptor, + existing_node, + axis='y', + fallback=0.0, + ) + + if existing_node is not None and cls._get_node_type(existing_node) == node_type: + node_payload = copy.deepcopy(existing_node) + else: + node_payload = create_graph_node_payload( + node_type, + node_id=node_id, + name=name or '', + position_x=position_x, + position_y=position_y, + ) + if node_payload is None: + raise NotFoundError(msg=f'Workflow node template not found: {node_type}') + + node_payload['id'] = node_id + node_payload['position'] = {'x': position_x, 'y': position_y} + node_data = node_payload.setdefault('data', {}) + node_data['id'] = node_id + node_data['type'] = node_type + if name is not None: + node_data['name'] = name + else: + node_data.setdefault('name', '') + if description is not None: + node_data['description'] = description + if tab is not None: + node_data['tab'] = copy.deepcopy(tab) + + param_updates = cls._extract_descriptor_param_updates(node_descriptor) + if param_updates: + cls._patch_node_fields(node_payload, param_updates) + + return node_payload + + @classmethod + def _coerce_editor_graph_input(cls, + graph_data: dict, + *, + base_graph: Optional[dict] = None) -> dict: + if not isinstance(graph_data, dict): + return graph_data + + nodes = graph_data.get('nodes') + if not isinstance(nodes, list): + return graph_data + + if all(not isinstance(node, dict) or isinstance(node.get('data'), dict) for node in nodes): + return cls._normalize_editor_node_types(copy.deepcopy(graph_data)) + + base_nodes_by_id = {} + if isinstance(base_graph, dict): + for existing_node in base_graph.get('nodes', []): + if isinstance(existing_node, dict) and existing_node.get('id'): + base_nodes_by_id[existing_node['id']] = existing_node + + updated_graph = copy.deepcopy(graph_data) + rebuilt_nodes = [] + for index, raw_node in enumerate(nodes): + if not isinstance(raw_node, dict): + rebuilt_nodes.append(raw_node) + continue + if isinstance(raw_node.get('data'), dict): + rebuilt_nodes.append(copy.deepcopy(raw_node)) + continue + rebuilt_nodes.append( + cls._build_graph_node_from_descriptor( + raw_node, + existing_node=base_nodes_by_id.get(raw_node.get('id')), + node_index=index, + ) + ) + updated_graph['nodes'] = rebuilt_nodes + return cls._normalize_editor_node_types(updated_graph) + + @classmethod + def _apply_node_initial_params(cls, node_payload: dict, initial_params: Optional[dict]) -> dict: + if not initial_params: + return node_payload + updated_node = copy.deepcopy(node_payload) + cls._patch_node_fields(updated_node, initial_params) + return updated_node + + @classmethod + def _patch_node_fields(cls, node: dict, updates: dict): + if not isinstance(updates, dict) or not updates: + cls._raise_workflow_error('Workflow node updates must be a non-empty JSON object') + fields = cls._get_node_param_fields(node) + missing_keys = [] + for key, value in updates.items(): + field_meta = fields.get(key) + if field_meta is None: + missing_keys.append(key) + continue + _group_name, field = field_meta + field['value'] = cls._coerce_param_value(field, value) + if missing_keys: + available = ', '.join(fields.keys()) + raise NotFoundError(msg=f'Node params not found: {missing_keys}. Available params: {available}') + + @classmethod + def _add_node_to_graph(cls, + graph_data: dict, + node_type: str, + *, + name: str = '', + position_x: float = 0, + position_y: float = 0, + node_id: Optional[str] = None, + initial_params: Optional[dict] = None) -> tuple[dict, str]: + updated_graph = copy.deepcopy(graph_data) + resolved_node_id = node_id or cls._next_graph_node_id(node_type) + if any(node.get('id') == resolved_node_id for node in updated_graph.get('nodes', [])): + cls._raise_workflow_error(f'Duplicate node id: {resolved_node_id}') + + node_payload = create_graph_node_payload( + node_type, + node_id=resolved_node_id, + name=name, + position_x=position_x, + position_y=position_y, + ) + if node_payload is None: + raise NotFoundError(msg=f'Workflow node template not found: {node_type}') + node_payload = cls._apply_node_initial_params(node_payload, initial_params) + updated_graph.setdefault('nodes', []).append(node_payload) + return updated_graph, resolved_node_id + + @classmethod + def _remove_node_from_graph(cls, graph_data: dict, node_id: str, *, cascade: bool = True) -> dict: + updated_graph = copy.deepcopy(graph_data) + node_index = cls._find_node_index(updated_graph, node_id) + related_edges = [ + edge for edge in updated_graph.get('edges', []) + if edge.get('source') == node_id or edge.get('target') == node_id + ] + if related_edges and not cascade: + cls._raise_workflow_error(f'Node {node_id} still has connected edges') + updated_graph['nodes'].pop(node_index) + if related_edges: + updated_graph['edges'] = [ + edge for edge in updated_graph.get('edges', []) + if edge.get('source') != node_id and edge.get('target') != node_id + ] + return updated_graph + + @classmethod + def _connect_nodes_in_graph(cls, + graph_data: dict, + *, + source_node_id: str, + target_node_id: str, + source_handle: str, + target_handle: str, + edge_id: Optional[str] = None) -> tuple[dict, str]: + updated_graph = copy.deepcopy(graph_data) + source_node = cls._find_node(updated_graph, source_node_id) + target_node = cls._find_node(updated_graph, target_node_id) + if source_node_id == target_node_id: + cls._raise_workflow_error('Workflow edge cannot connect a node to itself') + if not source_handle: + cls._raise_workflow_error('Workflow edge source_handle is required') + if not target_handle: + cls._raise_workflow_error('Workflow edge target_handle is required') + + for edge in updated_graph.get('edges', []): + if ( + edge.get('source') == source_node_id and + edge.get('target') == target_node_id and + edge.get('sourceHandle') == source_handle and + edge.get('targetHandle') == target_handle + ): + cls._raise_workflow_error('Duplicate workflow edge') + + resolved_edge_id = edge_id or cls._next_graph_edge_id() + updated_graph.setdefault('edges', []).append({ + 'id': resolved_edge_id, + 'source': source_node_id, + 'sourceHandle': source_handle, + 'sourceType': source_node.get('data', {}).get('type') or source_node.get('type', ''), + 'target': target_node_id, + 'targetHandle': target_handle, + 'targetType': target_node.get('data', {}).get('type') or target_node.get('type', ''), + }) + return updated_graph, resolved_edge_id + + @classmethod + def _disconnect_edge_from_graph(cls, + graph_data: dict, + *, + edge_id: Optional[str] = None, + source_node_id: str = '', + target_node_id: str = '', + source_handle: str = '', + target_handle: str = '') -> tuple[dict, str]: + updated_graph = copy.deepcopy(graph_data) + removed_edge_id = edge_id or '' + if edge_id: + edge_index = cls._find_edge_index(updated_graph, edge_id) + updated_graph['edges'].pop(edge_index) + return updated_graph, edge_id + + matches = [] + for index, edge in enumerate(updated_graph.get('edges', [])): + if source_node_id and edge.get('source') != source_node_id: + continue + if target_node_id and edge.get('target') != target_node_id: + continue + if source_handle and edge.get('sourceHandle') != source_handle: + continue + if target_handle and edge.get('targetHandle') != target_handle: + continue + matches.append((index, edge)) + + if not matches: + raise NotFoundError(msg='Workflow edge not found') + if len(matches) > 1: + cls._raise_workflow_error('Workflow edge selector is ambiguous') + + edge_index, edge = matches[0] + removed_edge_id = edge.get('id', '') + updated_graph['edges'].pop(edge_index) + return updated_graph, removed_edge_id + + @classmethod + def _patch_node_template(cls, graph_data: dict, node_id: str, updates: dict) -> dict: + updated_graph = copy.deepcopy(graph_data) + node = cls._find_node(updated_graph, node_id) + cls._patch_node_fields(node, updates) + return updated_graph + + @classmethod + def _normalize_condition_cases(cls, condition_cases: list[dict]) -> list[dict]: + if not isinstance(condition_cases, list): + cls._raise_workflow_error('Condition node cases must be a list') + normalized_cases = [] + seen_case_ids: set[str] = set() + for index, raw_case in enumerate(condition_cases): + if not isinstance(raw_case, dict): + cls._raise_workflow_error(f'Condition case at index {index} must be an object') + try: + normalized_case = ConditionCases(**raw_case).model_dump() + except Exception as exc: + cls._raise_workflow_error(f'Invalid condition case at index {index}: {exc}') + for condition in normalized_case.get('conditions') or []: + right_value_type = condition.get('right_value_type') + if right_value_type != 'ref': + condition['right_value_type'] = 'input' + case_id = normalized_case['id'] + if case_id in seen_case_ids: + cls._raise_workflow_error(f'Duplicate condition case id: {case_id}') + seen_case_ids.add(case_id) + normalized_cases.append(normalized_case) + return normalized_cases + + @classmethod + def _get_condition_cases(cls, node: dict) -> list[dict]: + condition_field = cls._get_node_param_field(node, cls._CONDITION_PARAM_KEY) + raw_value = condition_field.get('value') or [] + return cls._normalize_condition_cases(raw_value) + + @classmethod + def _update_condition_cases_in_graph(cls, graph_data: dict, node_id: str, condition_cases: list[dict]) -> dict: + updated_graph = copy.deepcopy(graph_data) + node = cls._find_node(updated_graph, node_id) + cls._assert_condition_node(node, node_id) + condition_field = cls._get_node_param_field(node, cls._CONDITION_PARAM_KEY) + condition_field['value'] = cls._normalize_condition_cases(condition_cases) + return updated_graph + + @classmethod + def _get_condition_outgoing_edges(cls, graph_data: dict, node_id: str) -> dict[str, list[dict[str, str]]]: + routes: dict[str, list[dict[str, str]]] = {} + for edge in graph_data.get('edges', []): + if edge.get('source') != node_id: + continue + source_handle = edge.get('sourceHandle') or '' + routes.setdefault(source_handle, []).append({ + 'edge_id': edge.get('id', ''), + 'target_node_id': edge.get('target', ''), + 'target_handle': edge.get('targetHandle', ''), + }) + return routes + + @classmethod + async def _get_workflow_with_write_access(cls, login_user: UserPayload, flow_id: str) -> Flow: + flow = await FlowDao.aget_flow_by_id(flow_id) + if not flow: + raise NotFoundError() + if flow.flow_type != FlowType.WORKFLOW.value: + raise NotFoundError() + if not await login_user.async_access_check(flow.user_id, flow_id, AccessType.WORKFLOW_WRITE): + raise UnAuthorizedError() + return flow + + @classmethod + async def _get_workflow_version(cls, flow_id: str, version_id: int) -> FlowVersion: + version = await FlowVersionDao.aget_version_by_id(version_id) + if not version or version.flow_id != flow_id: + raise NotFoundVersionError() + return version + + @classmethod + def _create_workflow_draft_sync(cls, + login_user: UserPayload, + name: str, + graph_data: dict, + description: Optional[str] = None, + guide_word: Optional[str] = None) -> tuple[Flow, FlowVersion]: + cls._assert_workflow_name_available(login_user, name) + graph_data = cls._coerce_editor_graph_input(graph_data) + graph_data = cls._ensure_create_graph_scaffold(graph_data) + cls._validate_draft_graph(login_user, graph_data, flow_name=name) + + db_flow = Flow( + name=name, + user_id=login_user.user_id, + description=description, + data=graph_data, + guide_word=guide_word, + flow_type=FlowType.WORKFLOW.value, + status=FlowStatus.OFFLINE.value, + ) + db_flow = FlowDao.create_flow(db_flow, FlowType.WORKFLOW.value) + current_version = FlowVersionDao.get_version_by_flow(db_flow.id) + current_version.data = cls._mark_graph_as_draft(current_version.data) + current_version = FlowVersionDao.update_version(current_version) + FlowService.create_flow_hook( + cls._internal_request(), + login_user, + db_flow, + current_version.id, + FlowType.WORKFLOW.value, + ) + return db_flow, current_version + + @classmethod + async def create_workflow_draft(cls, + login_user: UserPayload, + name: str, + graph_data: dict, + description: Optional[str] = None, + guide_word: Optional[str] = None) -> tuple[Flow, FlowVersion]: + return await asyncio.to_thread( + cls._create_workflow_draft_sync, + login_user, + name, + graph_data, + description, + guide_word, + ) + + @classmethod + async def update_workflow_draft(cls, + login_user: UserPayload, + flow_id: str, + graph_data: dict, + name: Optional[str] = None, + description: Optional[str] = None, + guide_word: Optional[str] = None, + expected_revision: Optional[int] = None) -> tuple[Flow, FlowVersion]: + flow, editable_version = await cls._get_editable_version(login_user, flow_id) + cls._assert_expected_revision(editable_version.data, expected_revision) + has_flow_updates = any(value is not None for value in (name, description, guide_word)) + if has_flow_updates and flow.status == FlowStatus.ONLINE.value: + raise WorkFlowOnlineEditError() + if name is not None and name != flow.name: + cls._assert_workflow_name_available(login_user, name, exclude_flow_id=flow.id) + + graph_data = cls._coerce_editor_graph_input(graph_data, base_graph=editable_version.data) + if has_flow_updates: + if name is not None: + flow.name = name + if description is not None: + flow.description = description + if guide_word is not None: + flow.guide_word = guide_word + flow = await FlowDao.aupdate_flow(flow) + await FlowService.update_flow_hook(cls._internal_request(), login_user, flow) + + cls._validate_draft_graph(login_user, graph_data, flow_name=flow.name, flow_id=flow.id) + editable_version.data = cls._mark_graph_as_draft(graph_data, in_place=True) + editable_version = FlowVersionDao.update_version(editable_version) + return flow, editable_version + + @classmethod + async def list_workflow_nodes(cls, + login_user: UserPayload, + flow_id: str, + version_id: Optional[int] = None) -> dict: + _, version = await cls._get_editable_version(login_user, flow_id, version_id) + nodes = [] + for node in version.data.get('nodes', []): + node_type = node.get('data', {}).get('type') or node.get('type', '') + try: + param_keys = cls._get_editable_node_param_keys(node) + except NotFoundError: + param_keys = [] + nodes.append({ + 'id': node.get('id', ''), + 'type': node_type, + 'name': node.get('data', {}).get('name') or node.get('data', {}).get('node', {}).get('display_name', ''), + 'param_keys': param_keys, + }) + return { + 'flow_id': flow_id, + 'version_id': version.id, + 'draft_revision': cls.get_graph_revision(version.data), + 'nodes': nodes, + } + + @classmethod + async def get_workflow_node_params(cls, + login_user: UserPayload, + flow_id: str, + node_id: str, + version_id: Optional[int] = None) -> dict: + _, version = await cls._get_editable_version(login_user, flow_id, version_id) + node = cls._find_node(version.data, node_id) + params = {} + for key, (group_name, value) in cls._get_node_param_fields(node).items(): + params[key] = cls._build_param_payload(group_name, value) + return { + 'flow_id': flow_id, + 'version_id': version.id, + 'draft_revision': cls.get_graph_revision(version.data), + 'node_id': node_id, + 'node_type': node.get('data', {}).get('type') or node.get('type', ''), + 'node_name': node.get('data', {}).get('name') or node.get('data', {}).get('node', {}).get('display_name', ''), + 'params': params, + } + + @classmethod + async def get_condition_node_config(cls, + login_user: UserPayload, + flow_id: str, + node_id: str, + version_id: Optional[int] = None) -> dict: + _, version = await cls._get_editable_version(login_user, flow_id, version_id) + node = cls._find_node(version.data, node_id) + cls._assert_condition_node(node, node_id) + condition_cases = cls._get_condition_cases(node) + route_handles = [one['id'] for one in condition_cases] + if cls._CONDITION_FALLBACK_HANDLE not in route_handles: + route_handles.append(cls._CONDITION_FALLBACK_HANDLE) + return { + 'flow_id': flow_id, + 'version_id': version.id, + 'draft_revision': cls.get_graph_revision(version.data), + 'node_id': node_id, + 'node_name': node.get('data', {}).get('name') or node.get('data', {}).get('node', {}).get('display_name', ''), + 'condition_cases': condition_cases, + 'route_handles': route_handles, + 'outgoing_edges': cls._get_condition_outgoing_edges(version.data, node_id), + } + + @classmethod + async def update_workflow_node_params(cls, + login_user: UserPayload, + flow_id: str, + node_id: str, + updates: dict, + version_id: Optional[int] = None, + expected_revision: Optional[int] = None) -> tuple[Flow, FlowVersion]: + flow, editable_version = await cls._get_editable_version(login_user, flow_id, version_id) + cls._assert_expected_revision(editable_version.data, expected_revision) + graph_data = cls._patch_node_template(editable_version.data, node_id, updates) + cls._validate_draft_graph(login_user, graph_data, flow_name=flow.name, flow_id=flow.id) + editable_version.data = cls._mark_graph_as_draft(graph_data, in_place=True) + editable_version = FlowVersionDao.update_version(editable_version) + return flow, editable_version + + @classmethod + async def update_condition_node(cls, + login_user: UserPayload, + flow_id: str, + node_id: str, + condition_cases: list[dict], + version_id: Optional[int] = None, + expected_revision: Optional[int] = None) -> tuple[Flow, FlowVersion]: + flow, editable_version = await cls._get_editable_version(login_user, flow_id, version_id) + cls._assert_expected_revision(editable_version.data, expected_revision) + graph_data = cls._update_condition_cases_in_graph(editable_version.data, node_id, condition_cases) + cls._validate_draft_graph(login_user, graph_data, flow_name=flow.name, flow_id=flow.id) + editable_version.data = cls._mark_graph_as_draft(graph_data, in_place=True) + editable_version = FlowVersionDao.update_version(editable_version) + return flow, editable_version + + @classmethod + async def add_workflow_node(cls, + login_user: UserPayload, + flow_id: str, + node_type: str, + name: str = '', + position_x: float = 0, + position_y: float = 0, + initial_params: Optional[dict] = None, + version_id: Optional[int] = None, + expected_revision: Optional[int] = None) -> tuple[Flow, FlowVersion, str]: + flow, editable_version = await cls._get_editable_version(login_user, flow_id, version_id) + cls._assert_expected_revision(editable_version.data, expected_revision) + graph_data, node_id = cls._add_node_to_graph( + editable_version.data, + node_type, + name=name, + position_x=position_x, + position_y=position_y, + initial_params=initial_params, + ) + cls._validate_draft_graph(login_user, graph_data, flow_name=flow.name, flow_id=flow.id) + editable_version.data = cls._mark_graph_as_draft(graph_data, in_place=True) + editable_version = FlowVersionDao.update_version(editable_version) + return flow, editable_version, node_id + + @classmethod + async def remove_workflow_node(cls, + login_user: UserPayload, + flow_id: str, + node_id: str, + *, + cascade: bool = True, + version_id: Optional[int] = None, + expected_revision: Optional[int] = None) -> tuple[Flow, FlowVersion]: + flow, editable_version = await cls._get_editable_version(login_user, flow_id, version_id) + cls._assert_expected_revision(editable_version.data, expected_revision) + graph_data = cls._remove_node_from_graph(editable_version.data, node_id, cascade=cascade) + cls._validate_draft_graph(login_user, graph_data, flow_name=flow.name, flow_id=flow.id) + editable_version.data = cls._mark_graph_as_draft(graph_data, in_place=True) + editable_version = FlowVersionDao.update_version(editable_version) + return flow, editable_version + + @classmethod + async def connect_workflow_nodes(cls, + login_user: UserPayload, + flow_id: str, + source_node_id: str, + target_node_id: str, + source_handle: str, + target_handle: str, + *, + version_id: Optional[int] = None, + expected_revision: Optional[int] = None) -> tuple[Flow, FlowVersion, str]: + flow, editable_version = await cls._get_editable_version(login_user, flow_id, version_id) + cls._assert_expected_revision(editable_version.data, expected_revision) + graph_data, edge_id = cls._connect_nodes_in_graph( + editable_version.data, + source_node_id=source_node_id, + target_node_id=target_node_id, + source_handle=source_handle, + target_handle=target_handle, + ) + cls._validate_draft_graph(login_user, graph_data, flow_name=flow.name, flow_id=flow.id) + editable_version.data = cls._mark_graph_as_draft(graph_data, in_place=True) + editable_version = FlowVersionDao.update_version(editable_version) + return flow, editable_version, edge_id + + @classmethod + async def disconnect_workflow_edge(cls, + login_user: UserPayload, + flow_id: str, + *, + edge_id: str = '', + source_node_id: str = '', + target_node_id: str = '', + source_handle: str = '', + target_handle: str = '', + version_id: Optional[int] = None, + expected_revision: Optional[int] = None) -> tuple[Flow, FlowVersion, str]: + flow, editable_version = await cls._get_editable_version(login_user, flow_id, version_id) + cls._assert_expected_revision(editable_version.data, expected_revision) + graph_data, removed_edge_id = cls._disconnect_edge_from_graph( + editable_version.data, + edge_id=edge_id or None, + source_node_id=source_node_id, + target_node_id=target_node_id, + source_handle=source_handle, + target_handle=target_handle, + ) + cls._validate_draft_graph(login_user, graph_data, flow_name=flow.name, flow_id=flow.id) + editable_version.data = cls._mark_graph_as_draft(graph_data, in_place=True) + editable_version = FlowVersionDao.update_version(editable_version) + return flow, editable_version, removed_edge_id + + @classmethod + async def validate_workflow(cls, login_user: UserPayload, flow_id: str, version_id: int) -> tuple[Flow, FlowVersion]: + flow = await cls._get_workflow_with_write_access(login_user, flow_id) + version = await cls._get_workflow_version(flow_id, version_id) + try: + cls._validate_draft_graph(login_user, version.data, flow_name=flow.name, flow_id=flow_id) + except Exception as exc: + if isinstance(exc, WorkFlowInitError): + raise exc + raise WorkFlowInitError(exception=exc, msg=str(exc)) + return flow, version + + @classmethod + async def publish_workflow(cls, login_user: UserPayload, flow_id: str, version_id: int) -> tuple[Flow, FlowVersion]: + flow, version = await cls.validate_workflow(login_user, flow_id, version_id) + draft_data = copy.deepcopy(version.data) + version.data = cls._clear_graph_draft_marker(version.data) + version = FlowVersionDao.update_version(version) + try: + await WorkFlowService.update_flow_status(login_user, flow_id, version_id, FlowStatus.ONLINE.value) + except Exception: + version.data = draft_data + FlowVersionDao.update_version(version) + raise + flow = await FlowDao.aget_flow_by_id(flow_id) + version = await cls._get_workflow_version(flow_id, version_id) + return flow, version diff --git a/src/backend/bisheng/api/services/flow.py b/src/backend/bisheng/api/services/flow.py index 1108c441a7..8439a62217 100644 --- a/src/backend/bisheng/api/services/flow.py +++ b/src/backend/bisheng/api/services/flow.py @@ -33,10 +33,15 @@ from bisheng.user.domain.models.user import UserDao from bisheng.user.domain.models.user_role import UserRoleDao from bisheng.utils import get_request_ip +from bisheng.workflow.authoring.editor_compat import normalize_workflow_editor_graph class FlowService(BaseService): + @staticmethod + def _normalize_workflow_editor_graph(graph_data: Optional[dict]) -> Optional[dict]: + return normalize_workflow_editor_graph(graph_data) + @classmethod def get_version_list_by_flow(cls, user: UserPayload, flow_id: str) -> UnifiedResponseModel[List[FlowVersionRead]]: """ @@ -233,6 +238,8 @@ async def get_one_flow(cls, login_user: UserPayload, flow_id: str, share_link: U raise UnAuthorizedError() flow_info.logo = await cls.get_logo_share_link_async(flow_info.logo) + if flow_info.flow_type == FlowType.WORKFLOW.value: + flow_info.data = cls._normalize_workflow_editor_graph(flow_info.data) return resp_200(data=flow_info) diff --git a/src/backend/bisheng/api/services/workflow_authoring.py b/src/backend/bisheng/api/services/workflow_authoring.py new file mode 100644 index 0000000000..d4e5124d9f --- /dev/null +++ b/src/backend/bisheng/api/services/workflow_authoring.py @@ -0,0 +1,219 @@ +import asyncio +import re +from typing import Optional + +from bisheng.common.dependencies.user_deps import UserPayload +from bisheng.common.errcode.base import BaseErrorCode +from bisheng.common.errcode.http_error import NotFoundError +from bisheng.database.models.flow import Flow, FlowDao, FlowStatus, FlowType +from bisheng.database.models.flow_version import FlowVersion, FlowVersionDao +from bisheng.database.models.role_access import RoleAccessDao, AccessType +from bisheng.user.domain.models.user_role import UserRoleDao +from bisheng.workflow.authoring.contract import ( + ValidationDiagnostic, + ValidationSeverity, + WorkflowGraphDescriptor, + WorkflowGraphNodeDescriptor, + WorkflowManifest, + WorkflowParamMetadata, + WorkflowVersionSummary, +) +from bisheng.workflow.authoring.registry import get_node_template_descriptor, list_node_type_descriptors +from bisheng.workflow.authoring.registry import normalize_tab_descriptor +from bisheng.api.services.external_workflow import ExternalWorkflowService + + +class WorkflowAuthoringService: + + @staticmethod + def _status_to_name(status: Optional[int]) -> str: + if status == FlowStatus.ONLINE.value: + return 'online' + return 'offline' + + @classmethod + def _editable_version_without_side_effect(cls, flow_id: str) -> tuple[Optional[FlowVersion], Optional[FlowVersion]]: + current_version = FlowVersionDao.get_version_by_flow(flow_id) + editable_version = ExternalWorkflowService._get_existing_external_draft_version(flow_id) or current_version + return current_version, editable_version + + @classmethod + def _build_manifest(cls, flow: Flow) -> WorkflowManifest: + current_version, editable_version = cls._editable_version_without_side_effect(flow.id) + draft_revision = ExternalWorkflowService.get_graph_revision(editable_version.data if editable_version else None) + return WorkflowManifest( + flow_id=flow.id, + name=flow.name, + description=flow.description, + status=cls._status_to_name(flow.status), + current_version_id=current_version.id if current_version else None, + editable_version_id=editable_version.id if editable_version else None, + draft_revision=draft_revision, + ) + + @classmethod + async def _list_candidate_workflows(cls, login_user: UserPayload) -> list[Flow]: + if login_user.is_admin(): + return FlowDao.get_flows( + login_user.user_id, + 'admin', + '', + None, + None, + 0, + 0, + FlowType.WORKFLOW.value, + ) + user_role = UserRoleDao.get_user_roles(login_user.user_id) + role_ids = [role.role_id for role in user_role] + role_access = RoleAccessDao.get_role_access_batch(role_ids, [AccessType.WORKFLOW, AccessType.WORKFLOW_WRITE]) + extra_ids = [access.third_id for access in role_access] if role_access else [] + return FlowDao.get_flows( + login_user.user_id, + extra_ids, + '', + None, + None, + 0, + 0, + FlowType.WORKFLOW.value, + ) + + @classmethod + async def list_workflows(cls, login_user: UserPayload) -> list[WorkflowManifest]: + candidates = await cls._list_candidate_workflows(login_user) + permissions = await asyncio.gather(*[ + login_user.async_access_check(flow.user_id, flow.id, AccessType.WORKFLOW_WRITE) for flow in candidates + ]) + return [cls._build_manifest(flow) for flow, allowed in zip(candidates, permissions) if allowed] + + @classmethod + async def get_workflow(cls, login_user: UserPayload, flow_id: str) -> WorkflowManifest: + flow = await ExternalWorkflowService._get_workflow_with_write_access(login_user, flow_id) + return cls._build_manifest(flow) + + @classmethod + async def get_workflow_versions(cls, login_user: UserPayload, flow_id: str) -> list[WorkflowVersionSummary]: + flow = await ExternalWorkflowService._get_workflow_with_write_access(login_user, flow_id) + current_version, editable_version = cls._editable_version_without_side_effect(flow.id) + versions = FlowVersionDao.get_list_by_flow(flow.id) + detailed_versions = { + version.id: FlowVersionDao.get_version_by_id(version.id) + for version in versions + } + return [ + WorkflowVersionSummary( + version_id=version.id, + name=version.name, + description=version.description, + is_current=version.is_current == 1, + is_editable=editable_version is not None and version.id == editable_version.id, + is_external_draft=detailed_versions[version.id] is not None + and ExternalWorkflowService._is_draft_graph(detailed_versions[version.id].data), + original_version_id=detailed_versions[version.id].original_version_id + if detailed_versions[version.id] else None, + draft_revision=ExternalWorkflowService.get_graph_revision( + detailed_versions[version.id].data if detailed_versions[version.id] else None + ), + create_time=version.create_time, + update_time=version.update_time, + ) + for version in versions + ] + + @classmethod + def _normalize_node_params(cls, node: dict) -> dict[str, WorkflowParamMetadata]: + try: + fields = ExternalWorkflowService._get_node_param_fields(node) + except NotFoundError: + return {} + return { + key: WorkflowParamMetadata(**ExternalWorkflowService._build_param_payload(group_name, value)) + for key, (group_name, value) in fields.items() + } + + @classmethod + def _normalize_graph_node(cls, node: dict) -> WorkflowGraphNodeDescriptor: + node_data = node.get('data', {}) + node_type = node_data.get('type') or node.get('type', '') + tab = node_data.get('tab') + node_template = get_node_template_descriptor(node_type) + params = cls._normalize_node_params(node) + return WorkflowGraphNodeDescriptor( + id=node.get('id', ''), + type=node_type, + name=node_data.get('name') or node_data.get('node', {}).get('display_name', ''), + description=node_data.get('description'), + tab=normalize_tab_descriptor(tab) if tab else (node_template.tab if node_template else None), + param_keys=list(params.keys()), + params=params, + ) + + @classmethod + async def get_workflow_graph(cls, + login_user: UserPayload, + flow_id: str, + version_id: Optional[int] = None) -> WorkflowGraphDescriptor: + _, version = await ExternalWorkflowService._get_editable_version(login_user, flow_id, version_id) + graph_data = version.data or {} + return WorkflowGraphDescriptor( + flow_id=flow_id, + version_id=version.id, + draft_revision=ExternalWorkflowService.get_graph_revision(graph_data), + nodes=[cls._normalize_graph_node(node) for node in graph_data.get('nodes', [])], + edges=list(graph_data.get('edges', [])), + ) + + @staticmethod + def list_node_types(): + return list_node_type_descriptors() + + @staticmethod + def get_node_template(node_type: str): + template = get_node_template_descriptor(node_type) + if template is None: + raise NotFoundError(msg=f'Workflow node template not found: {node_type}') + return template + + @staticmethod + def diagnostics_from_exception(exc: Exception) -> list[ValidationDiagnostic]: + code = '' + if isinstance(exc, BaseErrorCode): + code = str(exc.code) + message = exc.message + else: + message = str(exc) or exc.__class__.__name__ + node_id = None + field_path = None + suggested_fix = None + + node_match = re.search(r'Node ([\w-]+)', message) + if node_match: + node_id = node_match.group(1) + elif 'Workflow node not found:' in message: + node_id = message.rsplit(':', 1)[-1].strip() + + param_match = re.search(r'Param ([\w.\-]+)', message) + if param_match: + field_path = param_match.group(1) + suggested_fix = 'Review the parameter value and node template requirements.' + elif 'Edge source not found:' in message: + field_path = 'edges.source' + suggested_fix = 'Ensure the edge source points to an existing node id.' + elif 'Edge target not found:' in message: + field_path = 'edges.target' + suggested_fix = 'Ensure the edge target points to an existing node id.' + elif 'must contain at least one node' in message: + field_path = 'nodes' + suggested_fix = 'Add at least one workflow node before validating or publishing.' + + return [ + ValidationDiagnostic( + code=code, + severity=ValidationSeverity.ERROR, + message=message, + node_id=node_id, + field_path=field_path, + suggested_fix=suggested_fix, + ) + ] diff --git a/src/backend/bisheng/main.py b/src/backend/bisheng/main.py index 31c2e70786..934f3c449d 100644 --- a/src/backend/bisheng/main.py +++ b/src/backend/bisheng/main.py @@ -1,4 +1,4 @@ -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from fastapi import FastAPI, HTTPException, Request, status from fastapi.exceptions import RequestValidationError @@ -13,6 +13,7 @@ from bisheng.common.services.config_service import settings from bisheng.core.context import initialize_app_context, close_app_context from bisheng.core.logger import set_logger_config +from bisheng.mcp_server import get_workflow_mcp_asgi_app, get_workflow_mcp_server from bisheng.services.utils import initialize_services, teardown_services from bisheng.utils.http_middleware import CustomMiddleware, WebSocketLoggingMiddleware from bisheng.utils.threadpool import thread_pool @@ -54,11 +55,17 @@ async def lifespan(app: FastAPI): await initialize_app_context(config=settings) initialize_services() await init_default_data() - # LangfuseInstance.update() - yield - teardown_services() - thread_pool.tear_down() - await close_app_context() + try: + async with AsyncExitStack() as stack: + workflow_mcp_server = get_workflow_mcp_server() + if workflow_mcp_server is not None: + await stack.enter_async_context(workflow_mcp_server.session_manager.run()) + # LangfuseInstance.update() + yield + finally: + teardown_services() + thread_pool.tear_down() + await close_app_context() def create_app(): @@ -95,6 +102,9 @@ def authjwt_exception_handler(request: Request, exc: AuthJWTException): app.include_router(router) app.include_router(router_rpc) + workflow_mcp_app = get_workflow_mcp_asgi_app() + if workflow_mcp_app is not None: + app.mount('/mcp', workflow_mcp_app) if settings.debug: import tracemalloc tracemalloc.start() diff --git a/src/backend/bisheng/mcp_server/__init__.py b/src/backend/bisheng/mcp_server/__init__.py new file mode 100644 index 0000000000..c299a7ba91 --- /dev/null +++ b/src/backend/bisheng/mcp_server/__init__.py @@ -0,0 +1,3 @@ +from bisheng.mcp_server.workflow import get_workflow_mcp_asgi_app, get_workflow_mcp_server + +__all__ = ['get_workflow_mcp_asgi_app', 'get_workflow_mcp_server'] diff --git a/src/backend/bisheng/mcp_server/auth.py b/src/backend/bisheng/mcp_server/auth.py new file mode 100644 index 0000000000..14683c8cd0 --- /dev/null +++ b/src/backend/bisheng/mcp_server/auth.py @@ -0,0 +1,371 @@ +import hashlib +import json +import os +from contextvars import ContextVar, Token +from datetime import datetime, timezone +from typing import Optional +from urllib.parse import urlparse + +import jwt +from loguru import logger +from starlette.responses import JSONResponse + +from bisheng.common.dependencies.user_deps import UserPayload +from bisheng.common.errcode.http_error import UnAuthorizedError +from bisheng.common.exceptions.auth import JWTDecodeError +from bisheng.core.cache.redis_manager import get_redis_client +from bisheng.user.domain.services.auth import LoginUser +from bisheng.user.domain.services.auth import AuthJwt +from bisheng.utils import generate_uuid +from bisheng.utils.constants import USER_CURRENT_SESSION + +_current_access_token: ContextVar[str | None] = ContextVar('mcp_access_token', default=None) +_current_login_user: ContextVar[UserPayload | None] = ContextVar('mcp_login_user', default=None) +_current_token_scopes: ContextVar[tuple[str, ...]] = ContextVar('mcp_token_scopes', default=tuple()) + +_MCP_BEARER_REALM = 'bisheng-mcp' +_LOCAL_ORIGIN_HOSTS = {'127.0.0.1', 'localhost'} +_MCP_AUDIENCE = 'bisheng-workflow-mcp' +_MCP_ISSUER = 'bisheng-mcp' +_MCP_TOKEN_TYPE = 'mcp_access_token' +_MCP_DEFAULT_SCOPES = ('workflow.read', 'workflow.write', 'workflow.publish') +_MCP_MAX_EXPIRES_IN = 60 * 60 +_MCP_DEFAULT_EXPIRES_IN = 30 * 60 +_MCP_ALLOWED_ORIGINS = tuple( + one.strip() for one in os.getenv('BISHENG_MCP_ALLOWED_ORIGINS', '').split(',') if one.strip() +) + + +def get_current_access_token() -> str | None: + return _current_access_token.get() + + +def get_current_login_user() -> UserPayload | None: + return _current_login_user.get() + + +def get_current_token_scopes() -> tuple[str, ...]: + return _current_token_scopes.get() + + +def _get_headers(scope) -> dict[str, str]: + headers = {} + for key, value in scope.get('headers', []): + headers[key.decode('latin1').lower()] = value.decode('latin1') + return headers + + +def _parse_bearer_token(auth_header: str) -> Optional[str]: + if not auth_header: + return None + auth_header = auth_header.strip() + if not auth_header.lower().startswith('bearer '): + return None + token = auth_header.split(' ', 1)[1].strip() + return token or None + + +def _extract_hostname(value: str) -> Optional[str]: + if not value: + return None + parsed = urlparse(value if '://' in value else f'//{value}') + return parsed.hostname.lower() if parsed.hostname else None + + +def _is_allowed_origin(origin: str, host: str) -> bool: + if not origin: + return True + if _MCP_ALLOWED_ORIGINS: + return origin in _MCP_ALLOWED_ORIGINS + origin_host = _extract_hostname(origin) + host_name = _extract_hostname(host) + if not origin_host or not host_name: + return False + if origin_host == host_name: + return True + if origin_host in _LOCAL_ORIGIN_HOSTS and host_name in _LOCAL_ORIGIN_HOSTS: + return True + return False + + +def _auth_header(error: str, description: str) -> str: + description = description.replace('"', "'") + return ( + f'Bearer realm="{_MCP_BEARER_REALM}", ' + f'error="{error}", ' + f'error_description="{description}"' + ) + + +def _error_response(status_code: int, + message: str, + *, + error: str, + extra_headers: Optional[dict[str, str]] = None) -> JSONResponse: + headers = {'Cache-Control': 'no-store'} + if extra_headers: + headers.update(extra_headers) + return JSONResponse( + status_code=status_code, + headers=headers, + content={ + 'ok': False, + 'error': error, + 'message': message, + }, + ) + + +def get_request_bisheng_access_token(request) -> Optional[str]: + token = _parse_bearer_token(request.headers.get('authorization', '')) + if token: + return token + return request.cookies.get('access_token_cookie') + + +def _hash_session_token(token: str) -> str: + return hashlib.sha256(token.encode('utf-8')).hexdigest() + + +def hash_bisheng_session_token(token: str) -> str: + return _hash_session_token(token) + + +def _normalize_scopes(scopes: Optional[list[str] | tuple[str, ...]]) -> tuple[str, ...]: + if not scopes: + return _MCP_DEFAULT_SCOPES + normalized = [] + for scope in scopes: + if scope in _MCP_DEFAULT_SCOPES and scope not in normalized: + normalized.append(scope) + return tuple(normalized or _MCP_DEFAULT_SCOPES) + + +def normalize_mcp_scopes(scopes: Optional[list[str] | tuple[str, ...] | str]) -> tuple[str, ...]: + if isinstance(scopes, str): + scopes = [scope.strip() for scope in scopes.replace(',', ' ').split() if scope.strip()] + return _normalize_scopes(scopes) + + +async def _assert_parent_session_valid(user_id: int, parent_session_hash: str): + redis_client = await get_redis_client() + current_session = await redis_client.aget(USER_CURRENT_SESSION.format(user_id)) + if not current_session: + raise JWTDecodeError(status_code=401, message='Bisheng session expired') + if _hash_session_token(current_session) != parent_session_hash: + raise JWTDecodeError(status_code=401, message='Bisheng session has been replaced') + + +async def resolve_login_user_from_bisheng_access_token(token: str) -> UserPayload: + subject = AuthJwt().decode_jwt_token(token) + await _assert_parent_session_valid(subject['user_id'], _hash_session_token(token)) + return await UserPayload.init_login_user( + user_id=subject['user_id'], + user_name=subject['user_name'], + ) + + +def _create_mcp_access_token_from_session_hash(login_user: LoginUser, + parent_session_hash: str, + *, + scopes: Optional[list[str] | tuple[str, ...]] = None, + expires_in: int = _MCP_DEFAULT_EXPIRES_IN) -> tuple[str, dict]: + now = int(datetime.now(timezone.utc).timestamp()) + expires_in = max(60, min(int(expires_in), _MCP_MAX_EXPIRES_IN)) + normalized_scopes = list(_normalize_scopes(scopes)) + claims = { + 'sub': str(login_user.user_id), + 'user_id': login_user.user_id, + 'user_name': login_user.user_name, + 'iss': _MCP_ISSUER, + 'aud': _MCP_AUDIENCE, + 'iat': now, + 'exp': now + expires_in, + 'jti': generate_uuid(), + 'token_type': _MCP_TOKEN_TYPE, + 'scope': normalized_scopes, + 'parent_session_hash': parent_session_hash, + } + token = jwt.encode(claims, AuthJwt().jwt_secret, algorithm='HS256') + return token, { + 'access_token': token, + 'token_type': 'Bearer', + 'expires_in': expires_in, + 'scopes': normalized_scopes, + 'audience': _MCP_AUDIENCE, + } + + +def create_mcp_access_token(login_user: LoginUser, + parent_access_token: str, + *, + scopes: Optional[list[str] | tuple[str, ...]] = None, + expires_in: int = _MCP_DEFAULT_EXPIRES_IN) -> tuple[str, dict]: + return _create_mcp_access_token_from_session_hash( + login_user, + _hash_session_token(parent_access_token), + scopes=scopes, + expires_in=expires_in, + ) + + +def create_mcp_access_token_from_session_hash(login_user: LoginUser, + parent_session_hash: str, + *, + scopes: Optional[list[str] | tuple[str, ...]] = None, + expires_in: int = _MCP_DEFAULT_EXPIRES_IN) -> tuple[str, dict]: + return _create_mcp_access_token_from_session_hash( + login_user, + parent_session_hash, + scopes=scopes, + expires_in=expires_in, + ) + + +async def _validate_mcp_access_token(token: str) -> tuple[UserPayload, tuple[str, ...]]: + try: + payload = jwt.decode( + token, + AuthJwt().jwt_secret, + audience=_MCP_AUDIENCE, + issuer=_MCP_ISSUER, + algorithms=['HS256'], + ) + except Exception as exc: + raise JWTDecodeError(status_code=401, message=str(exc)) + + if payload.get('token_type') != _MCP_TOKEN_TYPE: + raise JWTDecodeError(status_code=401, message='Unsupported MCP token type') + + user_id = payload.get('user_id') + user_name = payload.get('user_name') + if not user_id or not user_name: + try: + subject = json.loads(payload.get('sub') or '{}') + except Exception as exc: + raise JWTDecodeError(status_code=401, message=f'Invalid MCP token subject: {exc}') + if not isinstance(subject, dict): + raise JWTDecodeError(status_code=401, message='Invalid MCP token subject') + user_id = user_id or subject.get('user_id') + user_name = user_name or subject.get('user_name') + parent_session_hash = payload.get('parent_session_hash') + if not user_id or not user_name or not parent_session_hash: + raise JWTDecodeError(status_code=401, message='MCP token payload is incomplete') + + await _assert_parent_session_valid(user_id, parent_session_hash) + login_user = await UserPayload.init_login_user(user_id=user_id, user_name=user_name) + return login_user, tuple(_normalize_scopes(payload.get('scope'))) + + +def require_mcp_scopes(*required_scopes: str): + current_scopes = set(get_current_token_scopes()) + missing_scopes = [scope for scope in required_scopes if scope not in current_scopes] + if missing_scopes: + raise UnAuthorizedError(msg=f'MCP token missing required scopes: {", ".join(missing_scopes)}') + + +class McpAuthorizationMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + token_ref: Token | None = None + user_ref: Token | None = None + scope_ref: Token | None = None + scope_type = scope.get('type') + if scope_type not in {'http', 'websocket'}: + await self.app(scope, receive, send) + return + + headers = _get_headers(scope) + method = scope.get('method', '').upper() + origin = headers.get('origin') + host = headers.get('host', '') + + if scope_type == 'http' and method == 'OPTIONS': + await self.app(scope, receive, send) + return + + if origin and not _is_allowed_origin(origin, host): + if scope_type == 'http': + await _error_response( + 403, + 'Origin is not allowed for Bisheng MCP', + error='forbidden_origin', + )(scope, receive, send) + else: + await send({ + 'type': 'websocket.close', + 'code': 4403, + 'reason': 'Origin is not allowed for Bisheng MCP', + }) + return + + bearer_token = _parse_bearer_token(headers.get('authorization', '')) + if not bearer_token: + if scope_type == 'http': + await _error_response( + 401, + 'Missing Bisheng bearer token for MCP request', + error='invalid_request', + extra_headers={ + 'WWW-Authenticate': _auth_header( + 'invalid_request', + 'Missing Bearer token', + ) + }, + )(scope, receive, send) + else: + await send({ + 'type': 'websocket.close', + 'code': 4401, + 'reason': 'Missing Bearer token', + }) + return + + try: + login_user, token_scopes = await _validate_mcp_access_token(bearer_token) + except JWTDecodeError as exc: + if scope_type == 'http': + await _error_response( + 401, + exc.message, + error='invalid_token', + extra_headers={ + 'WWW-Authenticate': _auth_header('invalid_token', exc.message) + }, + )(scope, receive, send) + else: + await send({ + 'type': 'websocket.close', + 'code': 4401, + 'reason': exc.message[:120], + }) + return + + token_ref = _current_access_token.set(bearer_token) + user_ref = _current_login_user.set(login_user) + scope_ref = _current_token_scopes.set(token_scopes) + try: + await self.app(scope, receive, send) + except Exception: + logger.exception('Unhandled exception in MCP authorization middleware') + raise + finally: + if scope_ref is not None: + _current_token_scopes.reset(scope_ref) + if user_ref is not None: + _current_login_user.reset(user_ref) + if token_ref is not None: + _current_access_token.reset(token_ref) + + +async def get_login_user_from_mcp_token() -> UserPayload: + login_user = get_current_login_user() + if login_user is not None: + return login_user + token = get_current_access_token() + if not token: + raise UnAuthorizedError(msg='Missing Bisheng bearer token for MCP request') + login_user, _ = await _validate_mcp_access_token(token) + return login_user diff --git a/src/backend/bisheng/mcp_server/device_flow.py b/src/backend/bisheng/mcp_server/device_flow.py new file mode 100644 index 0000000000..b8db11dde1 --- /dev/null +++ b/src/backend/bisheng/mcp_server/device_flow.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import time +from typing import Optional + +from pydantic import BaseModel, Field + +from bisheng.utils import generate_uuid + +MCP_DEVICE_FLOW_TTL_MAX = 15 * 60 +MCP_DEVICE_FLOW_TTL_DEFAULT = 10 * 60 +MCP_DEVICE_FLOW_INTERVAL_DEFAULT = 5 +MCP_DEVICE_FLOW_INTERVAL_MAX = 30 +MCP_DEVICE_REDIS_PREFIX = 'mcp:device' + + +class McpDeviceSession(BaseModel): + device_code: str + user_code: str + client_id: str + client_name: str = '' + scopes: list[str] = Field(default_factory=list) + status: str = 'pending' + expires_at: int + interval: int = MCP_DEVICE_FLOW_INTERVAL_DEFAULT + created_at: int = Field(default_factory=lambda: int(time.time())) + updated_at: int = Field(default_factory=lambda: int(time.time())) + last_poll_at: int = 0 + user_id: Optional[int] = None + user_name: str = '' + parent_session_hash: str = '' + denied_reason: str = '' + + @property + def expired(self) -> bool: + return self.expires_at <= int(time.time()) + + @property + def expires_in(self) -> int: + return max(0, self.expires_at - int(time.time())) + + +def normalize_device_flow_ttl(expires_in: Optional[int]) -> int: + if expires_in is None: + return MCP_DEVICE_FLOW_TTL_DEFAULT + return max(60, min(int(expires_in), MCP_DEVICE_FLOW_TTL_MAX)) + + +def normalize_device_flow_interval(interval: Optional[int]) -> int: + if interval is None: + return MCP_DEVICE_FLOW_INTERVAL_DEFAULT + return max(1, min(int(interval), MCP_DEVICE_FLOW_INTERVAL_MAX)) + + +def device_code_key(device_code: str) -> str: + return f'{MCP_DEVICE_REDIS_PREFIX}:code:{device_code}' + + +def user_code_key(user_code: str) -> str: + return f'{MCP_DEVICE_REDIS_PREFIX}:user:{user_code}' + + +def generate_device_code() -> str: + return generate_uuid().replace('-', '') + + +def generate_user_code() -> str: + raw = generate_uuid().replace('-', '').upper() + return f'{raw[:4]}-{raw[4:8]}' + + +async def save_device_session(redis_client, session: McpDeviceSession): + expiration = max(1, session.expires_in) + await redis_client.aset(device_code_key(session.device_code), session.model_dump(), expiration=expiration) + await redis_client.aset(user_code_key(session.user_code), session.device_code, expiration=expiration) + + +async def load_device_session_by_device_code(redis_client, device_code: str) -> Optional[McpDeviceSession]: + payload = await redis_client.aget(device_code_key(device_code)) + if not payload: + return None + if isinstance(payload, McpDeviceSession): + return payload + return McpDeviceSession.model_validate(payload) + + +async def load_device_session_by_user_code(redis_client, user_code: str) -> Optional[McpDeviceSession]: + device_code = await redis_client.aget(user_code_key(user_code)) + if not device_code: + return None + return await load_device_session_by_device_code(redis_client, device_code) + + +async def delete_device_session(redis_client, session: McpDeviceSession): + await redis_client.adelete(device_code_key(session.device_code)) + await redis_client.adelete(user_code_key(session.user_code)) diff --git a/src/backend/bisheng/mcp_server/workflow.py b/src/backend/bisheng/mcp_server/workflow.py new file mode 100644 index 0000000000..e28617934b --- /dev/null +++ b/src/backend/bisheng/mcp_server/workflow.py @@ -0,0 +1,668 @@ +from typing import Optional + +from loguru import logger +from pydantic import BaseModel, Field + +from bisheng.api.services.external_workflow import ExternalWorkflowService +from bisheng.api.services.workflow_authoring import WorkflowAuthoringService +from bisheng.api.v1.schemas import GraphData +from bisheng.common.errcode.base import BaseErrorCode +from bisheng.mcp_server.auth import McpAuthorizationMiddleware, get_current_token_scopes, get_login_user_from_mcp_token, require_mcp_scopes +from bisheng.workflow.authoring import ( + NodeTemplateDescriptor, + NodeTypeDescriptor, + ValidationDiagnostic, + WorkflowGraphDescriptor, + WorkflowManifest, + WorkflowVersionSummary, +) + +try: + from mcp.server.fastmcp import FastMCP +except ModuleNotFoundError: + FastMCP = None + + +class WorkflowMutationResult(BaseModel): + ok: bool = True + message: str = 'SUCCESS' + flow_id: Optional[str] = None + version_id: Optional[int] = None + status: Optional[str] = None + draft_revision: Optional[int] = None + node_id: Optional[str] = None + edge_id: Optional[str] = None + error_code: Optional[int] = None + + +class WorkflowValidationResult(BaseModel): + ok: bool = True + valid: bool = True + flow_id: Optional[str] = None + version_id: Optional[int] = None + draft_revision: Optional[int] = None + errors: list[str] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) + diagnostics: list[ValidationDiagnostic] = Field(default_factory=list) + error_code: Optional[int] = None + + +class WorkflowNodeSummary(BaseModel): + id: str + type: str = '' + name: str = '' + param_keys: list[str] = Field(default_factory=list) + + +class WorkflowNodeListResult(BaseModel): + ok: bool = True + flow_id: Optional[str] = None + version_id: Optional[int] = None + draft_revision: Optional[int] = None + nodes: list[WorkflowNodeSummary] = Field(default_factory=list) + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class WorkflowNodeParamField(BaseModel): + display_name: str = '' + group_name: str = '' + type: Optional[str] = None + required: bool = False + show: bool = False + options: Optional[object] = None + scope: Optional[object] = None + placeholder: Optional[str] = None + refresh: bool = False + value: Optional[object] = None + + +class WorkflowNodeParamsResult(BaseModel): + ok: bool = True + flow_id: Optional[str] = None + version_id: Optional[int] = None + draft_revision: Optional[int] = None + node_id: Optional[str] = None + node_type: str = '' + node_name: str = '' + params: dict[str, WorkflowNodeParamField] = Field(default_factory=dict) + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class ConditionNodeResult(BaseModel): + ok: bool = True + flow_id: Optional[str] = None + version_id: Optional[int] = None + draft_revision: Optional[int] = None + node_id: Optional[str] = None + node_name: str = '' + condition_cases: list[dict[str, object]] = Field(default_factory=list) + route_handles: list[str] = Field(default_factory=list) + outgoing_edges: dict[str, list[dict[str, str]]] = Field(default_factory=dict) + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class WorkflowConnectionResult(BaseModel): + ok: bool = True + service: str = 'bisheng-workflow-mcp' + authenticated: bool = True + user_id: Optional[int] = None + user_name: str = '' + scopes: list[str] = Field(default_factory=list) + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class WorkflowManifestResult(BaseModel): + ok: bool = True + workflow: Optional[WorkflowManifest] = None + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class WorkflowManifestListResult(BaseModel): + ok: bool = True + workflows: list[WorkflowManifest] = Field(default_factory=list) + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class WorkflowVersionListResult(BaseModel): + ok: bool = True + flow_id: Optional[str] = None + versions: list[WorkflowVersionSummary] = Field(default_factory=list) + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class WorkflowGraphResult(BaseModel): + ok: bool = True + graph: Optional[WorkflowGraphDescriptor] = None + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class NodeTypeListResult(BaseModel): + ok: bool = True + node_types: list[NodeTypeDescriptor] = Field(default_factory=list) + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +class NodeTemplateResult(BaseModel): + ok: bool = True + template: Optional[NodeTemplateDescriptor] = None + message: str = 'SUCCESS' + error_code: Optional[int] = None + + +def _mutation_error(exc: Exception) -> WorkflowMutationResult: + return _error_result(WorkflowMutationResult, exc) + + +def _validation_error(exc: Exception) -> WorkflowValidationResult: + message, error_code = _error_details(exc) + return WorkflowValidationResult( + ok=False, + valid=False, + errors=[message], + diagnostics=WorkflowAuthoringService.diagnostics_from_exception(exc), + error_code=error_code, + ) + + +def _node_list_error(exc: Exception) -> WorkflowNodeListResult: + return _error_result(WorkflowNodeListResult, exc) + + +def _node_params_error(exc: Exception) -> WorkflowNodeParamsResult: + return _error_result(WorkflowNodeParamsResult, exc) + + +def _condition_node_error(exc: Exception) -> ConditionNodeResult: + return _error_result(ConditionNodeResult, exc) + + +def _workflow_manifest_error(exc: Exception) -> WorkflowManifestResult: + return _error_result(WorkflowManifestResult, exc) + + +def _workflow_manifest_list_error(exc: Exception) -> WorkflowManifestListResult: + return _error_result(WorkflowManifestListResult, exc) + + +def _workflow_version_list_error(exc: Exception) -> WorkflowVersionListResult: + return _error_result(WorkflowVersionListResult, exc) + + +def _workflow_graph_error(exc: Exception) -> WorkflowGraphResult: + return _error_result(WorkflowGraphResult, exc) + + +def _node_type_list_error(exc: Exception) -> NodeTypeListResult: + return _error_result(NodeTypeListResult, exc) + + +def _node_template_result_error(exc: Exception) -> NodeTemplateResult: + return _error_result(NodeTemplateResult, exc) + + +def _error_details(exc: Exception) -> tuple[str, Optional[int]]: + if isinstance(exc, BaseErrorCode): + return exc.message, exc.code + return str(exc) or exc.__class__.__name__, None + + +def _error_result(result_cls, exc: Exception, **kwargs): + message, error_code = _error_details(exc) + return result_cls(ok=False, message=message, error_code=error_code, **kwargs) + + +def _connection_error(exc: Exception) -> WorkflowConnectionResult: + message, error_code = _error_details(exc) + return WorkflowConnectionResult( + ok=False, + authenticated=False, + message=message, + error_code=error_code, + ) + + +def _log_tool_failure(tool_name: str, exc: Exception): + if isinstance(exc, BaseErrorCode): + logger.warning(f'{tool_name} failed: {exc.message} (code={exc.code})') + return + logger.exception(f'{tool_name} failed') + + +def create_workflow_mcp_server(): + if FastMCP is None: + logger.warning('mcp dependency is not available, workflow MCP server will not be mounted') + return None + + mcp = FastMCP( + 'Bisheng Workflow MCP', + instructions='Create, validate, and publish Bisheng workflows as drafts first.', + json_response=True, + ) + try: + mcp.settings.streamable_http_path = '/' + except Exception: + pass + + def _connection_result(login_user, *, include_scopes: bool = True) -> WorkflowConnectionResult: + return WorkflowConnectionResult( + user_id=login_user.user_id, + user_name=login_user.user_name, + scopes=list(get_current_token_scopes()) if include_scopes else [], + ) + + async def _run_authenticated_tool(tool_name: str, + operation, + *, + scope: Optional[str] = None, + error_builder=None): + try: + login_user = await get_login_user_from_mcp_token() + if scope: + require_mcp_scopes(scope) + return await operation(login_user) + except Exception as exc: + _log_tool_failure(tool_name, exc) + return error_builder(exc) + + @mcp.tool() + async def ping() -> WorkflowConnectionResult: + """Verify the MCP connection and return the current authenticated Bisheng user.""" + async def _op(login_user): + return _connection_result(login_user, include_scopes=False) + + return await _run_authenticated_tool('ping', _op, error_builder=_connection_error) + + @mcp.tool() + async def whoami() -> WorkflowConnectionResult: + """Return the authenticated Bisheng user behind the current MCP bearer token.""" + async def _op(login_user): + return _connection_result(login_user, include_scopes=True) + + return await _run_authenticated_tool('whoami', _op, error_builder=_connection_error) + + @mcp.tool() + async def list_workflows() -> WorkflowManifestListResult: + """List workflows the current MCP user can author.""" + async def _op(login_user): + workflows = await WorkflowAuthoringService.list_workflows(login_user) + return WorkflowManifestListResult(workflows=workflows) + return await _run_authenticated_tool( + 'list_workflows', + _op, + scope='workflow.read', + error_builder=_workflow_manifest_list_error, + ) + + @mcp.tool() + async def get_workflow(flow_id: str) -> WorkflowManifestResult: + """Return manifest information for one authorable workflow.""" + async def _op(login_user): + workflow = await WorkflowAuthoringService.get_workflow(login_user, flow_id) + return WorkflowManifestResult(workflow=workflow) + return await _run_authenticated_tool('get_workflow', _op, scope='workflow.read', + error_builder=_workflow_manifest_error) + + @mcp.tool() + async def get_workflow_versions(flow_id: str) -> WorkflowVersionListResult: + """List versions for one authorable workflow.""" + async def _op(login_user): + versions = await WorkflowAuthoringService.get_workflow_versions(login_user, flow_id) + return WorkflowVersionListResult(flow_id=flow_id, versions=versions) + return await _run_authenticated_tool('get_workflow_versions', _op, scope='workflow.read', + error_builder=_workflow_version_list_error) + + @mcp.tool() + async def get_workflow_graph(flow_id: str, version_id: int = 0) -> WorkflowGraphResult: + """Return the normalized editable graph for a workflow.""" + async def _op(login_user): + graph = await WorkflowAuthoringService.get_workflow_graph( + login_user=login_user, + flow_id=flow_id, + version_id=version_id or None, + ) + return WorkflowGraphResult(graph=graph) + return await _run_authenticated_tool('get_workflow_graph', _op, scope='workflow.read', + error_builder=_workflow_graph_error) + + @mcp.tool() + async def list_node_types() -> NodeTypeListResult: + """List supported workflow node types for authoring.""" + async def _op(_login_user): + return NodeTypeListResult(node_types=WorkflowAuthoringService.list_node_types()) + return await _run_authenticated_tool('list_node_types', _op, scope='workflow.read', + error_builder=_node_type_list_error) + + @mcp.tool() + async def get_node_template(node_type: str) -> NodeTemplateResult: + """Return the normalized template metadata for one workflow node type.""" + async def _op(_login_user): + return NodeTemplateResult(template=WorkflowAuthoringService.get_node_template(node_type)) + return await _run_authenticated_tool('get_node_template', _op, scope='workflow.read', + error_builder=_node_template_result_error) + + @mcp.tool() + async def list_workflow_nodes(flow_id: str, version_id: int = 0) -> WorkflowNodeListResult: + """List nodes in a workflow version so the caller can choose a node_id to edit.""" + async def _op(login_user): + data = await ExternalWorkflowService.list_workflow_nodes( + login_user=login_user, + flow_id=flow_id, + version_id=version_id or None, + ) + return WorkflowNodeListResult(**data) + return await _run_authenticated_tool('list_workflow_nodes', _op, scope='workflow.read', + error_builder=_node_list_error) + + @mcp.tool() + async def get_workflow_node_params(flow_id: str, + node_id: str, + version_id: int = 0) -> WorkflowNodeParamsResult: + """Read the current editable params of a workflow node.""" + async def _op(login_user): + data = await ExternalWorkflowService.get_workflow_node_params( + login_user=login_user, + flow_id=flow_id, + node_id=node_id, + version_id=version_id or None, + ) + return WorkflowNodeParamsResult(**data) + return await _run_authenticated_tool('get_workflow_node_params', _op, scope='workflow.read', + error_builder=_node_params_error) + + @mcp.tool() + async def get_condition_node(flow_id: str, + node_id: str, + version_id: int = 0) -> ConditionNodeResult: + """Read the structured routing configuration of one condition node.""" + async def _op(login_user): + data = await ExternalWorkflowService.get_condition_node_config( + login_user=login_user, + flow_id=flow_id, + node_id=node_id, + version_id=version_id or None, + ) + return ConditionNodeResult(**data) + return await _run_authenticated_tool('get_condition_node', _op, scope='workflow.read', + error_builder=_condition_node_error) + + @mcp.tool() + async def create_workflow_draft(name: str, + graph_data: GraphData, + description: str = '', + guide_word: str = '') -> WorkflowMutationResult: + """Create a new Bisheng workflow draft for the current logged-in user.""" + async def _op(login_user): + flow, version = await ExternalWorkflowService.create_workflow_draft( + login_user=login_user, + name=name, + graph_data=graph_data.model_dump(), + description=description or None, + guide_word=guide_word or None, + ) + return WorkflowMutationResult( + flow_id=flow.id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + ) + return await _run_authenticated_tool('create_workflow_draft', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def update_workflow_draft(flow_id: str, + graph_data: GraphData, + name: str = '', + description: str = '', + guide_word: str = '', + expected_revision: int = 0) -> WorkflowMutationResult: + """Create a new draft version for an existing Bisheng workflow.""" + async def _op(login_user): + _, version = await ExternalWorkflowService.update_workflow_draft( + login_user=login_user, + flow_id=flow_id, + graph_data=graph_data.model_dump(), + name=name or None, + description=description or None, + guide_word=guide_word or None, + expected_revision=expected_revision if expected_revision >= 0 else None, + ) + return WorkflowMutationResult( + flow_id=flow_id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + ) + return await _run_authenticated_tool('update_workflow_draft', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def add_node(flow_id: str, + node_type: str, + name: str = '', + position_x: float = 0, + position_y: float = 0, + initial_params: Optional[dict] = None, + version_id: int = 0, + expected_revision: int = 0) -> WorkflowMutationResult: + """Add one node to the editable workflow graph.""" + async def _op(login_user): + flow, version, node_id = await ExternalWorkflowService.add_workflow_node( + login_user=login_user, + flow_id=flow_id, + node_type=node_type, + name=name, + position_x=position_x, + position_y=position_y, + initial_params=initial_params, + version_id=version_id or None, + expected_revision=expected_revision or None, + ) + return WorkflowMutationResult( + flow_id=flow.id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + node_id=node_id, + ) + return await _run_authenticated_tool('add_node', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def remove_node(flow_id: str, + node_id: str, + cascade: bool = True, + version_id: int = 0, + expected_revision: int = 0) -> WorkflowMutationResult: + """Remove one node from the editable workflow graph.""" + async def _op(login_user): + flow, version = await ExternalWorkflowService.remove_workflow_node( + login_user=login_user, + flow_id=flow_id, + node_id=node_id, + cascade=cascade, + version_id=version_id or None, + expected_revision=expected_revision or None, + ) + return WorkflowMutationResult( + flow_id=flow.id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + node_id=node_id, + ) + return await _run_authenticated_tool('remove_node', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def connect_nodes(flow_id: str, + source_node_id: str, + target_node_id: str, + source_handle: str, + target_handle: str, + version_id: int = 0, + expected_revision: int = 0) -> WorkflowMutationResult: + """Connect two existing workflow nodes with one edge.""" + async def _op(login_user): + flow, version, edge_id = await ExternalWorkflowService.connect_workflow_nodes( + login_user=login_user, + flow_id=flow_id, + source_node_id=source_node_id, + target_node_id=target_node_id, + source_handle=source_handle, + target_handle=target_handle, + version_id=version_id or None, + expected_revision=expected_revision or None, + ) + return WorkflowMutationResult( + flow_id=flow.id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + edge_id=edge_id, + ) + return await _run_authenticated_tool('connect_nodes', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def disconnect_edge(flow_id: str, + edge_id: str = '', + source_node_id: str = '', + target_node_id: str = '', + source_handle: str = '', + target_handle: str = '', + version_id: int = 0, + expected_revision: int = 0) -> WorkflowMutationResult: + """Disconnect one edge from the editable workflow graph.""" + async def _op(login_user): + flow, version, removed_edge_id = await ExternalWorkflowService.disconnect_workflow_edge( + login_user=login_user, + flow_id=flow_id, + edge_id=edge_id, + source_node_id=source_node_id, + target_node_id=target_node_id, + source_handle=source_handle, + target_handle=target_handle, + version_id=version_id or None, + expected_revision=expected_revision or None, + ) + return WorkflowMutationResult( + flow_id=flow.id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + edge_id=removed_edge_id, + ) + return await _run_authenticated_tool('disconnect_edge', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def update_workflow_node_params(flow_id: str, + node_id: str, + updates: dict[str, object], + version_id: int = 0, + expected_revision: int = 0) -> WorkflowMutationResult: + """Create a new draft version by patching one node's params, such as prompt or temperature.""" + async def _op(login_user): + _, version = await ExternalWorkflowService.update_workflow_node_params( + login_user=login_user, + flow_id=flow_id, + node_id=node_id, + updates=updates, + version_id=version_id or None, + expected_revision=expected_revision if expected_revision >= 0 else None, + ) + return WorkflowMutationResult( + flow_id=flow_id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + ) + return await _run_authenticated_tool('update_workflow_node_params', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def update_condition_node(flow_id: str, + node_id: str, + condition_cases: list[dict[str, object]], + version_id: int = 0, + expected_revision: int = 0) -> WorkflowMutationResult: + """Update the structured condition cases of one existing condition node.""" + async def _op(login_user): + flow, version = await ExternalWorkflowService.update_condition_node( + login_user=login_user, + flow_id=flow_id, + node_id=node_id, + condition_cases=condition_cases, + version_id=version_id or None, + expected_revision=expected_revision if expected_revision >= 0 else None, + ) + return WorkflowMutationResult( + flow_id=flow.id, + version_id=version.id, + status='draft', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + node_id=node_id, + ) + return await _run_authenticated_tool('update_condition_node', _op, scope='workflow.write', + error_builder=_mutation_error) + + @mcp.tool() + async def validate_workflow(flow_id: str, version_id: int) -> WorkflowValidationResult: + """Validate a Bisheng workflow version without publishing it.""" + async def _op(login_user): + flow, version = await ExternalWorkflowService.validate_workflow( + login_user=login_user, + flow_id=flow_id, + version_id=version_id, + ) + return WorkflowValidationResult( + flow_id=flow.id, + version_id=version.id, + valid=True, + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + diagnostics=[], + ) + return await _run_authenticated_tool('validate_workflow', _op, scope='workflow.write', + error_builder=_validation_error) + + @mcp.tool() + async def publish_workflow(flow_id: str, version_id: int) -> WorkflowMutationResult: + """Publish a validated Bisheng workflow version.""" + async def _op(login_user): + flow, version = await ExternalWorkflowService.publish_workflow( + login_user=login_user, + flow_id=flow_id, + version_id=version_id, + ) + return WorkflowMutationResult( + flow_id=flow.id, + version_id=version.id, + status='published', + draft_revision=ExternalWorkflowService.get_graph_revision(version.data), + ) + return await _run_authenticated_tool('publish_workflow', _op, scope='workflow.publish', + error_builder=_mutation_error) + + return mcp + + +_workflow_mcp_server = create_workflow_mcp_server() + + +def get_workflow_mcp_server(): + return _workflow_mcp_server + + +def get_workflow_mcp_asgi_app(): + if _workflow_mcp_server is None: + return None + return McpAuthorizationMiddleware(_workflow_mcp_server.streamable_http_app()) diff --git a/src/backend/bisheng/user/api/user.py b/src/backend/bisheng/user/api/user.py index b482598ead..1eb79693ad 100644 --- a/src/backend/bisheng/user/api/user.py +++ b/src/backend/bisheng/user/api/user.py @@ -1,8 +1,11 @@ import hashlib +import html import random +import time from base64 import b64encode from datetime import datetime from io import BytesIO +from types import SimpleNamespace from typing import Annotated, Dict, List, Optional import rsa @@ -10,7 +13,9 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Body, Request from fastapi.security import OAuth2PasswordBearer from loguru import logger +from pydantic import BaseModel, Field from sqlmodel import select +from starlette.responses import HTMLResponse, JSONResponse from bisheng.api.services.audit_log import AuditLogService from bisheng.api.v1.schemas import resp_200, CreateUserReq @@ -25,6 +30,25 @@ from bisheng.database.models.role import Role, RoleCreate, RoleDao, RoleUpdate from bisheng.database.models.role_access import RoleRefresh, RoleAccessDao, AccessType from bisheng.database.models.user_group import UserGroupDao +from bisheng.mcp_server.auth import ( + create_mcp_access_token, + create_mcp_access_token_from_session_hash, + get_request_bisheng_access_token, + hash_bisheng_session_token, + normalize_mcp_scopes, + resolve_login_user_from_bisheng_access_token, +) +from bisheng.mcp_server.device_flow import ( + McpDeviceSession, + delete_device_session, + generate_device_code, + generate_user_code, + load_device_session_by_device_code, + load_device_session_by_user_code, + normalize_device_flow_interval, + normalize_device_flow_ttl, + save_device_session, +) from bisheng.utils import generate_uuid from bisheng.utils import get_request_ip from bisheng.utils.constants import CAPTCHA_PREFIX, RSA_KEY, USER_PASSWORD_ERROR, USER_CURRENT_SESSION @@ -43,6 +67,94 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl='token') +class McpTokenCreateRequest(BaseModel): + expires_in: int = Field(default=1800, ge=60, le=3600, description='MCP token ttl in seconds') + + +class McpDeviceAuthorizeRequest(BaseModel): + client_id: str = Field(min_length=1, max_length=200, description='Public MCP client id') + client_name: str = Field(default='', max_length=200, description='Display name shown on the approval page') + scope: str = Field(default='', description='Space separated MCP scopes') + expires_in: int = Field(default=600, ge=60, le=900, description='Device code ttl in seconds') + interval: int = Field(default=5, ge=1, le=30, description='Polling interval in seconds') + + +class McpDeviceTokenRequest(BaseModel): + grant_type: str = Field(default='urn:ietf:params:oauth:grant-type:device_code') + device_code: str = Field(min_length=1) + client_id: str = Field(default='', max_length=200) + + +def _mcp_device_token_error(error: str, description: str, status_code: int = 400): + return JSONResponse( + status_code=status_code, + content={ + 'error': error, + 'error_description': description, + }, + ) + + +def _mcp_device_verify_html(*, + user_code: str, + title: str, + message: str, + status: str, + client_name: str = '', + scopes: Optional[list[str]] = None, + can_approve: bool = False): + escaped_title = html.escape(title) + escaped_message = html.escape(message) + escaped_user_code = html.escape(user_code) + escaped_client_name = html.escape(client_name or 'MCP Client') + scopes_html = ''.join( + f'
  • {html.escape(scope)}
  • ' for scope in (scopes or []) + ) or '
  • workflow.read workflow.write workflow.publish
  • ' + form_html = '' + if can_approve: + form_html = f""" +
    + + + +
    + """ + body = f""" + + + + + {escaped_title} + + + +
    +
    {html.escape(status)}
    +

    {escaped_title}

    +

    {escaped_message}

    +

    User code: {escaped_user_code}

    +

    Client: {escaped_client_name}

    +

    Requested scopes:

    +
      {scopes_html}
    + {form_html} +
    + +""" + return HTMLResponse(body) + + @router.post('/user/regist') async def regist(*, user: UserCreate): # Captcha Verification @@ -109,6 +221,260 @@ async def login(*, request: Request, user: UserLogin, auth_jwt: AuthJwt = Depend return await UserService.user_login(request, user=user, auth_jwt=auth_jwt) +@router.post('/user/mcp_token') +async def create_workflow_mcp_token(request: Request, + body: McpTokenCreateRequest = Body(default=None)): + body = body or McpTokenCreateRequest() + access_token = get_request_bisheng_access_token(request) + if not access_token: + return UnAuthorizedError.return_resp(msg='Bisheng login required before issuing MCP token') + login_user = await resolve_login_user_from_bisheng_access_token(access_token) + _, token_payload = create_mcp_access_token( + login_user, + access_token, + expires_in=body.expires_in, + ) + return resp_200(token_payload) + + +@router.post('/user/mcp/device/authorize') +async def create_mcp_device_authorization(request: Request, + body: McpDeviceAuthorizeRequest): + redis_client = await get_redis_client() + scopes = list(normalize_mcp_scopes(body.scope)) + session = McpDeviceSession( + device_code=generate_device_code(), + user_code=generate_user_code(), + client_id=body.client_id, + client_name=body.client_name.strip(), + scopes=scopes, + expires_at=int(time.time()) + normalize_device_flow_ttl(body.expires_in), + interval=normalize_device_flow_interval(body.interval), + ) + await save_device_session(redis_client, session) + verification_uri = str(request.url_for('mcp_device_verify_page')) + return resp_200(data={ + 'device_code': session.device_code, + 'user_code': session.user_code, + 'verification_uri': verification_uri, + 'verification_uri_complete': f'{verification_uri}?user_code={session.user_code}', + 'expires_in': session.expires_in, + 'interval': session.interval, + 'scope': ' '.join(session.scopes), + 'scopes': session.scopes, + }) + + +@router.get('/user/mcp/device/verify', name='mcp_device_verify_page') +async def mcp_device_verify_page(request: Request, user_code: str = Query(default='')): + normalized_user_code = user_code.strip().upper() + if not normalized_user_code: + return _mcp_device_verify_html( + user_code='', + title='Missing user code', + message='Open this page with the device flow user code from your MCP client.', + status='invalid_request', + ) + + redis_client = await get_redis_client() + session = await load_device_session_by_user_code(redis_client, normalized_user_code) + if session is None: + return _mcp_device_verify_html( + user_code=normalized_user_code, + title='Device code not found', + message='This device authorization request does not exist or has already been completed.', + status='invalid_grant', + ) + if session.expired: + await delete_device_session(redis_client, session) + return _mcp_device_verify_html( + user_code=session.user_code, + title='Device code expired', + message='Start a new device authorization request from your MCP client.', + status='expired_token', + client_name=session.client_name, + scopes=session.scopes, + ) + + access_token = get_request_bisheng_access_token(request) + if not access_token: + return _mcp_device_verify_html( + user_code=session.user_code, + title='Login required', + message='Log in to Bisheng in this browser first, then refresh this page to approve the MCP client.', + status='login_required', + client_name=session.client_name, + scopes=session.scopes, + ) + + try: + login_user = await resolve_login_user_from_bisheng_access_token(access_token) + except Exception: + return _mcp_device_verify_html( + user_code=session.user_code, + title='Login required', + message='Your Bisheng session is invalid or expired. Log in again, then refresh this page.', + status='login_required', + client_name=session.client_name, + scopes=session.scopes, + ) + + if session.status == 'approved': + message = f'This request has already been approved for {login_user.user_name}. Return to the MCP client.' + return _mcp_device_verify_html( + user_code=session.user_code, + title='Already approved', + message=message, + status='approved', + client_name=session.client_name, + scopes=session.scopes, + ) + + if session.status == 'denied': + return _mcp_device_verify_html( + user_code=session.user_code, + title='Request denied', + message='This device authorization request has already been denied.', + status='access_denied', + client_name=session.client_name, + scopes=session.scopes, + ) + + return _mcp_device_verify_html( + user_code=session.user_code, + title='Approve MCP client', + message=f'You are signed in as {login_user.user_name}. Approve this MCP client to access Bisheng workflows.', + status='pending', + client_name=session.client_name, + scopes=session.scopes, + can_approve=True, + ) + + +@router.post('/user/mcp/device/verify') +async def approve_mcp_device_authorization(request: Request): + form = await request.form() + user_code = str(form.get('user_code', '')).strip().upper() + action = str(form.get('action', 'approve')).strip().lower() + redis_client = await get_redis_client() + session = await load_device_session_by_user_code(redis_client, user_code) + if session is None: + return _mcp_device_verify_html( + user_code=user_code, + title='Device code not found', + message='This device authorization request does not exist or has already been completed.', + status='invalid_grant', + ) + if session.expired: + await delete_device_session(redis_client, session) + return _mcp_device_verify_html( + user_code=session.user_code, + title='Device code expired', + message='Start a new device authorization request from your MCP client.', + status='expired_token', + client_name=session.client_name, + scopes=session.scopes, + ) + + access_token = get_request_bisheng_access_token(request) + if not access_token: + return _mcp_device_verify_html( + user_code=session.user_code, + title='Login required', + message='Log in to Bisheng in this browser first, then retry approval.', + status='login_required', + client_name=session.client_name, + scopes=session.scopes, + ) + try: + login_user = await resolve_login_user_from_bisheng_access_token(access_token) + except Exception: + return _mcp_device_verify_html( + user_code=session.user_code, + title='Login required', + message='Your Bisheng session is invalid or expired. Log in again, then retry approval.', + status='login_required', + client_name=session.client_name, + scopes=session.scopes, + ) + + session.updated_at = int(time.time()) + if action == 'deny': + session.status = 'denied' + session.denied_reason = f'denied by {login_user.user_name}' + await save_device_session(redis_client, session) + return _mcp_device_verify_html( + user_code=session.user_code, + title='Request denied', + message='The MCP client authorization request has been denied.', + status='access_denied', + client_name=session.client_name, + scopes=session.scopes, + ) + + session.status = 'approved' + session.user_id = login_user.user_id + session.user_name = login_user.user_name + session.parent_session_hash = hash_bisheng_session_token(access_token) + await save_device_session(redis_client, session) + return _mcp_device_verify_html( + user_code=session.user_code, + title='Request approved', + message='Return to the MCP client. It can now exchange the device code for an MCP access token.', + status='approved', + client_name=session.client_name, + scopes=session.scopes, + ) + + +@router.post('/user/mcp/device/token') +async def issue_mcp_device_token(body: McpDeviceTokenRequest): + if body.grant_type != 'urn:ietf:params:oauth:grant-type:device_code': + return _mcp_device_token_error('unsupported_grant_type', 'Only device_code grant_type is supported.') + + redis_client = await get_redis_client() + session = await load_device_session_by_device_code(redis_client, body.device_code) + if session is None: + return _mcp_device_token_error('invalid_grant', 'The device_code is invalid or has already been consumed.') + if body.client_id and body.client_id != session.client_id: + return _mcp_device_token_error('invalid_client', 'The client_id does not match this device authorization request.') + if session.expired: + await delete_device_session(redis_client, session) + return _mcp_device_token_error('expired_token', 'The device authorization request has expired.') + + now = int(time.time()) + if session.last_poll_at and now - session.last_poll_at < session.interval: + session.last_poll_at = now + session.updated_at = now + await save_device_session(redis_client, session) + return _mcp_device_token_error('slow_down', f'Poll no faster than every {session.interval} seconds.') + + session.last_poll_at = now + session.updated_at = now + await save_device_session(redis_client, session) + + if session.status == 'pending': + return _mcp_device_token_error('authorization_pending', 'The end user has not approved this device yet.') + if session.status == 'denied': + await delete_device_session(redis_client, session) + description = session.denied_reason or 'The end user denied the device authorization request.' + return _mcp_device_token_error('access_denied', description) + if session.status != 'approved' or not session.user_id or not session.parent_session_hash: + await delete_device_session(redis_client, session) + return _mcp_device_token_error('invalid_grant', 'The device authorization request is in an invalid state.') + + login_user = SimpleNamespace(user_id=session.user_id, user_name=session.user_name) + _, token_payload = create_mcp_access_token_from_session_hash( + login_user, + session.parent_session_hash, + scopes=session.scopes, + ) + await delete_device_session(redis_client, session) + token_payload['scope'] = ' '.join(session.scopes) + token_payload['mcp_url'] = '/mcp' + return JSONResponse(content=token_payload) + + @router.get('/user/admin') async def get_admins(login_user: LoginUser = Depends(LoginUser.get_login_user)): """ diff --git a/src/backend/bisheng/workflow/authoring/__init__.py b/src/backend/bisheng/workflow/authoring/__init__.py new file mode 100644 index 0000000000..2ce28ae081 --- /dev/null +++ b/src/backend/bisheng/workflow/authoring/__init__.py @@ -0,0 +1,39 @@ +from bisheng.workflow.authoring.contract import ( + WORKFLOW_AUTHORING_SCHEMA_VERSION, + NodeTemplateDescriptor, + NodeTypeDescriptor, + ValidationDiagnostic, + ValidationSeverity, + WorkflowGraphDescriptor, + WorkflowGraphNodeDescriptor, + WorkflowManifest, + WorkflowParamGroupDescriptor, + WorkflowParamMetadata, + WorkflowTabDescriptor, + WorkflowTabOption, + WorkflowVersionSummary, +) +from bisheng.workflow.authoring.registry import ( + get_node_template_descriptor, + list_node_type_descriptors, + normalize_tab_descriptor, +) + +__all__ = [ + 'WORKFLOW_AUTHORING_SCHEMA_VERSION', + 'NodeTemplateDescriptor', + 'NodeTypeDescriptor', + 'ValidationDiagnostic', + 'ValidationSeverity', + 'WorkflowGraphDescriptor', + 'WorkflowGraphNodeDescriptor', + 'WorkflowManifest', + 'WorkflowParamGroupDescriptor', + 'WorkflowParamMetadata', + 'WorkflowTabDescriptor', + 'WorkflowTabOption', + 'WorkflowVersionSummary', + 'get_node_template_descriptor', + 'list_node_type_descriptors', + 'normalize_tab_descriptor', +] diff --git a/src/backend/bisheng/workflow/authoring/contract.py b/src/backend/bisheng/workflow/authoring/contract.py new file mode 100644 index 0000000000..e95d721d42 --- /dev/null +++ b/src/backend/bisheng/workflow/authoring/contract.py @@ -0,0 +1,116 @@ +from datetime import datetime +from enum import Enum +from typing import Optional, Any + +from pydantic import BaseModel, Field + + +WORKFLOW_AUTHORING_SCHEMA_VERSION = 'workflow-authoring.v1' + + +class ValidationSeverity(str, Enum): + ERROR = 'error' + WARNING = 'warning' + + +class WorkflowTabOption(BaseModel): + key: str = '' + label: str = '' + help: Optional[str] = None + + +class WorkflowTabDescriptor(BaseModel): + value: Optional[str] = None + options: list[WorkflowTabOption] = Field(default_factory=list) + + +class WorkflowParamMetadata(BaseModel): + display_name: str = '' + group_name: str = '' + type: Optional[str] = None + required: bool = False + show: bool = True + options: Optional[Any] = None + scope: Optional[Any] = None + placeholder: Optional[str] = None + refresh: bool = False + value: Optional[Any] = None + + +class WorkflowParamGroupDescriptor(BaseModel): + name: str = '' + group_key: Optional[str] = None + param_keys: list[str] = Field(default_factory=list) + + +class WorkflowManifest(BaseModel): + flow_id: str + name: str = '' + description: Optional[str] = None + status: str = 'offline' + current_version_id: Optional[int] = None + editable_version_id: Optional[int] = None + draft_revision: int = 0 + schema_version: str = WORKFLOW_AUTHORING_SCHEMA_VERSION + + +class WorkflowVersionSummary(BaseModel): + version_id: int + name: str = '' + description: Optional[str] = None + is_current: bool = False + is_editable: bool = False + is_external_draft: bool = False + original_version_id: Optional[int] = None + draft_revision: int = 0 + schema_version: str = WORKFLOW_AUTHORING_SCHEMA_VERSION + create_time: Optional[datetime] = None + update_time: Optional[datetime] = None + + +class WorkflowGraphNodeDescriptor(BaseModel): + id: str + type: str = '' + name: str = '' + description: Optional[str] = None + tab: Optional[WorkflowTabDescriptor] = None + param_keys: list[str] = Field(default_factory=list) + params: dict[str, WorkflowParamMetadata] = Field(default_factory=dict) + + +class WorkflowGraphDescriptor(BaseModel): + flow_id: str + version_id: int + draft_revision: int = 0 + schema_version: str = WORKFLOW_AUTHORING_SCHEMA_VERSION + nodes: list[WorkflowGraphNodeDescriptor] = Field(default_factory=list) + edges: list[dict[str, Any]] = Field(default_factory=list) + + +class NodeTypeDescriptor(BaseModel): + type: str + display_name: str = '' + description: str = '' + param_keys: list[str] = Field(default_factory=list) + dynamic_template: bool = False + schema_version: str = WORKFLOW_AUTHORING_SCHEMA_VERSION + + +class NodeTemplateDescriptor(BaseModel): + node_type: str + display_name: str = '' + description: str = '' + tab: Optional[WorkflowTabDescriptor] = None + groups: list[WorkflowParamGroupDescriptor] = Field(default_factory=list) + params: dict[str, WorkflowParamMetadata] = Field(default_factory=dict) + dynamic_template: bool = False + schema_version: str = WORKFLOW_AUTHORING_SCHEMA_VERSION + + +class ValidationDiagnostic(BaseModel): + code: str = '' + severity: ValidationSeverity = ValidationSeverity.ERROR + message: str + node_id: Optional[str] = None + field_path: Optional[str] = None + suggested_fix: Optional[str] = None diff --git a/src/backend/bisheng/workflow/authoring/editor_compat.py b/src/backend/bisheng/workflow/authoring/editor_compat.py new file mode 100644 index 0000000000..78b8d4b118 --- /dev/null +++ b/src/backend/bisheng/workflow/authoring/editor_compat.py @@ -0,0 +1,164 @@ +import copy +from typing import Any, Optional + + +_EDITOR_FLOW_NODE_TYPE = 'flowNode' +_EDITOR_NOTE_NODE_TYPE = 'noteNode' +_CONDITION_NODE_TYPE = 'condition' +_CONDITION_PARAM_KEY = 'condition' +_EDITOR_CONDITION_RIGHT_VALUE_TYPES = {'input', 'ref'} + + +def _iter_variable_refs(node_data: dict): + node_id = str(node_data.get('id', '') or '') + if not node_id: + return + + node_name = str(node_data.get('name', '') or node_id) + for group in node_data.get('group_params', []) or []: + if not isinstance(group, dict): + continue + for param in group.get('params', []) or []: + if not isinstance(param, dict): + continue + + param_key = param.get('key') + if not isinstance(param_key, str) or not param_key: + continue + + param_global = param.get('global') + param_value = param.get('value') + param_label = str(param.get('label') or param_key) + + if isinstance(param_global, str) and param_global.startswith('code:') and isinstance(param_value, list): + for item in param_value: + if not isinstance(item, dict): + continue + item_key = item.get('value') or item.get('key') + if not item_key: + continue + item_label = str(item.get('label') or item.get('key') or item_key) + yield f'{node_id}.{item_key}', f'{node_name}/{item_label}' + continue + + if param_global in {'key', 'self'}: + yield f'{node_id}.{param_key}', f'{node_name}/{param_label}' + continue + + if param_global == 'item:form_input' and isinstance(param_value, list): + for item in param_value: + if not isinstance(item, dict): + continue + base_label = str(item.get('label') or item.get('key') or '') + for sub_key_name in ('key', 'file_content', 'file_path', 'image_file'): + sub_key = item.get(sub_key_name) + if not sub_key: + continue + suffix = base_label or str(sub_key) + yield f'{node_id}.{sub_key}', f'{node_name}/{suffix}' + + +def _build_variable_label_map(nodes: list[dict]) -> dict[str, str]: + variable_labels: dict[str, str] = {} + for node in nodes: + if not isinstance(node, dict): + continue + node_data = node.get('data') + if not isinstance(node_data, dict): + continue + for ref, label in _iter_variable_refs(node_data): + variable_labels[ref] = label + return variable_labels + + +def _normalize_editor_condition_cases(condition_cases: Any) -> Any: + if not isinstance(condition_cases, list): + return condition_cases + + normalized_cases = [] + for raw_case in condition_cases: + if not isinstance(raw_case, dict): + normalized_cases.append(raw_case) + continue + + case = copy.deepcopy(raw_case) + conditions = case.get('conditions') + if isinstance(conditions, list): + normalized_conditions = [] + for raw_condition in conditions: + if not isinstance(raw_condition, dict): + normalized_conditions.append(raw_condition) + continue + + condition = copy.deepcopy(raw_condition) + condition.setdefault('left_label', '') + condition.setdefault('right_label', '') + if condition.get('left_var') is None: + condition['left_var'] = '' + if condition.get('right_value') is None: + condition['right_value'] = '' + if condition.get('comparison_operation') is None: + condition['comparison_operation'] = '' + + right_value_type = condition.get('right_value_type') + if right_value_type not in _EDITOR_CONDITION_RIGHT_VALUE_TYPES: + condition['right_value_type'] = 'ref' if right_value_type == 'ref' else 'input' + + normalized_conditions.append(condition) + case['conditions'] = normalized_conditions + + normalized_cases.append(case) + + return normalized_cases + + +def normalize_workflow_editor_graph(graph_data: Optional[dict], *, in_place: bool = False) -> Optional[dict]: + if not isinstance(graph_data, dict): + return graph_data + + nodes = graph_data.get('nodes') + if not isinstance(nodes, list): + return graph_data + + normalized_graph = graph_data if in_place else copy.deepcopy(graph_data) + variable_labels = _build_variable_label_map(normalized_graph.get('nodes', [])) + for node in normalized_graph.get('nodes', []): + if not isinstance(node, dict): + continue + + node_data = node.get('data') + if not isinstance(node_data, dict): + continue + + node_type = node_data.get('type') + if node_type: + if node.get('type') not in {_EDITOR_FLOW_NODE_TYPE, _EDITOR_NOTE_NODE_TYPE}: + node['type'] = _EDITOR_NOTE_NODE_TYPE if node_type == 'note' else _EDITOR_FLOW_NODE_TYPE + + if node_type == _CONDITION_NODE_TYPE: + for group in node_data.get('group_params', []): + if not isinstance(group, dict): + continue + for param in group.get('params', []): + if not isinstance(param, dict): + continue + if param.get('key') == _CONDITION_PARAM_KEY: + param['value'] = _normalize_editor_condition_cases(param.get('value') or []) + for case in param['value'] or []: + if not isinstance(case, dict): + continue + for condition in case.get('conditions') or []: + if not isinstance(condition, dict): + continue + if not condition.get('left_label'): + condition['left_label'] = variable_labels.get(condition.get('left_var') or '', '') + if ( + condition.get('right_value_type') == 'ref' + and not condition.get('right_label') + ): + condition['right_label'] = variable_labels.get( + condition.get('right_value') or '', + '', + ) + + return normalized_graph diff --git a/src/backend/bisheng/workflow/authoring/registry.py b/src/backend/bisheng/workflow/authoring/registry.py new file mode 100644 index 0000000000..d84b6d5c50 --- /dev/null +++ b/src/backend/bisheng/workflow/authoring/registry.py @@ -0,0 +1,1021 @@ +import copy +from typing import Optional + +from bisheng.workflow.authoring.contract import ( + NodeTemplateDescriptor, + NodeTypeDescriptor, + WorkflowParamGroupDescriptor, + WorkflowParamMetadata, + WorkflowTabDescriptor, + WorkflowTabOption, +) + + +_NODE_DISPLAY_NAMES = { + 'start': 'Start', + 'input': 'Input', + 'output': 'Output', + 'llm': 'LLM', + 'agent': 'Agent', + 'qa_retriever': 'QA Retriever', + 'rag': 'RAG', + 'knowledge_retriever': 'Knowledge Retriever', + 'report': 'Report', + 'code': 'Code', + 'condition': 'Condition', + 'end': 'End', + 'tool': 'Tool', +} + +_NODE_DESCRIPTIONS = { + 'tool': 'Dynamic tool node. Param schema depends on the selected tool.', +} + +_WORKFLOW_NODE_TEMPLATES = [ + { + 'id': 'start_xxx', + 'name': '', + 'description': '', + 'type': 'start', + 'v': '3', + 'group_params': [ + { + 'name': '开场引导', + 'params': [ + { + 'key': 'guide_word', + 'label': 'Guide Word', + 'value': '', + 'type': 'textarea', + 'placeholder': '', + }, + { + 'key': 'guide_question', + 'label': 'Guide Questions', + 'value': [], + 'type': 'input_list', + 'placeholder': '', + 'help': '', + }, + ], + }, + { + 'name': '全局变量', + 'params': [ + { + 'key': 'user_info', + 'global': 'key', + 'label': 'User Info', + 'type': 'var', + 'value': '', + }, + { + 'key': 'current_time', + 'global': 'key', + 'label': 'Current Time', + 'type': 'var', + 'value': '', + }, + { + 'key': 'chat_history', + 'global': 'key', + 'label': 'Chat History', + 'type': 'chat_history_num', + 'value': 10, + }, + { + 'key': 'preset_question', + 'label': 'Preset Questions', + 'global': 'item:input_list', + 'type': 'input_list', + 'value': [], + 'placeholder': '', + 'help': '', + }, + { + 'key': 'custom_variables', + 'label': 'Custom Variables', + 'global': 'item:input_list', + 'type': 'global_var', + 'value': [], + 'help': '', + }, + ], + }, + ], + }, + { + 'id': 'input_xxx', + 'name': '', + 'description': '', + 'type': 'input', + 'v': '3', + 'tab': { + 'value': 'dialog_input', + 'options': [ + { + 'label': 'Dialog Input', + 'key': 'dialog_input', + 'help': '', + }, + { + 'label': 'Form Input', + 'key': 'form_input', + 'help': '', + }, + ], + }, + 'group_params': [ + { + 'name': '接收文本', + 'params': [ + { + 'key': 'user_input', + 'global': 'key', + 'label': 'User Input', + 'type': 'var', + 'tab': 'dialog_input', + }, + ], + }, + { + 'name': '', + 'groupKey': 'inputfile', + 'params': [ + { + 'groupTitle': True, + 'key': 'user_input_file', + 'tab': 'dialog_input', + 'value': True, + }, + { + 'key': 'file_parse_mode', + 'type': 'select_parsemode', + 'tab': 'dialog_input', + 'value': 'extract_text', + }, + { + 'key': 'dialog_files_content', + 'global': 'key', + 'label': 'Dialog Files Content', + 'type': 'var', + 'tab': 'dialog_input', + }, + { + 'key': 'dialog_files_content_size', + 'label': 'Dialog Files Content Size', + 'type': 'char_number', + 'min': 0, + 'value': 15000, + 'tab': 'dialog_input', + }, + { + 'key': 'dialog_file_accept', + 'label': 'Dialog File Accept', + 'type': 'select_fileaccept', + 'value': 'all', + 'tab': 'dialog_input', + }, + { + 'key': 'dialog_image_files', + 'global': 'key', + 'label': 'Dialog Image Files', + 'type': 'var', + 'tab': 'dialog_input', + 'help': '', + }, + { + 'key': 'dialog_file_paths', + 'global': 'key', + 'label': 'Dialog File Paths', + 'type': 'var', + 'tab': 'dialog_input', + 'help': '', + }, + ], + }, + { + 'name': '', + 'groupKey': 'custom', + 'params': [ + { + 'groupTitle': True, + 'key': 'recommended_questions_flag', + 'label': 'Recommended Questions Flag', + 'hidden': True, + 'tab': 'dialog_input', + 'help': '', + 'value': False, + }, + { + 'key': 'recommended_llm', + 'label': 'Recommended LLM', + 'type': 'bisheng_model', + 'tab': 'dialog_input', + 'value': '', + 'placeholder': '', + 'required': True, + }, + { + 'key': 'recommended_system_prompt', + 'label': 'Recommended System Prompt', + 'tab': 'dialog_input', + 'type': 'var_textarea', + 'value': '', + 'required': True, + }, + { + 'key': 'recommended_history_num', + 'label': 'Recommended History Num', + 'type': 'slide', + 'tab': 'dialog_input', + 'help': '', + 'scope': [1, 10], + 'step': 1, + 'value': 2, + }, + ], + }, + { + 'name': '', + 'params': [ + { + 'key': 'form_input', + 'global': 'item:form_input', + 'label': 'Form Input', + 'type': 'form', + 'tab': 'form_input', + 'value': [], + }, + ], + }, + ], + }, + { + 'id': 'output_xxx', + 'name': '', + 'description': '', + 'type': 'output', + 'v': '2', + 'group_params': [ + { + 'params': [ + { + 'key': 'message', + 'label': 'Message', + 'global': 'key', + 'type': 'var_textarea_file', + 'required': True, + 'placeholder': '', + 'value': { + 'msg': '', + 'files': [], + }, + }, + { + 'key': 'output_result', + 'label': 'Output Result', + 'global': 'value.type=input', + 'type': 'output_form', + 'required': True, + 'value': { + 'type': '', + 'value': '', + }, + 'options': [], + }, + ], + }, + ], + }, + { + 'id': 'llm_xxx', + 'name': '', + 'description': '', + 'type': 'llm', + 'v': '2', + 'tab': { + 'value': 'single', + 'options': [ + {'label': 'Single', 'key': 'single'}, + {'label': 'Batch', 'key': 'batch'}, + ], + }, + 'group_params': [ + { + 'params': [ + { + 'key': 'batch_variable', + 'label': 'Batch Variable', + 'global': 'self', + 'type': 'user_question', + 'test': 'var', + 'value': [], + 'required': True, + 'linkage': 'output', + 'placeholder': '', + 'help': '', + 'tab': 'batch', + }, + ], + }, + { + 'name': '模型设置', + 'params': [ + { + 'key': 'model_id', + 'label': 'Model ID', + 'type': 'bisheng_model', + 'value': '', + 'required': True, + 'placeholder': '', + }, + { + 'key': 'temperature', + 'label': 'Temperature', + 'type': 'slide', + 'scope': [0, 2], + 'step': 0.1, + 'value': 0.7, + }, + ], + }, + { + 'name': '提示词', + 'params': [ + { + 'key': 'system_prompt', + 'label': 'System Prompt', + 'type': 'var_textarea', + 'test': 'var', + 'value': '', + }, + { + 'key': 'user_prompt', + 'label': 'User Prompt', + 'type': 'var_textarea', + 'test': 'var', + 'value': '', + 'required': True, + }, + { + 'key': 'image_prompt', + 'label': 'Image Prompt', + 'type': 'image_prompt', + 'value': [], + 'help': '', + }, + ], + }, + { + 'name': '输出', + 'params': [ + { + 'key': 'output_user', + 'label': 'Output To User', + 'type': 'switch', + 'help': '', + 'value': True, + }, + { + 'key': 'output', + 'global': 'code:value.map(el => ({ label: el.label, value: el.key }))', + 'label': 'Output Variable', + 'help': '', + 'type': 'var', + 'value': [], + }, + ], + }, + ], + }, + { + 'id': 'agent_xxx', + 'name': '', + 'description': '', + 'type': 'agent', + 'v': '2', + 'tab': { + 'value': 'single', + 'options': [ + {'label': 'Single', 'key': 'single'}, + {'label': 'Batch', 'key': 'batch'}, + ], + }, + 'group_params': [ + { + 'params': [ + { + 'key': 'batch_variable', + 'label': 'Batch Variable', + 'required': True, + 'type': 'user_question', + 'test': 'var', + 'global': 'self', + 'value': [], + 'linkage': 'output', + 'placeholder': '', + 'tab': 'batch', + 'help': '', + }, + ], + }, + { + 'name': '模型设置', + 'params': [ + { + 'key': 'model_id', + 'label': 'Model ID', + 'type': 'agent_model', + 'required': True, + 'value': '', + 'placeholder': '', + }, + { + 'key': 'temperature', + 'label': 'Temperature', + 'type': 'slide', + 'scope': [0, 2], + 'step': 0.1, + 'value': 0.7, + }, + ], + }, + { + 'name': '提示词', + 'params': [ + { + 'key': 'system_prompt', + 'label': 'System Prompt', + 'type': 'var_textarea', + 'test': 'var', + 'value': '', + 'placeholder': '', + 'required': True, + }, + { + 'key': 'user_prompt', + 'label': 'User Prompt', + 'type': 'var_textarea', + 'test': 'var', + 'value': '', + 'placeholder': '', + 'required': True, + }, + { + 'key': 'chat_history_flag', + 'label': 'Chat History', + 'type': 'slide_switch', + 'scope': [0, 100], + 'step': 1, + 'value': { + 'flag': True, + 'value': 50, + }, + 'help': '', + }, + { + 'key': 'image_prompt', + 'label': 'Image Prompt', + 'type': 'image_prompt', + 'value': '', + 'help': '', + }, + ], + }, + { + 'name': '知识库', + 'params': [ + { + 'key': 'knowledge_id', + 'label': 'Knowledge ID', + 'type': 'knowledge_select_multi', + 'placeholder': '', + 'value': { + 'type': 'knowledge', + 'value': [], + }, + }, + ], + }, + { + 'name': '数据库', + 'params': [ + { + 'key': 'sql_agent', + 'type': 'sql_config', + 'value': { + 'open': False, + 'db_address': '', + 'db_name': '', + 'db_username': '', + 'db_password': '', + }, + }, + ], + }, + { + 'name': '工具', + 'params': [ + { + 'key': 'tool_list', + 'label': 'Tool List', + 'type': 'add_tool', + 'value': [], + }, + ], + }, + { + 'name': '输出', + 'params': [ + { + 'key': 'output_user', + 'label': 'Output To User', + 'type': 'switch', + 'help': '', + 'value': True, + }, + { + 'key': 'output', + 'global': 'code:value.map(el => ({ label: el.label, value: el.key }))', + 'label': 'Output Variable', + 'type': 'var', + 'help': '', + 'value': [], + }, + ], + }, + ], + }, + { + 'id': 'qa_retriever_xxx', + 'name': '', + 'description': '', + 'type': 'qa_retriever', + 'v': '1', + 'group_params': [ + { + 'name': '检索设置', + 'params': [ + { + 'key': 'user_question', + 'label': 'User Question', + 'type': 'var_select', + 'test': 'var', + 'value': '', + 'required': True, + 'placeholder': '', + }, + { + 'key': 'qa_knowledge_id', + 'label': 'QA Knowledge ID', + 'type': 'qa_select_multi', + 'value': [], + 'required': True, + 'placeholder': '', + }, + { + 'key': 'score', + 'label': 'Score', + 'type': 'slide', + 'value': 0.8, + 'scope': [0.01, 0.99], + 'step': 0.01, + 'help': '', + }, + ], + }, + { + 'name': '输出', + 'params': [ + { + 'key': 'retrieved_result', + 'label': 'Retrieved Result', + 'type': 'var', + 'global': 'key', + 'value': '', + }, + ], + }, + ], + }, + { + 'id': 'rag_xxx', + 'name': '', + 'description': '', + 'type': 'rag', + 'v': '2', + 'group_params': [ + { + 'name': '知识库检索设置', + 'params': [ + { + 'key': 'user_question', + 'label': 'User Question', + 'global': 'self=user_prompt', + 'type': 'user_question', + 'test': 'var', + 'help': '', + 'linkage': 'output_user_input', + 'value': [], + 'placeholder': '', + 'required': True, + }, + { + 'key': 'knowledge', + 'label': 'Knowledge', + 'type': 'knowledge_select_multi', + 'placeholder': '', + 'value': { + 'type': 'knowledge', + 'value': [], + }, + 'required': True, + }, + { + 'key': 'metadata_filter', + 'label': 'Metadata Filter', + 'type': 'metadata_filter', + 'value': {}, + }, + { + 'key': 'advanced_retrieval_switch', + 'label': 'Advanced Retrieval Switch', + 'type': 'search_switch', + 'value': {}, + }, + { + 'key': 'retrieved_result', + 'label': 'Retrieved Result', + 'type': 'var', + 'global': 'self=user_prompt', + }, + ], + }, + { + 'name': 'AI回复生成设置', + 'params': [ + { + 'key': 'system_prompt', + 'label': 'System Prompt', + 'type': 'var_textarea', + 'value': '', + 'required': True, + }, + { + 'key': 'user_prompt', + 'label': 'User Prompt', + 'type': 'var_textarea', + 'value': '', + 'test': 'var', + 'required': True, + }, + { + 'key': 'model_id', + 'label': 'Model ID', + 'type': 'bisheng_model', + 'value': '', + 'required': True, + 'placeholder': '', + }, + { + 'key': 'temperature', + 'label': 'Temperature', + 'type': 'slide', + 'scope': [0, 2], + 'step': 0.1, + 'value': 0.7, + }, + ], + }, + { + 'name': '输出', + 'params': [ + { + 'key': 'output_user', + 'label': 'Output To User', + 'type': 'switch', + 'value': True, + 'help': '', + }, + { + 'key': 'output_user_input', + 'label': 'Output User Input', + 'type': 'var', + 'help': '', + 'global': 'code:value.map(el => ({ label: el.label, value: el.key }))', + 'value': [], + }, + ], + }, + ], + }, + { + 'id': 'knowledge_retriever_xxx', + 'name': '', + 'description': '', + 'type': 'knowledge_retriever', + 'v': '1', + 'group_params': [ + { + 'name': '知识库检索设置', + 'params': [ + { + 'key': 'user_question', + 'label': 'User Question', + 'global': 'self=user_prompt', + 'type': 'user_question', + 'test': 'var', + 'help': '', + 'linkage': 'retrieved_result', + 'value': [], + 'placeholder': '', + 'required': True, + }, + { + 'key': 'knowledge', + 'label': 'Knowledge', + 'type': 'knowledge_select_multi', + 'placeholder': '', + 'value': { + 'type': 'knowledge', + 'value': [], + }, + 'required': True, + }, + { + 'key': 'metadata_filter', + 'label': 'Metadata Filter', + 'type': 'metadata_filter', + 'value': {}, + }, + { + 'key': 'advanced_retrieval_switch', + 'label': 'Advanced Retrieval Switch', + 'type': 'search_switch', + 'value': {}, + }, + ], + }, + { + 'name': '输出', + 'params': [ + { + 'key': 'retrieved_result', + 'label': 'Retrieved Result', + 'type': 'var', + 'global': 'code:value.map(el => ({ label: el.label, value: el.key }))', + 'value': [], + }, + ], + }, + ], + }, + { + 'id': 'report_xxx', + 'name': '', + 'description': '', + 'type': 'report', + 'v': '1', + 'group_params': [ + { + 'params': [ + { + 'key': 'report_info', + 'label': 'Report Info', + 'placeholder': '', + 'required': True, + 'type': 'report', + 'value': {}, + }, + ], + }, + ], + }, + { + 'id': 'code_xxx', + 'name': '', + 'description': '', + 'type': 'code', + 'v': '1', + 'group_params': [ + { + 'name': '入参', + 'params': [ + { + 'key': 'code_input', + 'type': 'code_input', + 'test': 'input', + 'required': True, + 'value': [ + {'key': 'arg1', 'type': 'input', 'label': '', 'value': ''}, + {'key': 'arg2', 'type': 'input', 'label': '', 'value': ''}, + ], + }, + ], + }, + { + 'name': '执行代码', + 'params': [ + { + 'key': 'code', + 'type': 'code', + 'required': True, + 'value': "def main(arg1: str, arg2: str) -> dict: \n return {'result1': arg1, 'result2': arg2}", + }, + ], + }, + { + 'name': '出参', + 'params': [ + { + 'key': 'code_output', + 'type': 'code_output', + 'global': 'code:value.map(el => ({ label: el.key, value: el.key }))', + 'required': True, + 'value': [ + {'key': 'result1', 'type': 'str'}, + {'key': 'result2', 'type': 'str'}, + ], + }, + ], + }, + ], + }, + { + 'id': 'condition_xxx', + 'name': '', + 'description': '', + 'type': 'condition', + 'v': '1', + 'group_params': [ + { + 'params': [ + { + 'key': 'condition', + 'label': '', + 'type': 'condition', + 'value': [], + }, + ], + }, + ], + }, + { + 'id': 'end_xxx', + 'name': '', + 'description': '', + 'type': 'end', + 'v': '1', + 'group_params': [], + }, + { + 'id': 'tool_xxx', + 'name': '', + 'description': '', + 'type': 'tool', + 'v': '1', + 'group_params': [], + 'dynamic_template': True, + }, +] + +_TEMPLATE_MAP = {item['type']: item for item in _WORKFLOW_NODE_TEMPLATES} + + +def _normalize_tab(tab: Optional[dict]) -> Optional[WorkflowTabDescriptor]: + if not isinstance(tab, dict): + return None + options = [] + for option in tab.get('options', []) or []: + if not isinstance(option, dict): + continue + options.append( + WorkflowTabOption( + key=str(option.get('key', '')), + label=str(option.get('label', '')), + help=option.get('help'), + ) + ) + return WorkflowTabDescriptor(value=tab.get('value'), options=options) + + +def normalize_tab_descriptor(tab: Optional[dict]) -> Optional[WorkflowTabDescriptor]: + return _normalize_tab(tab) + + +def _normalize_param(field: dict, group_name: str) -> WorkflowParamMetadata: + key = field.get('key', '') + return WorkflowParamMetadata( + display_name=field.get('display_name') or field.get('label') or field.get('name') or key, + group_name=group_name, + type=field.get('type'), + required=field.get('required', False), + show=field.get('show', True) and not field.get('hidden', False), + options=field.get('options'), + scope=field.get('scope'), + placeholder=field.get('placeholder'), + refresh=field.get('refresh', False), + value=field.get('value'), + ) + + +def _iter_group_fields(node_template: dict): + for group in node_template.get('group_params', []) or []: + if not isinstance(group, dict): + continue + group_name = group.get('name', '') + group_key = group.get('groupKey') + param_keys = [] + fields = {} + for field in group.get('params', []) or []: + if not isinstance(field, dict): + continue + key = field.get('key') + if not key or field.get('groupTitle') is True: + continue + metadata = _normalize_param(field, group_name) + if not metadata.show: + continue + param_keys.append(key) + fields[key] = metadata + yield WorkflowParamGroupDescriptor(name=group_name, group_key=group_key, param_keys=param_keys), fields + + +def _display_name_for(node_type: str) -> str: + return _NODE_DISPLAY_NAMES.get(node_type, node_type.replace('_', ' ').title()) + + +def _description_for(node_type: str) -> str: + return _NODE_DESCRIPTIONS.get(node_type, '') + + +def get_node_template_descriptor(node_type: str) -> Optional[NodeTemplateDescriptor]: + template = _TEMPLATE_MAP.get(node_type) + if template is None: + return None + groups = [] + params = {} + for group, fields in _iter_group_fields(copy.deepcopy(template)): + groups.append(group) + params.update(fields) + return NodeTemplateDescriptor( + node_type=node_type, + display_name=_display_name_for(node_type), + description=_description_for(node_type), + tab=_normalize_tab(template.get('tab')), + groups=groups, + params=params, + dynamic_template=template.get('dynamic_template', False), + ) + + +def list_node_type_descriptors() -> list[NodeTypeDescriptor]: + descriptors = [] + for node_type in _TEMPLATE_MAP: + template = get_node_template_descriptor(node_type) + if template is None: + continue + descriptors.append( + NodeTypeDescriptor( + type=node_type, + display_name=template.display_name, + description=template.description, + param_keys=list(template.params.keys()), + dynamic_template=template.dynamic_template, + ) + ) + return descriptors + + +def get_node_template_payload(node_type: str) -> Optional[dict]: + template = _TEMPLATE_MAP.get(node_type) + if template is None: + return None + return copy.deepcopy(template) + + +def create_graph_node_payload(node_type: str, + *, + node_id: str, + name: str = '', + position_x: float = 0, + position_y: float = 0) -> Optional[dict]: + template = get_node_template_payload(node_type) + if template is None: + return None + + template['id'] = node_id + template['name'] = name or template.get('name') or _display_name_for(node_type) + template.setdefault('description', '') + template.setdefault('v', '1') + + return { + 'id': node_id, + 'type': 'noteNode' if node_type == 'note' else 'flowNode', + 'position': { + 'x': position_x, + 'y': position_y, + }, + 'data': template, + } diff --git a/src/backend/bisheng/workflow/nodes/condition/conidition_case.py b/src/backend/bisheng/workflow/nodes/condition/conidition_case.py index 49a50f78dd..700036bbd2 100644 --- a/src/backend/bisheng/workflow/nodes/condition/conidition_case.py +++ b/src/backend/bisheng/workflow/nodes/condition/conidition_case.py @@ -11,9 +11,11 @@ class ConditionOne(BaseModel): id: str = Field(..., description='Unique id for condition') left_var: str = Field(..., description='Left variable') + left_label: str = Field('', description='Left variable label for editor display') comparison_operation: str = Field(..., description='Compare type') right_value_type: str = Field(..., description='Right value type') right_value: str = Field(..., description='Right value') + right_label: str = Field('', description='Right variable label for editor display') variable_key_value: Dict = Field(default={}, description='variable key value') def evaluate(self, node_instance: BaseNode) -> bool: diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index f7a8a4a397..b9dfef6514 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -84,3 +84,8 @@ dependencies = [ "json-repair>=0.55.0", "arxiv>=2.4.0" ] + +[dependency-groups] +dev = [ + "pytest>=9.0.3", +] diff --git a/src/backend/test/conftest.py b/src/backend/test/conftest.py new file mode 100644 index 0000000000..a499f54597 --- /dev/null +++ b/src/backend/test/conftest.py @@ -0,0 +1,5 @@ +import os +from pathlib import Path + + +os.environ.setdefault('config', str(Path(__file__).with_name('test_config.yaml').resolve())) diff --git a/src/backend/test/test_config.yaml b/src/backend/test/test_config.yaml new file mode 100644 index 0000000000..0967ef424b --- /dev/null +++ b/src/backend/test/test_config.yaml @@ -0,0 +1 @@ +{} diff --git a/src/backend/test/test_external_workflow_service.py b/src/backend/test/test_external_workflow_service.py new file mode 100644 index 0000000000..26e995a6ed --- /dev/null +++ b/src/backend/test/test_external_workflow_service.py @@ -0,0 +1,1250 @@ +from copy import deepcopy +from types import SimpleNamespace +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, patch + +from bisheng.api.services.flow import FlowService +from bisheng.api.services.external_workflow import ExternalWorkflowService +from bisheng.api.services.workflow import WorkFlowService +from bisheng.common.errcode.flow import WorkFlowInitError, WorkFlowVersionUpdateError, WorkflowNameExistsError +from bisheng.database.models.flow import FlowDao, FlowStatus +from bisheng.database.models.flow_version import FlowVersionDao + + +def make_graph(): + return { + 'nodes': [{ + 'id': 'node-1', + 'data': { + 'id': 'node-1', + 'type': 'llm', + 'name': 'LLM Node', + 'group_params': [{ + 'name': 'model', + 'params': [{ + 'key': 'temperature', + 'type': 'slide', + 'required': True, + 'scope': [0, 1], + 'placeholder': '0.0 ~ 1.0', + 'refresh': True, + 'options': [{'key': 0.3, 'value': 0.3}, {'key': 0.7, 'value': 0.7}], + 'value': 0.7, + }, { + 'key': 'system_prompt', + 'type': 'var_textarea', + 'value': 'hello', + }, { + 'key': 'openai_api_key', + 'type': 'str', + 'password': True, + 'value': 'secret', + }, { + 'key': 'hidden_internal', + 'type': 'str', + 'show': False, + 'value': 'hidden', + }], + }], + }, + }], + 'edges': [], + } + + +def make_condition_graph(): + return { + 'nodes': [{ + 'id': 'condition-1', + 'data': { + 'id': 'condition-1', + 'type': 'condition', + 'name': 'Condition Node', + 'group_params': [{ + 'params': [{ + 'key': 'condition', + 'type': 'condition', + 'value': [{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [{ + 'id': 'rule_1', + 'left_var': 'score', + 'comparison_operation': 'greater_than', + 'right_value_type': 'const', + 'right_value': '80', + 'variable_key_value': {}, + }], + 'variable_key_value': {}, + }], + }], + }], + }, + }, { + 'id': 'node-2', + 'data': { + 'id': 'node-2', + 'type': 'output', + 'name': 'Output Node', + 'group_params': [], + }, + }], + 'edges': [{ + 'id': 'edge-1', + 'source': 'condition-1', + 'sourceHandle': 'case_a', + 'target': 'node-2', + 'targetHandle': 'input', + }, { + 'id': 'edge-2', + 'source': 'condition-1', + 'sourceHandle': 'right_handle', + 'target': 'node-2', + 'targetHandle': 'input', + }], + } + + +def make_large_graph(): + nodes = [ + { + 'id': 'input-1', + 'data': { + 'id': 'input-1', + 'type': 'input', + 'name': 'Input Node', + 'group_params': [], + }, + }, + { + 'id': 'llm-1', + 'data': { + 'id': 'llm-1', + 'type': 'llm', + 'name': 'Planner', + 'group_params': [{ + 'name': 'model', + 'params': [{ + 'key': 'temperature', + 'type': 'slide', + 'required': True, + 'scope': [0, 1], + 'value': 0.7, + }], + }], + }, + }, + { + 'id': 'condition-1', + 'data': { + 'id': 'condition-1', + 'type': 'condition', + 'name': 'Route', + 'group_params': [{ + 'params': [{ + 'key': 'condition', + 'type': 'condition', + 'value': [{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [{ + 'id': 'rule_1', + 'left_var': 'score', + 'comparison_operation': 'greater_than', + 'right_value_type': 'const', + 'right_value': '80', + 'variable_key_value': {}, + }], + 'variable_key_value': {}, + }], + }], + }], + }, + }, + ] + for index, node_type in enumerate( + ['tool', 'tool', 'code', 'agent', 'tool', 'tool', 'output', 'output', 'output'], + start=1, + ): + node_id = f'node-{index + 3}' + nodes.append({ + 'id': node_id, + 'data': { + 'id': node_id, + 'type': node_type, + 'name': f'{node_type}-{index}', + 'group_params': [], + }, + }) + + edges = [ + {'id': 'edge-1', 'source': 'input-1', 'sourceHandle': 'output', 'target': 'llm-1', 'targetHandle': 'input'}, + {'id': 'edge-2', 'source': 'llm-1', 'sourceHandle': 'output', 'target': 'condition-1', 'targetHandle': 'input'}, + {'id': 'edge-3', 'source': 'condition-1', 'sourceHandle': 'case_a', 'target': 'node-4', 'targetHandle': 'input'}, + {'id': 'edge-4', 'source': 'condition-1', 'sourceHandle': 'right_handle', 'target': 'node-5', 'targetHandle': 'input'}, + {'id': 'edge-5', 'source': 'node-4', 'sourceHandle': 'output', 'target': 'node-6', 'targetHandle': 'input'}, + {'id': 'edge-6', 'source': 'node-5', 'sourceHandle': 'output', 'target': 'node-7', 'targetHandle': 'input'}, + {'id': 'edge-7', 'source': 'node-6', 'sourceHandle': 'output', 'target': 'node-8', 'targetHandle': 'input'}, + {'id': 'edge-8', 'source': 'node-7', 'sourceHandle': 'output', 'target': 'node-9', 'targetHandle': 'input'}, + {'id': 'edge-9', 'source': 'node-8', 'sourceHandle': 'output', 'target': 'node-10', 'targetHandle': 'input'}, + {'id': 'edge-10', 'source': 'node-9', 'sourceHandle': 'output', 'target': 'node-11', 'targetHandle': 'input'}, + {'id': 'edge-11', 'source': 'node-10', 'sourceHandle': 'output', 'target': 'node-12', 'targetHandle': 'input'}, + ] + return {'nodes': nodes, 'edges': edges} + + +class TestExternalWorkflowService(IsolatedAsyncioTestCase): + def test_ensure_create_graph_scaffold_builds_minimal_start_end_graph(self): + graph = ExternalWorkflowService._ensure_create_graph_scaffold({'nodes': [], 'edges': []}) + + self.assertEqual([node['data']['type'] for node in graph['nodes']], ['start', 'end']) + self.assertEqual([node['type'] for node in graph['nodes']], ['flowNode', 'flowNode']) + self.assertEqual(len(graph['edges']), 1) + self.assertEqual(graph['edges'][0]['source'], graph['nodes'][0]['id']) + self.assertEqual(graph['edges'][0]['target'], graph['nodes'][1]['id']) + self.assertEqual(graph['edges'][0]['sourceHandle'], 'right_handle') + self.assertEqual(graph['edges'][0]['targetHandle'], 'left_handle') + + def test_ensure_create_graph_scaffold_wraps_initial_node_with_start_and_end(self): + graph = ExternalWorkflowService._ensure_create_graph_scaffold({ + 'nodes': [{ + 'id': 'input-1', + 'position': {'x': 240, 'y': 32}, + 'data': { + 'id': 'input-1', + 'type': 'input', + 'name': 'Input Node', + 'group_params': [], + }, + }], + 'edges': [], + }) + + node_types = {node['id']: node['data']['type'] for node in graph['nodes']} + start_id = next(node_id for node_id, node_type in node_types.items() if node_type == 'start') + end_id = next(node_id for node_id, node_type in node_types.items() if node_type == 'end') + + self.assertEqual(len(graph['nodes']), 3) + self.assertEqual([node['type'] for node in graph['nodes']], ['flowNode', 'flowNode', 'flowNode']) + self.assertEqual(len(graph['edges']), 2) + self.assertTrue(any(edge['source'] == start_id and edge['target'] == 'input-1' for edge in graph['edges'])) + self.assertTrue(any(edge['source'] == 'input-1' and edge['target'] == end_id for edge in graph['edges'])) + + def test_ensure_create_graph_scaffold_adds_condition_routes_to_end(self): + graph = ExternalWorkflowService._ensure_create_graph_scaffold({ + 'nodes': [{ + 'id': 'condition-1', + 'position': {'x': 120, 'y': 48}, + 'data': { + 'id': 'condition-1', + 'type': 'condition', + 'name': 'Condition Node', + 'group_params': [{ + 'params': [{ + 'key': 'condition', + 'type': 'condition', + 'value': [{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [], + 'variable_key_value': {}, + }], + }], + }], + }, + }], + 'edges': [], + }) + + node_types = {node['id']: node['data']['type'] for node in graph['nodes']} + start_id = next(node_id for node_id, node_type in node_types.items() if node_type == 'start') + end_id = next(node_id for node_id, node_type in node_types.items() if node_type == 'end') + + self.assertEqual([node['type'] for node in graph['nodes']], ['flowNode', 'flowNode', 'flowNode']) + self.assertTrue(any(edge['source'] == start_id and edge['target'] == 'condition-1' for edge in graph['edges'])) + self.assertTrue( + any( + edge['source'] == 'condition-1' and edge['target'] == end_id and edge['sourceHandle'] == 'case_a' + for edge in graph['edges'] + ) + ) + self.assertTrue( + any( + edge['source'] == 'condition-1' and edge['target'] == end_id and edge['sourceHandle'] == 'right_handle' + for edge in graph['edges'] + ) + ) + + def test_create_workflow_draft_sync_accepts_normalized_graph_descriptor(self): + captured = {} + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + captured['graph_data'] = deepcopy(graph_data) + + def fake_create_flow(flow_info, flow_type): + flow_info.id = 'flow-1' + return flow_info + + def fake_get_current_version(flow_id): + return SimpleNamespace(id=11, data=deepcopy(captured['graph_data'])) + + with patch.object(ExternalWorkflowService, '_assert_workflow_name_available'), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowDao, 'create_flow', side_effect=fake_create_flow), \ + patch.object(FlowVersionDao, 'get_version_by_flow', side_effect=fake_get_current_version), \ + patch.object(FlowVersionDao, 'update_version', side_effect=lambda version: version), \ + patch.object(FlowService, 'create_flow_hook'): + flow, version = ExternalWorkflowService._create_workflow_draft_sync( + login_user=SimpleNamespace(user_id=1), + name='demo', + graph_data={ + 'nodes': [{ + 'id': 'input-1', + 'type': 'input', + 'name': 'Ticket Input', + 'params': {}, + }], + 'edges': [], + }, + ) + + self.assertEqual(flow.id, 'flow-1') + self.assertEqual(version.id, 11) + node_ids = {node['id'] for node in captured['graph_data']['nodes']} + self.assertIn('input-1', node_ids) + input_node = next(node for node in captured['graph_data']['nodes'] if node['id'] == 'input-1') + self.assertEqual(input_node['type'], 'flowNode') + self.assertEqual(input_node['data']['type'], 'input') + self.assertEqual(input_node['data']['name'], 'Ticket Input') + self.assertEqual(input_node['position'], {'x': 0.0, 'y': 0.0}) + + async def test_update_workflow_draft_accepts_normalized_graph_descriptor(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value, description='', guide_word='') + version_graph = make_graph() + version_graph['nodes'][0]['type'] = 'flowNode' + version_graph['nodes'][0]['position'] = {'x': 128, 'y': 64} + version = SimpleNamespace(id=11, data=version_graph, is_current=1) + validate_calls = [] + persisted = [] + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + validate_calls.append(deepcopy(graph_data)) + + def fake_update_version(version_info): + persisted.append(deepcopy(version_info.data)) + return version_info + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + _, updated_version = await ExternalWorkflowService.update_workflow_draft( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + graph_data={ + 'nodes': [{ + 'id': 'node-1', + 'type': 'llm', + 'name': 'Router', + 'params': { + 'temperature': {'value': 0.3}, + 'system_prompt': {'value': 'route tickets'}, + }, + }], + 'edges': [], + }, + ) + + self.assertEqual(updated_version.id, 11) + self.assertEqual(len(validate_calls), 1) + rebuilt_node = validate_calls[0]['nodes'][0] + self.assertEqual(rebuilt_node['type'], 'flowNode') + self.assertEqual(rebuilt_node['position'], {'x': 128.0, 'y': 64.0}) + self.assertEqual(rebuilt_node['data']['name'], 'Router') + params = rebuilt_node['data']['group_params'][0]['params'] + self.assertEqual(params[0]['value'], 0.3) + self.assertEqual(params[1]['value'], 'route tickets') + self.assertEqual(params[2]['value'], 'secret') + self.assertEqual(len(persisted), 1) + + async def test_create_workflow_draft_uses_async_thread_wrapper(self): + expected_flow = SimpleNamespace(id='flow-1') + expected_version = SimpleNamespace(id=11, data=make_graph()) + + async_to_thread = AsyncMock(return_value=(expected_flow, expected_version)) + with patch('bisheng.api.services.external_workflow.asyncio.to_thread', async_to_thread): + flow, version = await ExternalWorkflowService.create_workflow_draft( + login_user=SimpleNamespace(user_id=1), + name='demo', + graph_data=make_graph(), + ) + + self.assertEqual(flow.id, 'flow-1') + self.assertEqual(version.id, 11) + args = async_to_thread.await_args.args + self.assertIs(args[0].__func__, ExternalWorkflowService._create_workflow_draft_sync.__func__) + self.assertEqual(args[2], 'demo') + + def test_create_workflow_draft_sync_validates_scaffolded_initial_graph(self): + captured = {} + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + captured['graph_data'] = deepcopy(graph_data) + + def fake_create_flow(flow_info, flow_type): + flow_info.id = 'flow-1' + return flow_info + + def fake_get_current_version(flow_id): + return SimpleNamespace(id=11, data=deepcopy(captured['graph_data'])) + + with patch.object(ExternalWorkflowService, '_assert_workflow_name_available'), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowDao, 'create_flow', side_effect=fake_create_flow), \ + patch.object(FlowVersionDao, 'get_version_by_flow', side_effect=fake_get_current_version), \ + patch.object(FlowVersionDao, 'update_version', side_effect=lambda version: version), \ + patch.object(FlowService, 'create_flow_hook'): + flow, version = ExternalWorkflowService._create_workflow_draft_sync( + login_user=SimpleNamespace(user_id=1), + name='demo', + graph_data={'nodes': [], 'edges': []}, + ) + + self.assertEqual(flow.id, 'flow-1') + self.assertEqual(version.id, 11) + scaffold_types = [node['data']['type'] for node in captured['graph_data']['nodes']] + self.assertEqual(scaffold_types, ['start', 'end']) + self.assertEqual([node['type'] for node in captured['graph_data']['nodes']], ['flowNode', 'flowNode']) + self.assertEqual(len(captured['graph_data']['edges']), 1) + + def test_normalize_workflow_editor_graph_rewrites_legacy_node_types(self): + graph = { + 'nodes': [{ + 'id': 'start-1', + 'type': 'start', + 'position': {'x': 0, 'y': 0}, + 'data': { + 'id': 'start-1', + 'type': 'start', + 'name': 'Start', + 'group_params': [], + }, + }, { + 'id': 'end-1', + 'type': 'end', + 'position': {'x': 320, 'y': 0}, + 'data': { + 'id': 'end-1', + 'type': 'end', + 'name': 'End', + 'group_params': [], + }, + }, { + 'id': 'condition-1', + 'type': 'condition', + 'position': {'x': 160, 'y': 0}, + 'data': { + 'id': 'condition-1', + 'type': 'condition', + 'name': 'Condition', + 'group_params': [{ + 'params': [{ + 'key': 'condition', + 'type': 'condition', + 'value': [{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [{ + 'id': 'rule_1', + 'left_var': 'code_1.score', + 'comparison_operation': 'greater_than', + 'right_value_type': 'const', + 'right_value': '80', + }], + }], + }], + }], + }, + }], + 'edges': [], + } + + normalized = FlowService._normalize_workflow_editor_graph(graph) + + self.assertEqual([node['type'] for node in normalized['nodes']], ['flowNode', 'flowNode', 'flowNode']) + self.assertEqual([node['data']['type'] for node in normalized['nodes']], ['start', 'end', 'condition']) + self.assertEqual([node['type'] for node in graph['nodes']], ['start', 'end', 'condition']) + condition_item = normalized['nodes'][2]['data']['group_params'][0]['params'][0]['value'][0]['conditions'][0] + self.assertEqual(condition_item['right_value_type'], 'input') + self.assertEqual(condition_item['left_label'], '') + self.assertEqual(condition_item['right_label'], '') + + def test_normalize_condition_cases_keeps_editor_friendly_shape(self): + normalized = ExternalWorkflowService._normalize_condition_cases([{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [{ + 'id': 'rule_1', + 'left_var': 'code_1.score', + 'left_label': 'Score/priority_score', + 'comparison_operation': 'greater_than_or_equal', + 'right_value_type': 'const', + 'right_value': '90', + 'right_label': '', + 'variable_key_value': {}, + }], + 'variable_key_value': {}, + }]) + + condition_item = normalized[0]['conditions'][0] + self.assertEqual(condition_item['left_label'], 'Score/priority_score') + self.assertEqual(condition_item['right_label'], '') + self.assertEqual(condition_item['right_value_type'], 'input') + + def test_normalize_workflow_editor_graph_hydrates_condition_labels(self): + graph = { + 'nodes': [{ + 'id': 'code-1', + 'type': 'flowNode', + 'position': {'x': 0, 'y': 0}, + 'data': { + 'id': 'code-1', + 'type': 'code', + 'name': 'Score Priority', + 'group_params': [{ + 'name': '出参', + 'params': [{ + 'key': 'code_output', + 'type': 'code_output', + 'global': 'code:value.map(el => ({ label: el.key, value: el.key }))', + 'value': [{'key': 'priority_score', 'type': 'str'}], + }], + }], + }, + }, { + 'id': 'condition-1', + 'type': 'flowNode', + 'position': {'x': 320, 'y': 0}, + 'data': { + 'id': 'condition-1', + 'type': 'condition', + 'name': 'Condition', + 'group_params': [{ + 'params': [{ + 'key': 'condition', + 'type': 'condition', + 'value': [{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [{ + 'id': 'rule_1', + 'left_var': 'code-1.priority_score', + 'comparison_operation': 'greater_than_or_equal', + 'right_value_type': 'const', + 'right_value': '90', + }], + }], + }], + }], + }, + }], + 'edges': [], + } + + normalized = FlowService._normalize_workflow_editor_graph(graph) + condition_item = normalized['nodes'][1]['data']['group_params'][0]['params'][0]['value'][0]['conditions'][0] + + self.assertEqual(condition_item['left_label'], 'Score Priority/priority_score') + self.assertEqual(condition_item['right_label'], '') + self.assertEqual(condition_item['right_value_type'], 'input') + + def test_get_existing_external_draft_version_limits_recent_versions(self): + captured = {'statements': []} + + class FakeResult: + def __init__(self, versions): + self._versions = versions + + def all(self): + return self._versions + + class FakeSession: + def __init__(self): + self.calls = 0 + + def exec(self, statement): + captured['statements'].append(statement) + self.calls += 1 + if self.calls == 1: + versions = [SimpleNamespace(data={'nodes': [], 'edges': []}) for _ in range( + ExternalWorkflowService._MAX_EXTERNAL_DRAFT_SCAN)] + return FakeResult(versions) + return FakeResult([ + SimpleNamespace(data=ExternalWorkflowService._mark_graph_as_draft({'nodes': [], 'edges': []})) + ]) + + class FakeContext: + def __enter__(self): + if 'session' not in captured: + captured['session'] = FakeSession() + return captured['session'] + + def __exit__(self, exc_type, exc, tb): + return False + + with patch('bisheng.api.services.external_workflow.get_sync_db_session', return_value=FakeContext()): + result = ExternalWorkflowService._get_existing_external_draft_version('flow-1') + + self.assertIsNotNone(result) + self.assertEqual(len(captured['statements']), 2) + self.assertIsNotNone(captured['statements'][0]._limit_clause) + self.assertEqual(captured['statements'][0]._limit_clause.value, ExternalWorkflowService._MAX_EXTERNAL_DRAFT_SCAN) + self.assertEqual(captured['statements'][1]._offset_clause.value, ExternalWorkflowService._MAX_EXTERNAL_DRAFT_SCAN) + + async def test_get_workflow_node_params_returns_extended_metadata(self): + version = SimpleNamespace(id=11, data=make_graph()) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return SimpleNamespace(id=flow_id, name='demo', status=FlowStatus.OFFLINE.value), version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + result = await ExternalWorkflowService.get_workflow_node_params( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='node-1', + ) + + temperature = result['params']['temperature'] + self.assertEqual(temperature['group_name'], 'model') + self.assertEqual(temperature['scope'], [0, 1]) + self.assertEqual(temperature['placeholder'], '0.0 ~ 1.0') + self.assertTrue(temperature['refresh']) + self.assertEqual(temperature['options'], [{'key': 0.3, 'value': 0.3}, {'key': 0.7, 'value': 0.7}]) + self.assertNotIn('openai_api_key', result['params']) + self.assertNotIn('hidden_internal', result['params']) + + async def test_update_workflow_node_params_revalidates_before_persist(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_graph(), is_current=1) + validate_calls = [] + persisted = [] + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + validate_calls.append((flow_name, flow_id, deepcopy(graph_data))) + + def fake_update_version(version_info): + persisted.append(deepcopy(version_info.data)) + return version_info + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + _, updated_version = await ExternalWorkflowService.update_workflow_node_params( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='node-1', + updates={'temperature': 0.3}, + ) + + self.assertEqual(updated_version.id, 11) + self.assertEqual(len(validate_calls), 1) + self.assertEqual(validate_calls[0][0], 'demo') + self.assertEqual(validate_calls[0][1], 'flow-1') + self.assertEqual(validate_calls[0][2]['nodes'][0]['data']['group_params'][0]['params'][0]['value'], 0.3) + self.assertEqual(len(persisted), 1) + self.assertTrue(persisted[0]['_external_workflow_meta']['draft']) + self.assertEqual(persisted[0]['nodes'][0]['data']['group_params'][0]['params'][0]['value'], 0.3) + self.assertEqual(persisted[0]['_external_workflow_meta']['revision'], 1) + + async def test_update_workflow_node_params_rejects_revision_mismatch(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=ExternalWorkflowService._mark_graph_as_draft(make_graph()), is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + with self.assertRaises(WorkFlowVersionUpdateError): + await ExternalWorkflowService.update_workflow_node_params( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='node-1', + updates={'temperature': 0.3}, + expected_revision=999, + ) + + async def test_update_workflow_node_params_rejects_invalid_graph_before_persist(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_graph(), is_current=1) + persist_calls = [] + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + raise WorkFlowInitError(msg='invalid graph') + + def fake_update_version(version_info): + persist_calls.append(version_info) + return version_info + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + with self.assertRaises(WorkFlowInitError): + await ExternalWorkflowService.update_workflow_node_params( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='node-1', + updates={'temperature': 0.3}, + ) + + self.assertEqual(persist_calls, []) + + async def test_update_workflow_draft_rejects_duplicate_name(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value, description='', guide_word='') + version = SimpleNamespace(id=11, data=make_graph(), is_current=1) + update_flow = AsyncMock() + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_assert_workflow_name_available(login_user, name, exclude_flow_id=None): + raise WorkflowNameExistsError() + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_assert_workflow_name_available', + side_effect=fake_assert_workflow_name_available), \ + patch('bisheng.api.services.external_workflow.FlowDao.aupdate_flow', update_flow): + with self.assertRaises(WorkflowNameExistsError): + await ExternalWorkflowService.update_workflow_draft( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + graph_data=make_graph(), + name='duplicate', + ) + + update_flow.assert_not_called() + + async def test_publish_workflow_restores_draft_marker_on_failure(self): + draft_graph = ExternalWorkflowService._mark_graph_as_draft(make_graph()) + flow = SimpleNamespace(id='flow-1') + version = SimpleNamespace(id=11, data=deepcopy(draft_graph), is_current=1) + updated_payloads = [] + + async def fake_validate_workflow(login_user, flow_id, version_id): + return flow, version + + async def fake_update_flow_status(login_user, flow_id, version_id, status): + raise RuntimeError('publish failed') + + def fake_update_version(version_info): + updated_payloads.append(deepcopy(version_info.data)) + return version_info + + with patch.object(ExternalWorkflowService, 'validate_workflow', side_effect=fake_validate_workflow), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version), \ + patch.object(WorkFlowService, 'update_flow_status', side_effect=fake_update_flow_status): + with self.assertRaises(RuntimeError): + await ExternalWorkflowService.publish_workflow( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + version_id=11, + ) + + self.assertEqual(len(updated_payloads), 2) + self.assertNotIn('_external_workflow_meta', updated_payloads[0]) + self.assertTrue(updated_payloads[1]['_external_workflow_meta']['draft']) + + async def test_add_workflow_node_revalidates_before_persist(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_graph(), is_current=1) + persisted = [] + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + return None + + def fake_update_version(version_info): + persisted.append(deepcopy(version_info.data)) + return version_info + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + _, updated_version, node_id = await ExternalWorkflowService.add_workflow_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_type='code', + name='Code Node', + position_x=120, + position_y=260, + ) + + self.assertEqual(updated_version.id, 11) + self.assertTrue(node_id.startswith('code_')) + self.assertEqual(len(persisted), 1) + self.assertEqual(len(persisted[0]['nodes']), 2) + self.assertEqual(persisted[0]['nodes'][1]['id'], node_id) + self.assertEqual(persisted[0]['nodes'][1]['type'], 'flowNode') + self.assertEqual(persisted[0]['nodes'][1]['position'], {'x': 120, 'y': 260}) + self.assertEqual(persisted[0]['nodes'][1]['data']['type'], 'code') + self.assertEqual(persisted[0]['nodes'][1]['data']['name'], 'Code Node') + self.assertTrue(persisted[0]['_external_workflow_meta']['draft']) + + async def test_remove_workflow_node_cascades_related_edges(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + graph = make_graph() + graph['nodes'].append({ + 'id': 'node-2', + 'data': { + 'id': 'node-2', + 'type': 'output', + 'name': 'Output Node', + 'group_params': [], + }, + }) + graph['edges'].append({ + 'id': 'edge-1', + 'source': 'node-1', + 'sourceHandle': 'output', + 'target': 'node-2', + 'targetHandle': 'input', + }) + version = SimpleNamespace(id=11, data=graph, is_current=1) + persisted = [] + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + return None + + def fake_update_version(version_info): + persisted.append(deepcopy(version_info.data)) + return version_info + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + await ExternalWorkflowService.remove_workflow_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='node-2', + ) + + self.assertEqual(len(persisted), 1) + self.assertEqual(len(persisted[0]['nodes']), 1) + self.assertEqual(persisted[0]['nodes'][0]['id'], 'node-1') + self.assertEqual(persisted[0]['edges'], []) + + async def test_remove_workflow_node_rejects_connected_node_when_cascade_disabled(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + graph = make_graph() + graph['nodes'].append({ + 'id': 'node-2', + 'data': { + 'id': 'node-2', + 'type': 'output', + 'name': 'Output Node', + 'group_params': [], + }, + }) + graph['edges'].append({ + 'id': 'edge-1', + 'source': 'node-1', + 'sourceHandle': 'output', + 'target': 'node-2', + 'targetHandle': 'input', + }) + version = SimpleNamespace(id=11, data=graph, is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + with self.assertRaises(WorkFlowInitError): + await ExternalWorkflowService.remove_workflow_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='node-2', + cascade=False, + ) + + async def test_connect_and_disconnect_workflow_nodes_persist_edge_updates(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + graph = make_graph() + graph['nodes'].append({ + 'id': 'node-2', + 'data': { + 'id': 'node-2', + 'type': 'output', + 'name': 'Output Node', + 'group_params': [], + }, + }) + version = SimpleNamespace(id=11, data=graph, is_current=1) + persisted = [] + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + return None + + def fake_update_version(version_info): + persisted.append(deepcopy(version_info.data)) + version.data = deepcopy(version_info.data) + return version_info + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + _, _, edge_id = await ExternalWorkflowService.connect_workflow_nodes( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + source_node_id='node-1', + target_node_id='node-2', + source_handle='output', + target_handle='input', + ) + _, _, removed_edge_id = await ExternalWorkflowService.disconnect_workflow_edge( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + edge_id=edge_id, + ) + + self.assertEqual(edge_id, removed_edge_id) + self.assertEqual(len(persisted), 2) + self.assertEqual(len(persisted[0]['edges']), 1) + self.assertEqual(persisted[0]['edges'][0]['id'], edge_id) + self.assertEqual(persisted[0]['edges'][0]['sourceType'], 'llm') + self.assertEqual(persisted[0]['edges'][0]['targetType'], 'output') + self.assertEqual(persisted[1]['edges'], []) + + async def test_connect_workflow_nodes_rejects_duplicate_edge(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + graph = make_graph() + graph['nodes'].append({ + 'id': 'node-2', + 'data': { + 'id': 'node-2', + 'type': 'output', + 'name': 'Output Node', + 'group_params': [], + }, + }) + graph['edges'].append({ + 'id': 'edge-1', + 'source': 'node-1', + 'sourceHandle': 'output', + 'target': 'node-2', + 'targetHandle': 'input', + }) + version = SimpleNamespace(id=11, data=graph, is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + with self.assertRaises(WorkFlowInitError): + await ExternalWorkflowService.connect_workflow_nodes( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + source_node_id='node-1', + target_node_id='node-2', + source_handle='output', + target_handle='input', + ) + + async def test_disconnect_workflow_edge_rejects_ambiguous_selector(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + graph = make_graph() + graph['nodes'].append({ + 'id': 'node-2', + 'data': { + 'id': 'node-2', + 'type': 'output', + 'name': 'Output Node', + 'group_params': [], + }, + }) + graph['nodes'].append({ + 'id': 'node-3', + 'data': { + 'id': 'node-3', + 'type': 'output', + 'name': 'Output Node 2', + 'group_params': [], + }, + }) + graph['edges'].extend([{ + 'id': 'edge-1', + 'source': 'node-1', + 'sourceHandle': 'output', + 'target': 'node-2', + 'targetHandle': 'input', + }, { + 'id': 'edge-2', + 'source': 'node-1', + 'sourceHandle': 'output', + 'target': 'node-3', + 'targetHandle': 'input', + }]) + version = SimpleNamespace(id=11, data=graph, is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + with self.assertRaises(WorkFlowInitError): + await ExternalWorkflowService.disconnect_workflow_edge( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + source_node_id='node-1', + source_handle='output', + ) + + async def test_get_condition_node_config_returns_cases_and_routes(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_condition_graph(), is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + result = await ExternalWorkflowService.get_condition_node_config( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='condition-1', + ) + + self.assertEqual(result['node_id'], 'condition-1') + self.assertEqual(result['condition_cases'][0]['id'], 'case_a') + self.assertIn('case_a', result['route_handles']) + self.assertIn('right_handle', result['route_handles']) + self.assertEqual(result['outgoing_edges']['case_a'][0]['edge_id'], 'edge-1') + + async def test_get_condition_node_config_rejects_non_condition_node(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_graph(), is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + with self.assertRaises(WorkFlowInitError): + await ExternalWorkflowService.get_condition_node_config( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='node-1', + ) + + async def test_update_condition_node_persists_structured_cases(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_condition_graph(), is_current=1) + persisted = [] + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_validate(login_user, graph_data, flow_name, flow_id=None): + return None + + def fake_update_version(version_info): + persisted.append(deepcopy(version_info.data)) + return version_info + + new_cases = [{ + 'id': 'case_b', + 'operator': 'or', + 'conditions': [{ + 'id': 'rule_2', + 'left_var': 'intent', + 'comparison_operation': 'equals', + 'right_value_type': 'const', + 'right_value': 'vip', + 'variable_key_value': {}, + }], + 'variable_key_value': {}, + }] + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph', side_effect=fake_validate), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + _, updated_version = await ExternalWorkflowService.update_condition_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='condition-1', + condition_cases=new_cases, + ) + + self.assertEqual(updated_version.id, 11) + self.assertEqual(persisted[0]['nodes'][0]['data']['group_params'][0]['params'][0]['value'][0]['id'], 'case_b') + self.assertTrue(persisted[0]['_external_workflow_meta']['draft']) + + async def test_update_condition_node_rejects_case_id_without_matching_edge(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_condition_graph(), is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + with self.assertRaises(WorkFlowInitError): + await ExternalWorkflowService.update_condition_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='condition-1', + condition_cases=[{ + 'id': 'case_renamed', + 'operator': 'and', + 'conditions': [{ + 'id': 'rule_1', + 'left_var': 'score', + 'comparison_operation': 'greater_than', + 'right_value_type': 'const', + 'right_value': '80', + 'variable_key_value': {}, + }], + 'variable_key_value': {}, + }], + ) + + async def test_update_condition_node_rejects_duplicate_case_ids(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace(id=11, data=make_condition_graph(), is_current=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version): + with self.assertRaises(WorkFlowInitError): + await ExternalWorkflowService.update_condition_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='condition-1', + condition_cases=[{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [], + 'variable_key_value': {}, + }, { + 'id': 'case_a', + 'operator': 'or', + 'conditions': [], + 'variable_key_value': {}, + }], + ) + + def test_validate_condition_node_routes_rejects_missing_fallback_edge(self): + graph = make_condition_graph() + graph['edges'] = [edge for edge in graph['edges'] if edge['sourceHandle'] != 'right_handle'] + + with self.assertRaises(WorkFlowInitError): + ExternalWorkflowService._validate_special_node_routes(graph) + + async def test_large_graph_editing_scenario_supports_sequential_mutations(self): + flow = SimpleNamespace(id='flow-1', name='complex-flow', status=FlowStatus.OFFLINE.value) + version = SimpleNamespace( + id=21, + data=ExternalWorkflowService._mark_graph_as_draft(make_large_graph()), + is_current=1, + ) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return flow, version + + def fake_update_version(version_info): + version.data = deepcopy(version_info.data) + return version_info + + with patch.object(ExternalWorkflowService, '_get_editable_version', side_effect=fake_get_editable_version), \ + patch.object(ExternalWorkflowService, '_validate_draft_graph'), \ + patch.object(FlowVersionDao, 'update_version', side_effect=fake_update_version): + _, updated_version, added_node_id = await ExternalWorkflowService.add_workflow_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_type='tool', + name='Late Tool', + position_x=640, + position_y=240, + expected_revision=1, + ) + self.assertEqual(updated_version.id, 21) + self.assertEqual(ExternalWorkflowService.get_graph_revision(version.data), 2) + self.assertTrue(any(node['id'] == added_node_id for node in version.data['nodes'])) + + _, _, added_edge_id = await ExternalWorkflowService.connect_workflow_nodes( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + source_node_id='node-8', + target_node_id=added_node_id, + source_handle='output', + target_handle='input', + expected_revision=2, + ) + self.assertEqual(ExternalWorkflowService.get_graph_revision(version.data), 3) + self.assertTrue(any(edge['id'] == added_edge_id for edge in version.data['edges'])) + + await ExternalWorkflowService.update_workflow_node_params( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='llm-1', + updates={'temperature': 0.2}, + expected_revision=3, + ) + llm_node = next(node for node in version.data['nodes'] if node['id'] == 'llm-1') + self.assertEqual(llm_node['data']['group_params'][0]['params'][0]['value'], 0.2) + self.assertEqual(ExternalWorkflowService.get_graph_revision(version.data), 4) + + await ExternalWorkflowService.update_condition_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id='condition-1', + condition_cases=[{ + 'id': 'case_a', + 'operator': 'and', + 'conditions': [{ + 'id': 'rule_1', + 'left_var': 'score', + 'comparison_operation': 'greater_than_or_equal', + 'right_value_type': 'const', + 'right_value': '90', + 'variable_key_value': {}, + }], + 'variable_key_value': {}, + }], + expected_revision=4, + ) + condition_payload = next( + param for param in next(node for node in version.data['nodes'] if node['id'] == 'condition-1')['data'][ + 'group_params'][0]['params'] + if param['key'] == 'condition' + ) + self.assertEqual(condition_payload['value'][0]['conditions'][0]['right_value'], '90') + self.assertEqual(ExternalWorkflowService.get_graph_revision(version.data), 5) + + await ExternalWorkflowService.disconnect_workflow_edge( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + edge_id=added_edge_id, + expected_revision=5, + ) + self.assertEqual(ExternalWorkflowService.get_graph_revision(version.data), 6) + + await ExternalWorkflowService.remove_workflow_node( + login_user=SimpleNamespace(user_id=1), + flow_id='flow-1', + node_id=added_node_id, + expected_revision=6, + ) + self.assertFalse(any(node['id'] == added_node_id for node in version.data['nodes'])) + self.assertEqual(ExternalWorkflowService.get_graph_revision(version.data), 7) diff --git a/src/backend/test/test_mcp_auth.py b/src/backend/test/test_mcp_auth.py new file mode 100644 index 0000000000..9d2e658ea3 --- /dev/null +++ b/src/backend/test/test_mcp_auth.py @@ -0,0 +1,154 @@ +from types import SimpleNamespace +from unittest import IsolatedAsyncioTestCase, TestCase +from unittest.mock import AsyncMock, patch + +from fastapi import FastAPI, WebSocket +from fastapi.testclient import TestClient +import jwt + +from bisheng.common.exceptions.auth import JWTDecodeError +from bisheng.mcp_server import auth as mcp_auth +from bisheng.mcp_server.auth import McpAuthorizationMiddleware, create_mcp_access_token, get_login_user_from_mcp_token +from bisheng.user.domain.services.auth import AuthJwt + + +def create_test_client(): + app = FastAPI() + + @app.get('/secure') + async def secure(): + login_user = await get_login_user_from_mcp_token() + return { + 'user_id': login_user.user_id, + 'user_name': login_user.user_name, + } + + @app.websocket('/ws') + async def secure_websocket(websocket: WebSocket): + login_user = await get_login_user_from_mcp_token() + await websocket.accept() + await websocket.send_json({ + 'user_id': login_user.user_id, + 'user_name': login_user.user_name, + }) + await websocket.close() + + return TestClient(McpAuthorizationMiddleware(app)) + + +class TestMcpAuthorizationMiddleware(TestCase): + def test_missing_bearer_token_returns_401(self): + client = create_test_client() + + response = client.get('/secure') + + self.assertEqual(response.status_code, 401) + self.assertEqual(response.json()['error'], 'invalid_request') + self.assertIn('WWW-Authenticate', response.headers) + self.assertIn('Bearer realm="bisheng-mcp"', response.headers['WWW-Authenticate']) + + def test_invalid_bearer_token_returns_401(self): + client = create_test_client() + + with patch('bisheng.mcp_server.auth._validate_mcp_access_token', + AsyncMock(side_effect=JWTDecodeError(status_code=422, message='bad token'))): + response = client.get('/secure', headers={'Authorization': 'Bearer invalid-token'}) + + self.assertEqual(response.status_code, 401) + self.assertEqual(response.json()['error'], 'invalid_token') + + def test_valid_bearer_token_populates_login_user(self): + client = create_test_client() + + with patch( + 'bisheng.mcp_server.auth._validate_mcp_access_token', + AsyncMock(return_value=(SimpleNamespace(user_id=7, user_name='admin'), ('workflow.read',))), + ): + response = client.get('/secure', headers={'Authorization': 'Bearer valid-token'}) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), {'user_id': 7, 'user_name': 'admin'}) + + def test_cross_origin_request_is_rejected(self): + client = create_test_client() + + with patch( + 'bisheng.mcp_server.auth._validate_mcp_access_token', + AsyncMock(return_value=(SimpleNamespace(user_id=7, user_name='admin'), ('workflow.read',))), + ): + response = client.get( + '/secure', + headers={ + 'Authorization': 'Bearer valid-token', + 'Origin': 'https://evil.example.com', + 'Host': '127.0.0.1:7860', + }, + ) + + self.assertEqual(response.status_code, 403) + self.assertEqual(response.json()['error'], 'forbidden_origin') + + def test_valid_websocket_bearer_token_populates_login_user(self): + client = create_test_client() + + with patch( + 'bisheng.mcp_server.auth._validate_mcp_access_token', + AsyncMock(return_value=(SimpleNamespace(user_id=7, user_name='admin'), ('workflow.read',))), + ): + with client.websocket_connect('/ws', headers={'Authorization': 'Bearer valid-token'}) as websocket: + self.assertEqual(websocket.receive_json(), {'user_id': 7, 'user_name': 'admin'}) + + +class TestMcpAccessToken(IsolatedAsyncioTestCase): + async def test_create_and_validate_mcp_token(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + + token, payload = create_mcp_access_token( + login_user, + 'parent-access-token', + scopes=['workflow.read', 'workflow.write'], + expires_in=600, + ) + + fake_redis = SimpleNamespace(aget=AsyncMock(return_value='parent-access-token')) + resolved_user = SimpleNamespace(user_id=7, user_name='admin') + with patch('bisheng.mcp_server.auth.get_redis_client', AsyncMock(return_value=fake_redis)), \ + patch('bisheng.mcp_server.auth.UserPayload.init_login_user', + AsyncMock(return_value=resolved_user)): + user, scopes = await mcp_auth._validate_mcp_access_token(token) + + self.assertEqual(user.user_id, 7) + self.assertEqual(user.user_name, 'admin') + self.assertEqual(scopes, ('workflow.read', 'workflow.write')) + self.assertEqual(payload['scopes'], ['workflow.read', 'workflow.write']) + + async def test_validate_mcp_token_uses_top_level_user_claims(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + token, _ = create_mcp_access_token(login_user, 'parent-access-token') + fake_redis = SimpleNamespace(aget=AsyncMock(return_value='parent-access-token')) + resolved_user = SimpleNamespace(user_id=7, user_name='admin') + + with patch('bisheng.mcp_server.auth.get_redis_client', AsyncMock(return_value=fake_redis)), \ + patch('bisheng.mcp_server.auth.UserPayload.init_login_user', + AsyncMock(return_value=resolved_user)): + user, scopes = await mcp_auth._validate_mcp_access_token(token) + + self.assertEqual(user.user_id, 7) + self.assertEqual(user.user_name, 'admin') + self.assertEqual(scopes, ('workflow.read', 'workflow.write', 'workflow.publish')) + + async def test_validate_mcp_token_rejects_non_json_legacy_subject(self): + token = jwt.encode({ + 'sub': '7', + 'iss': 'bisheng-mcp', + 'aud': 'bisheng-workflow-mcp', + 'iat': 1, + 'exp': 9999999999, + 'jti': 'bad-token', + 'token_type': 'mcp_access_token', + 'scope': ['workflow.read'], + 'parent_session_hash': 'hash', + }, AuthJwt().jwt_secret, algorithm='HS256') + + with self.assertRaises(JWTDecodeError): + await mcp_auth._validate_mcp_access_token(token) diff --git a/src/backend/test/test_mcp_device_flow.py b/src/backend/test/test_mcp_device_flow.py new file mode 100644 index 0000000000..aaf960553b --- /dev/null +++ b/src/backend/test/test_mcp_device_flow.py @@ -0,0 +1,137 @@ +from types import SimpleNamespace +from unittest import IsolatedAsyncioTestCase, TestCase +from unittest.mock import AsyncMock, patch + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from bisheng.mcp_server.device_flow import McpDeviceSession, load_device_session_by_device_code +from bisheng.user.api.user import router + + +class FakeRedis: + def __init__(self): + self.store = {} + + async def aset(self, key, value, expiration=3600): + self.store[key] = value + return True + + async def aget(self, key): + return self.store.get(key) + + async def adelete(self, key): + self.store.pop(key, None) + return 1 + + +def create_test_client(): + app = FastAPI() + app.include_router(router, prefix='/api/v1') + return TestClient(app) + + +class TestMcpDeviceFlow(TestCase): + def test_device_authorization_creates_codes(self): + fake_redis = FakeRedis() + client = create_test_client() + + with patch('bisheng.user.api.user.get_redis_client', AsyncMock(return_value=fake_redis)): + response = client.post('/api/v1/user/mcp/device/authorize', json={ + 'client_id': 'codex-cli', + 'client_name': 'Codex CLI', + 'scope': 'workflow.read workflow.write', + 'expires_in': 600, + 'interval': 5, + }) + + self.assertEqual(response.status_code, 200) + payload = response.json()['data'] + self.assertEqual(payload['scope'], 'workflow.read workflow.write') + self.assertIn('verification_uri_complete', payload) + self.assertIn('device_code', payload) + self.assertIn('user_code', payload) + + def test_device_token_returns_pending_before_approval(self): + fake_redis = FakeRedis() + client = create_test_client() + + with patch('bisheng.user.api.user.get_redis_client', AsyncMock(return_value=fake_redis)): + created = client.post('/api/v1/user/mcp/device/authorize', json={ + 'client_id': 'codex-cli', + 'scope': 'workflow.read', + }).json()['data'] + response = client.post('/api/v1/user/mcp/device/token', json={ + 'device_code': created['device_code'], + 'client_id': 'codex-cli', + }) + + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json()['error'], 'authorization_pending') + + def test_verify_page_requires_login_before_approval(self): + fake_redis = FakeRedis() + client = create_test_client() + + with patch('bisheng.user.api.user.get_redis_client', AsyncMock(return_value=fake_redis)): + created = client.post('/api/v1/user/mcp/device/authorize', json={ + 'client_id': 'codex-cli', + 'client_name': 'Codex CLI', + }).json()['data'] + response = client.get(f"/api/v1/user/mcp/device/verify?user_code={created['user_code']}") + + self.assertEqual(response.status_code, 200) + self.assertIn('Login required', response.text) + + def test_verify_and_exchange_device_token(self): + fake_redis = FakeRedis() + client = create_test_client() + + with patch('bisheng.user.api.user.get_redis_client', AsyncMock(return_value=fake_redis)): + created = client.post('/api/v1/user/mcp/device/authorize', json={ + 'client_id': 'codex-cli', + 'client_name': 'Codex CLI', + 'scope': 'workflow.read workflow.write', + }).json()['data'] + + with patch('bisheng.user.api.user.get_request_bisheng_access_token', return_value='parent-access-token'), \ + patch( + 'bisheng.user.api.user.resolve_login_user_from_bisheng_access_token', + AsyncMock(return_value=SimpleNamespace(user_id=7, user_name='demo')), + ): + approval = client.post('/api/v1/user/mcp/device/verify', data={ + 'user_code': created['user_code'], + 'action': 'approve', + }) + + self.assertEqual(approval.status_code, 200) + self.assertIn('Request approved', approval.text) + + token_response = client.post('/api/v1/user/mcp/device/token', json={ + 'device_code': created['device_code'], + 'client_id': 'codex-cli', + 'grant_type': 'urn:ietf:params:oauth:grant-type:device_code', + }) + + self.assertEqual(token_response.status_code, 200) + token_payload = token_response.json() + self.assertEqual(token_payload['token_type'], 'Bearer') + self.assertEqual(token_payload['scope'], 'workflow.read workflow.write') + self.assertEqual(token_payload['mcp_url'], '/mcp') + self.assertTrue(token_payload['access_token']) + + +class TestMcpDeviceFlowStorage(IsolatedAsyncioTestCase): + async def test_approved_session_is_persisted(self): + fake_redis = FakeRedis() + session = McpDeviceSession( + device_code='device-code', + user_code='ABCD-EFGH', + client_id='codex-cli', + scopes=['workflow.read'], + expires_at=4102444800, + ) + fake_redis.store['mcp:device:code:device-code'] = session.model_dump() + stored = await load_device_session_by_device_code(fake_redis, 'device-code') + self.assertIsNotNone(stored) + self.assertEqual(stored.user_code, 'ABCD-EFGH') diff --git a/src/backend/test/test_workflow_authoring_service.py b/src/backend/test/test_workflow_authoring_service.py new file mode 100644 index 0000000000..795c953d54 --- /dev/null +++ b/src/backend/test/test_workflow_authoring_service.py @@ -0,0 +1,155 @@ +from types import SimpleNamespace +from unittest import IsolatedAsyncioTestCase, TestCase +from unittest.mock import AsyncMock, patch + +from bisheng.api.services.workflow_authoring import WorkflowAuthoringService +from bisheng.common.errcode.flow import WorkFlowInitError +from bisheng.common.errcode.http_error import NotFoundError +from bisheng.database.models.flow import FlowStatus +from bisheng.workflow.authoring.registry import get_node_template_descriptor, list_node_type_descriptors + + +def make_graph(): + return { + 'nodes': [{ + 'id': 'node-1', + 'data': { + 'id': 'node-1', + 'type': 'llm', + 'name': 'LLM Node', + 'tab': { + 'value': 'single', + 'options': [{'label': 'Single', 'key': 'single'}, {'label': 'Batch', 'key': 'batch'}], + }, + 'group_params': [{ + 'name': 'model', + 'params': [{ + 'key': 'temperature', + 'type': 'slide', + 'required': True, + 'scope': [0, 1], + 'placeholder': '0.0 ~ 1.0', + 'refresh': True, + 'options': [{'key': 0.3, 'value': 0.3}, {'key': 0.7, 'value': 0.7}], + 'value': 0.7, + }, { + 'key': 'system_prompt', + 'type': 'var_textarea', + 'value': 'hello', + }], + }], + }, + }], + 'edges': [{ + 'id': 'edge-1', + 'source': 'node-1', + 'sourceHandle': 'output', + 'target': 'node-2', + 'targetHandle': 'input', + }], + } + + +class TestWorkflowAuthoringRegistry(TestCase): + def test_list_node_types_includes_dynamic_tool(self): + node_types = {item.type: item for item in list_node_type_descriptors()} + + self.assertIn('llm', node_types) + self.assertIn('condition', node_types) + self.assertIn('tool', node_types) + self.assertTrue(node_types['tool'].dynamic_template) + + def test_get_node_template_returns_normalized_llm_template(self): + template = get_node_template_descriptor('llm') + + self.assertIsNotNone(template) + self.assertEqual(template.node_type, 'llm') + self.assertEqual(template.display_name, 'LLM') + self.assertEqual(template.tab.value, 'single') + self.assertIn('model_id', template.params) + self.assertIn('temperature', template.params) + self.assertIn('system_prompt', template.params) + + +class TestWorkflowAuthoringService(IsolatedAsyncioTestCase): + async def test_list_workflows_filters_by_write_access(self): + flows = [ + SimpleNamespace(id='flow-1', user_id=10, name='one', description='', status=FlowStatus.OFFLINE.value), + SimpleNamespace(id='flow-2', user_id=11, name='two', description='', status=FlowStatus.ONLINE.value), + ] + login_user = SimpleNamespace( + user_id=1, + async_access_check=AsyncMock(side_effect=[True, False]), + ) + + with patch.object(WorkflowAuthoringService, '_list_candidate_workflows', AsyncMock(return_value=flows)), \ + patch.object(WorkflowAuthoringService, '_build_manifest', side_effect=lambda flow: flow.id): + result = await WorkflowAuthoringService.list_workflows(login_user) + + self.assertEqual(result, ['flow-1']) + + async def test_get_workflow_graph_returns_normalized_graph(self): + version = SimpleNamespace(id=11, data=make_graph()) + login_user = SimpleNamespace(user_id=1) + + async def fake_get_editable_version(login_user, flow_id, version_id=None): + return SimpleNamespace(id=flow_id, name='demo', status=FlowStatus.OFFLINE.value), version + + with patch('bisheng.api.services.workflow_authoring.ExternalWorkflowService._get_editable_version', + side_effect=fake_get_editable_version): + graph = await WorkflowAuthoringService.get_workflow_graph(login_user, 'flow-1') + + self.assertEqual(graph.flow_id, 'flow-1') + self.assertEqual(graph.version_id, 11) + self.assertEqual(len(graph.nodes), 1) + self.assertEqual(graph.nodes[0].id, 'node-1') + self.assertEqual(graph.nodes[0].type, 'llm') + self.assertEqual(graph.nodes[0].tab.value, 'single') + self.assertIn('temperature', graph.nodes[0].params) + self.assertEqual(graph.edges[0]['source'], 'node-1') + + async def test_get_workflow_versions_marks_editable_draft(self): + flow = SimpleNamespace(id='flow-1', name='demo', status=FlowStatus.OFFLINE.value) + versions = [ + SimpleNamespace(id=11, name='v1', description='base', is_current=1, create_time=None, update_time=None), + SimpleNamespace(id=12, name='draft', description='draft', is_current=0, create_time=None, update_time=None), + ] + detailed_versions = { + 11: SimpleNamespace(id=11, data=make_graph(), original_version_id=None), + 12: SimpleNamespace( + id=12, + data={'nodes': make_graph()['nodes'], 'edges': make_graph()['edges'], + '_external_workflow_meta': {'draft': True, 'revision': 3}}, + original_version_id=11, + ), + } + + with patch('bisheng.api.services.workflow_authoring.ExternalWorkflowService._get_workflow_with_write_access', + AsyncMock(return_value=flow)), \ + patch.object(WorkflowAuthoringService, '_editable_version_without_side_effect', + return_value=(detailed_versions[11], detailed_versions[12])), \ + patch('bisheng.api.services.workflow_authoring.FlowVersionDao.get_list_by_flow', + return_value=versions), \ + patch('bisheng.api.services.workflow_authoring.FlowVersionDao.get_version_by_id', + side_effect=lambda version_id: detailed_versions[version_id]): + result = await WorkflowAuthoringService.get_workflow_versions(SimpleNamespace(user_id=1), 'flow-1') + + self.assertEqual(len(result), 2) + self.assertTrue(result[1].is_editable) + self.assertTrue(result[1].is_external_draft) + self.assertEqual(result[1].draft_revision, 3) + self.assertEqual(result[1].original_version_id, 11) + + def test_get_node_template_raises_for_unknown_type(self): + with self.assertRaises(NotFoundError): + WorkflowAuthoringService.get_node_template('missing') + + def test_diagnostics_from_exception_extracts_field_path(self): + diagnostics = WorkflowAuthoringService.diagnostics_from_exception( + WorkFlowInitError(msg='Param temperature must be within scope [0, 1]') + ) + + self.assertEqual(len(diagnostics), 1) + self.assertEqual(diagnostics[0].severity.value, 'error') + self.assertEqual(diagnostics[0].field_path, 'temperature') + self.assertEqual(diagnostics[0].suggested_fix, 'Review the parameter value and node template requirements.') diff --git a/src/backend/test/test_workflow_mcp_e2e.py b/src/backend/test/test_workflow_mcp_e2e.py new file mode 100644 index 0000000000..014ff1a4ca --- /dev/null +++ b/src/backend/test/test_workflow_mcp_e2e.py @@ -0,0 +1,274 @@ +import json +from contextlib import AsyncExitStack, asynccontextmanager +from types import SimpleNamespace +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, patch + +import httpx +from mcp import ClientSession +from mcp.client.streamable_http import streamablehttp_client + +from bisheng.api.services.external_workflow import ExternalWorkflowService +from bisheng.common.errcode.flow import WorkFlowInitError +from bisheng.api.services.flow import FlowService +from bisheng.database.models.flow import FlowDao +from bisheng.database.models.flow_version import FlowVersionDao +from bisheng.mcp_server.auth import McpAuthorizationMiddleware +from bisheng.mcp_server.workflow import create_workflow_mcp_server +from bisheng.workflow.authoring import WorkflowManifest + + +class TestWorkflowMcpE2E(IsolatedAsyncioTestCase): + def setUp(self): + self.server = create_workflow_mcp_server() + self.app = McpAuthorizationMiddleware(self.server.streamable_http_app()) + self._stack = AsyncExitStack() + + async def asyncTearDown(self): + await self._stack.aclose() + + def _httpx_client_factory(self, headers=None, timeout=None, auth=None): + @asynccontextmanager + async def _factory(): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=self.app), + base_url='http://testserver', + headers=headers, + timeout=timeout, + auth=auth, + follow_redirects=True, + ) as client: + yield client + + return _factory() + + @staticmethod + def _decode_tool_result(result) -> dict: + if not result.content: + raise AssertionError('Tool result did not contain any MCP content payload') + return json.loads(result.content[0].text) + + async def _open_session(self, scopes=('workflow.read', 'workflow.write')): + login_user = SimpleNamespace(user_id=7, user_name='admin') + auth_patch = patch( + 'bisheng.mcp_server.auth._validate_mcp_access_token', + AsyncMock(return_value=(login_user, scopes)), + ) + self.addCleanup(auth_patch.stop) + auth_patch.start() + await self._stack.enter_async_context(self.server.session_manager.run()) + streams = await self._stack.enter_async_context( + streamablehttp_client( + 'http://testserver/', + headers={'Authorization': 'Bearer test-token'}, + httpx_client_factory=self._httpx_client_factory, + ) + ) + read_stream, write_stream, _ = streams + session = await self._stack.enter_async_context(ClientSession(read_stream, write_stream)) + await session.initialize() + return session + + async def test_http_rejects_missing_bearer_token(self): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=self.app), + base_url='http://testserver', + follow_redirects=True, + ) as client: + response = await client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': 1, + 'method': 'initialize', + 'params': { + 'protocolVersion': '2025-06-18', + 'capabilities': {}, + 'clientInfo': {'name': 'pytest', 'version': '1.0.0'}, + }, + }, + ) + + self.assertEqual(response.status_code, 401) + self.assertEqual(response.json()['error'], 'invalid_request') + self.assertIn('Missing Bearer token', response.headers['WWW-Authenticate']) + + async def test_streamable_http_lists_expected_tools(self): + session = await self._open_session() + + tools = await session.list_tools() + tool_names = {tool.name for tool in tools.tools} + + self.assertIn('ping', tool_names) + self.assertIn('list_workflows', tool_names) + self.assertIn('add_node', tool_names) + self.assertIn('update_condition_node', tool_names) + + ping_result = await session.call_tool('ping', {}) + payload = self._decode_tool_result(ping_result) + + self.assertTrue(payload['ok']) + self.assertTrue(payload['authenticated']) + self.assertEqual(payload['user_id'], 7) + self.assertEqual(payload['scopes'], []) + + async def test_streamable_http_list_workflows_returns_json_payload(self): + session = await self._open_session() + + with patch( + 'bisheng.mcp_server.workflow.WorkflowAuthoringService.list_workflows', + AsyncMock(return_value=[WorkflowManifest(flow_id='flow-1', name='demo')]), + ): + result = await session.call_tool('list_workflows', {}) + + payload = self._decode_tool_result(result) + self.assertTrue(payload['ok']) + self.assertEqual(payload['workflows'][0]['flow_id'], 'flow-1') + self.assertEqual(payload['workflows'][0]['name'], 'demo') + + async def test_streamable_http_add_node_returns_mutation_payload(self): + session = await self._open_session(scopes=('workflow.read', 'workflow.write')) + flow = SimpleNamespace(id='flow-1') + version = SimpleNamespace(id=11, data={'nodes': [], 'edges': [], '_external_workflow_meta': {'revision': 3}}) + + with patch( + 'bisheng.mcp_server.workflow.ExternalWorkflowService.add_workflow_node', + AsyncMock(return_value=(flow, version, 'llm_1234')), + ), patch( + 'bisheng.mcp_server.workflow.ExternalWorkflowService.get_graph_revision', + return_value=3, + ): + result = await session.call_tool( + 'add_node', + { + 'flow_id': 'flow-1', + 'node_type': 'llm', + 'name': 'LLM Node', + 'position_x': 120, + 'position_y': 48, + 'initial_params': {'temperature': 0.3}, + }, + ) + + payload = self._decode_tool_result(result) + self.assertTrue(payload['ok']) + self.assertEqual(payload['flow_id'], 'flow-1') + self.assertEqual(payload['version_id'], 11) + self.assertEqual(payload['draft_revision'], 3) + self.assertEqual(payload['node_id'], 'llm_1234') + + async def test_streamable_http_create_workflow_draft_scaffolds_empty_graph(self): + session = await self._open_session(scopes=('workflow.write',)) + captured = {} + + def fake_validate_runtime(login_user, graph_data, flow_name, flow_id=None): + captured['graph_data'] = json.loads(json.dumps(graph_data)) + + def fake_create_flow(flow_info, flow_type): + flow_info.id = 'flow-1' + captured['created_graph'] = json.loads(json.dumps(flow_info.data)) + return flow_info + + def fake_get_current_version(flow_id): + return SimpleNamespace(id=11, data=json.loads(json.dumps(captured['created_graph']))) + + with patch.object(ExternalWorkflowService, '_assert_workflow_name_available'), \ + patch.object(ExternalWorkflowService, '_validate_workflow_runtime', side_effect=fake_validate_runtime), \ + patch.object(FlowDao, 'create_flow', side_effect=fake_create_flow), \ + patch.object(FlowVersionDao, 'get_version_by_flow', side_effect=fake_get_current_version), \ + patch.object(FlowVersionDao, 'update_version', side_effect=lambda version: version), \ + patch.object(FlowService, 'create_flow_hook'): + result = await session.call_tool( + 'create_workflow_draft', + { + 'name': 'demo', + 'graph_data': {'nodes': [], 'edges': []}, + }, + ) + + payload = self._decode_tool_result(result) + self.assertTrue(payload['ok']) + self.assertEqual(payload['flow_id'], 'flow-1') + self.assertEqual(payload['version_id'], 11) + scaffold_types = [node['data']['type'] for node in captured['graph_data']['nodes']] + self.assertEqual(scaffold_types, ['start', 'end']) + self.assertEqual(len(captured['graph_data']['edges']), 1) + + async def test_streamable_http_create_workflow_draft_accepts_normalized_graph_descriptor(self): + session = await self._open_session(scopes=('workflow.write',)) + captured = {} + + def fake_validate_runtime(login_user, graph_data, flow_name, flow_id=None): + captured['graph_data'] = json.loads(json.dumps(graph_data)) + + def fake_create_flow(flow_info, flow_type): + flow_info.id = 'flow-1' + captured['created_graph'] = json.loads(json.dumps(flow_info.data)) + return flow_info + + def fake_get_current_version(flow_id): + return SimpleNamespace(id=11, data=json.loads(json.dumps(captured['created_graph']))) + + with patch.object(ExternalWorkflowService, '_assert_workflow_name_available'), \ + patch.object(ExternalWorkflowService, '_validate_workflow_runtime', side_effect=fake_validate_runtime), \ + patch.object(FlowDao, 'create_flow', side_effect=fake_create_flow), \ + patch.object(FlowVersionDao, 'get_version_by_flow', side_effect=fake_get_current_version), \ + patch.object(FlowVersionDao, 'update_version', side_effect=lambda version: version), \ + patch.object(FlowService, 'create_flow_hook'): + result = await session.call_tool( + 'create_workflow_draft', + { + 'name': 'demo', + 'graph_data': { + 'nodes': [{ + 'id': 'input-1', + 'type': 'input', + 'name': 'Ticket Input', + 'params': {}, + }], + 'edges': [], + }, + }, + ) + + payload = self._decode_tool_result(result) + self.assertTrue(payload['ok']) + input_node = next(node for node in captured['graph_data']['nodes'] if node['id'] == 'input-1') + self.assertEqual(input_node['type'], 'flowNode') + self.assertEqual(input_node['data']['type'], 'input') + self.assertEqual(input_node['data']['name'], 'Ticket Input') + + async def test_streamable_http_returns_structured_validation_error_payload(self): + session = await self._open_session() + + with patch( + 'bisheng.mcp_server.workflow.ExternalWorkflowService.validate_workflow', + AsyncMock(side_effect=WorkFlowInitError(msg='Param temperature must be within scope [0, 1]')), + ): + result = await session.call_tool( + 'validate_workflow', + {'flow_id': 'flow-1', 'version_id': 11}, + ) + + payload = self._decode_tool_result(result) + self.assertFalse(payload['ok']) + self.assertFalse(payload['valid']) + self.assertEqual(payload['error_code'], 10526) + self.assertEqual(payload['errors'], ['Param temperature must be within scope [0, 1]']) + self.assertEqual(payload['diagnostics'][0]['field_path'], 'temperature') + + async def test_streamable_http_scope_failure_returns_wrapped_tool_payload(self): + session = await self._open_session(scopes=('workflow.read',)) + + result = await session.call_tool( + 'add_node', + { + 'flow_id': 'flow-1', + 'node_type': 'llm', + }, + ) + + payload = self._decode_tool_result(result) + self.assertFalse(payload['ok']) + self.assertEqual(payload['error_code'], 403) + self.assertIn('workflow.write', payload['message']) diff --git a/src/backend/test/test_workflow_mcp_tools.py b/src/backend/test/test_workflow_mcp_tools.py new file mode 100644 index 0000000000..fb8dba08fa --- /dev/null +++ b/src/backend/test/test_workflow_mcp_tools.py @@ -0,0 +1,125 @@ +from types import SimpleNamespace +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, patch + +from bisheng.common.errcode.flow import WorkFlowInitError +from bisheng.common.errcode.http_error import UnAuthorizedError +from bisheng.common.errcode.http_error import NotFoundError +from bisheng.mcp_server.workflow import create_workflow_mcp_server +from bisheng.workflow.authoring import WorkflowManifest + + +def get_tool(name: str): + mcp = create_workflow_mcp_server() + return mcp._tool_manager._tools[name].fn + + +class TestWorkflowMcpTools(IsolatedAsyncioTestCase): + async def test_ping_hides_scopes_but_whoami_returns_them(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + ping = get_tool('ping') + whoami = get_tool('whoami') + + with patch('bisheng.mcp_server.workflow.get_login_user_from_mcp_token', AsyncMock(return_value=login_user)), \ + patch('bisheng.mcp_server.workflow.get_current_token_scopes', + return_value=('workflow.read', 'workflow.write')): + ping_result = await ping() + whoami_result = await whoami() + + self.assertTrue(ping_result.ok) + self.assertEqual(ping_result.scopes, []) + self.assertEqual(whoami_result.scopes, ['workflow.read', 'workflow.write']) + + async def test_list_workflows_enforces_read_scope_and_wraps_result(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + list_workflows = get_tool('list_workflows') + workflow = WorkflowManifest(flow_id='flow-1', name='demo') + + with patch('bisheng.mcp_server.workflow.get_login_user_from_mcp_token', AsyncMock(return_value=login_user)), \ + patch('bisheng.mcp_server.workflow.require_mcp_scopes') as require_scopes, \ + patch('bisheng.mcp_server.workflow.WorkflowAuthoringService.list_workflows', + AsyncMock(return_value=[workflow])): + result = await list_workflows() + + require_scopes.assert_called_once_with('workflow.read') + self.assertTrue(result.ok) + self.assertEqual(result.workflows, [workflow]) + + async def test_add_node_wraps_service_error(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + add_node = get_tool('add_node') + + with patch('bisheng.mcp_server.workflow.get_login_user_from_mcp_token', AsyncMock(return_value=login_user)), \ + patch('bisheng.mcp_server.workflow.require_mcp_scopes'), \ + patch('bisheng.mcp_server.workflow.ExternalWorkflowService.add_workflow_node', + AsyncMock(side_effect=NotFoundError(msg='missing node template'))): + result = await add_node(flow_id='flow-1', node_type='missing') + + self.assertFalse(result.ok) + self.assertEqual(result.error_code, 404) + self.assertEqual(result.message, 'missing node template') + + async def test_list_workflows_wraps_scope_error(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + list_workflows = get_tool('list_workflows') + + with patch('bisheng.mcp_server.workflow.get_login_user_from_mcp_token', AsyncMock(return_value=login_user)), \ + patch('bisheng.mcp_server.workflow.require_mcp_scopes', + side_effect=UnAuthorizedError(msg='forbidden')): + result = await list_workflows() + + self.assertFalse(result.ok) + self.assertEqual(result.error_code, 403) + self.assertEqual(result.message, 'forbidden') + + async def test_get_condition_node_wraps_service_payload(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + get_condition_node = get_tool('get_condition_node') + payload = { + 'flow_id': 'flow-1', + 'version_id': 11, + 'draft_revision': 2, + 'node_id': 'condition-1', + 'node_name': 'Condition Node', + 'condition_cases': [{'id': 'case_a', 'operator': 'and', 'conditions': [], 'variable_key_value': {}}], + 'route_handles': ['case_a', 'right_handle'], + 'outgoing_edges': {'case_a': [{'edge_id': 'edge-1', 'target_node_id': 'node-2', 'target_handle': 'input'}]}, + } + + with patch('bisheng.mcp_server.workflow.get_login_user_from_mcp_token', AsyncMock(return_value=login_user)), \ + patch('bisheng.mcp_server.workflow.require_mcp_scopes'), \ + patch('bisheng.mcp_server.workflow.ExternalWorkflowService.get_condition_node_config', + AsyncMock(return_value=payload)): + result = await get_condition_node(flow_id='flow-1', node_id='condition-1') + + self.assertTrue(result.ok) + self.assertEqual(result.node_id, 'condition-1') + self.assertEqual(result.route_handles, ['case_a', 'right_handle']) + + async def test_validate_workflow_returns_structured_diagnostics_on_error(self): + login_user = SimpleNamespace(user_id=7, user_name='admin') + validate_workflow = get_tool('validate_workflow') + + with patch('bisheng.mcp_server.workflow.get_login_user_from_mcp_token', AsyncMock(return_value=login_user)), \ + patch('bisheng.mcp_server.workflow.require_mcp_scopes'), \ + patch('bisheng.mcp_server.workflow.ExternalWorkflowService.validate_workflow', + AsyncMock(side_effect=WorkFlowInitError(msg='Param temperature must be within scope [0, 1]'))): + result = await validate_workflow(flow_id='flow-1', version_id=11) + + self.assertFalse(result.ok) + self.assertFalse(result.valid) + self.assertEqual(result.error_code, 10526) + self.assertEqual(result.errors, ['Param temperature must be within scope [0, 1]']) + self.assertEqual(result.diagnostics[0].field_path, 'temperature') + + async def test_ping_returns_connection_error_when_authentication_fails(self): + ping = get_tool('ping') + + with patch('bisheng.mcp_server.workflow.get_login_user_from_mcp_token', + AsyncMock(side_effect=UnAuthorizedError(msg='missing token'))): + result = await ping() + + self.assertFalse(result.ok) + self.assertFalse(result.authenticated) + self.assertEqual(result.error_code, 403) + self.assertEqual(result.message, 'missing token') diff --git a/src/backend/uv.lock b/src/backend/uv.lock index b0856f3991..7b65443c61 100644 --- a/src/backend/uv.lock +++ b/src/backend/uv.lock @@ -421,7 +421,7 @@ wheels = [ [[package]] name = "backend" -version = "2.3.0" +version = "2.4.0b1" source = { virtual = "." } dependencies = [ { name = "aiofiles" }, @@ -504,6 +504,11 @@ dependencies = [ { name = "zhipuai" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest" }, +] + [package.metadata] requires-dist = [ { name = "aiofiles", specifier = ">=25.1.0" }, @@ -585,6 +590,9 @@ requires-dist = [ { name = "zhipuai", specifier = ">=1.0.7" }, ] +[package.metadata.requires-dev] +dev = [{ name = "pytest", specifier = ">=9.0.3" }] + [[package]] name = "backoff" version = "2.2.1" @@ -2316,6 +2324,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/ed/1f1afb2e9e7f38a545d628f864d562a5ae64fe6f7a10e28ffb9b185b4e89/importlib_resources-6.5.2-py3-none-any.whl", hash = "sha256:789cfdc3ed28c78b67a06acb8126751ced69a3d5f79c095a98298cd8a760ccec", size = 37461, upload-time = "2025-01-03T18:51:54.306Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "jieba" version = "0.42.1" @@ -4578,6 +4595,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/60/fe31d7e6b8907789dcb0584f88be741ba388413e4fbce35f1eba4e3073de/playwright-1.57.0-py3-none-win_arm64.whl", hash = "sha256:5f065f5a133dbc15e6e7c71e7bc04f258195755b1c32a432b792e28338c8335e", size = 32837940, upload-time = "2025-12-09T08:06:42.268Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "portalocker" version = "3.2.0" @@ -5482,6 +5508,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/0a/c99fb7d7e176f8b176ef19704a32e6a9c6aafdf19ef75a187f701fc15801/pysbd-0.3.4-py3-none-any.whl", hash = "sha256:cd838939b7b0b185fcf86b0baf6636667dfb6e474743beeff878e9f42e022953", size = 71082, upload-time = "2021-02-11T16:36:33.351Z" }, ] +[[package]] +name = "pytest" +version = "9.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/src/frontend/client/src/hooks/AuthContext.tsx b/src/frontend/client/src/hooks/AuthContext.tsx index c6c6755e34..16a9328ee6 100644 --- a/src/frontend/client/src/hooks/AuthContext.tsx +++ b/src/frontend/client/src/hooks/AuthContext.tsx @@ -151,6 +151,15 @@ const AuthContextProvider = ({ if (userQuery.data) { setUser(userQuery.data); } else if (userQuery.isError) { + // Dev mode: mock user to bypass auth + if (import.meta.env.DEV && !isAuthenticated) { + setUserContext({ + token: 'dev-mock-token', + isAuthenticated: true, + user: { id: 'dev', username: 'dev', email: 'dev@test.com', name: 'Dev User', role: 'admin', avatar: '', provider: 'local', plugins: [], createdAt: '', updatedAt: '' } as any, + }); + return; + } doSetError((userQuery.error as Error).message); // navigate(`/${__APP_ENV__.BISHENG_HOST}`, { replace: true }); } diff --git a/src/frontend/client/src/routes/index.tsx b/src/frontend/client/src/routes/index.tsx index 4f6438cbab..4bd29fa7df 100644 --- a/src/frontend/client/src/routes/index.tsx +++ b/src/frontend/client/src/routes/index.tsx @@ -22,6 +22,7 @@ import Subscription from '~/pages/subscription'; import AppRoot from './AppRoot'; import Root from './Root'; import Knowledge from '~/pages/knowledge'; +import PrototypePage from '~/pages/prototype'; const AuthLayout = () => ( @@ -80,6 +81,7 @@ export const router = createBrowserRouter([ path: 'chat/:conversationId/:fid/:type', element: `/app/${p.conversationId}/${p.fid}/${p.type}`} /> }, + { path: 'prototype', element: }, { path: 'apps', element: }, { path: 'apps/explore', element: }, { path: 'channel', element: }, @@ -93,4 +95,4 @@ export const router = createBrowserRouter([ { path: '/html', element: }, { path: '/404', element: }, { path: "*", element: } -], baseConfig); \ No newline at end of file +], baseConfig); diff --git a/src/frontend/client/vite.config.ts b/src/frontend/client/vite.config.ts index 0e20b6aad7..c9555f60b8 100644 --- a/src/frontend/client/vite.config.ts +++ b/src/frontend/client/vite.config.ts @@ -13,6 +13,9 @@ const app_env = { BASE_URL: '/workspace', BISHENG_HOST: '/admin' } + +const proxyTarget = process.env.VITE_PROXY_TARGET || 'http://127.0.0.1:7860'; +const fileServiceTarget = process.env.VITE_FILE_SERVICE_TARGET || proxyTarget; // https://vitejs.dev/config/ export default defineConfig(({ command }) => ({ base: app_env.BASE_URL || '/', @@ -30,7 +33,7 @@ export default defineConfig(({ command }) => ({ // changeOrigin: true, // }, '^(/workspace)?/bisheng': { - target: "http://192.168.106.120:3002", + target: fileServiceTarget, changeOrigin: true, secure: false, rewrite: (path) => { @@ -38,7 +41,7 @@ export default defineConfig(({ command }) => ({ }, }, '/workspace/api': { - target: 'http://192.168.106.120:3002', + target: proxyTarget, changeOrigin: true, secure: false, ws: true, @@ -52,7 +55,7 @@ export default defineConfig(({ command }) => ({ }, }, '/workspace/tmp-dir': { - target: 'http://192.168.106.120:3002', + target: fileServiceTarget, changeOrigin: true, secure: false, rewrite: (path) => { @@ -328,4 +331,4 @@ export function sourcemapExclude(opts?: SourcemapExclude): Plugin { } }, }; -} \ No newline at end of file +} diff --git a/src/frontend/platform/src/pages/BuildPage/flow/FlowNode/component/ConditionItem.tsx b/src/frontend/platform/src/pages/BuildPage/flow/FlowNode/component/ConditionItem.tsx index d8dbc54b8f..3445767ddf 100644 --- a/src/frontend/platform/src/pages/BuildPage/flow/FlowNode/component/ConditionItem.tsx +++ b/src/frontend/platform/src/pages/BuildPage/flow/FlowNode/component/ConditionItem.tsx @@ -24,6 +24,27 @@ interface Item { del: boolean } +const normalizeConditionRule = (item?: Partial) => ({ + id: item?.id || generateUUID(8), + left_var: typeof item?.left_var === 'string' ? item.left_var : '', + left_label: typeof item?.left_label === 'string' ? item.left_label : '', + comparison_operation: typeof item?.comparison_operation === 'string' ? item.comparison_operation : '', + right_value_type: item?.right_value_type === 'ref' ? 'ref' : 'input', + right_value: typeof item?.right_value === 'string' ? item.right_value : '', + right_label: typeof item?.right_label === 'string' ? item.right_label : '', +}); + +const normalizeConditionBranches = (value) => { + if (!Array.isArray(value)) return []; + return value.map((branch) => ({ + id: branch?.id || generateUUID(8), + operator: branch?.operator === 'or' ? 'or' : 'and', + conditions: Array.isArray(branch?.conditions) + ? branch.conditions.map((item) => normalizeConditionRule(item)) + : [], + })); +}; + const Item = ({ nodeId, item, index, del, required, varErrors, onUpdateItem, onDeleteItem }) => { const { t } = useTranslation('flow'); @@ -159,6 +180,11 @@ export default function ConditionItem({ nodeId, node, data: paramItem, onChange, const { t } = useTranslation('flow'); // 获取翻译函数 const [value, setValue] = useState([]); const [required, setRequired] = useState(false); + const normalizedParamValue = useMemo(() => normalizeConditionBranches(paramItem?.value), [paramItem?.value]); + const needsParamNormalization = useMemo( + () => JSON.stringify(normalizedParamValue) !== JSON.stringify(paramItem?.value ?? []), + [normalizedParamValue, paramItem?.value] + ); const handleAddCondition = () => { setRequired(false); @@ -170,12 +196,20 @@ export default function ConditionItem({ nodeId, node, data: paramItem, onChange, }; useEffect(() => { - if (paramItem.value && paramItem.value.length) { - setValue(paramItem.value); - } else { - handleAddCondition(); + if (normalizedParamValue.length) { + setValue(normalizedParamValue); + if (needsParamNormalization) { + onChange(normalizedParamValue); + } + return; } - }, []); + setValue((current) => { + if (current.length) return current; + const initialValue = [{ id: generateUUID(8), operator: 'and', conditions: [] }]; + onChange(initialValue); + return initialValue; + }); + }, [needsParamNormalization, normalizedParamValue, onChange, paramItem?.value]); const deleteCondition = (id) => { setValue((val) => { @@ -219,8 +253,8 @@ export default function ConditionItem({ nodeId, node, data: paramItem, onChange, setTimeout(() => { setRequired(true); }, 100); - if (paramItem.value.length === 0) return t('conditionBranchCannotBeEmpty'); // 条件分支不可为空 - const res = paramItem.value.some((item) => { + if (value.length === 0) return t('conditionBranchCannotBeEmpty'); // 条件分支不可为空 + const res = value.some((item) => { if (!item.conditions.length) return true; return item.conditions.some((cds) => { if (!cds.left_label) return true; @@ -237,7 +271,7 @@ export default function ConditionItem({ nodeId, node, data: paramItem, onChange, }); return () => onValidate(() => { }); - }, [paramItem.value]); + }, [onValidate, t, value]); // 校验变量是否可用 const { flow } = useFlowStore(); @@ -265,7 +299,7 @@ export default function ConditionItem({ nodeId, node, data: paramItem, onChange, useEffect(() => { onVarEvent && onVarEvent(validateVarAvailble); return () => onVarEvent && onVarEvent(() => { }); - }, [paramItem, value]); + }, [onVarEvent, paramItem, value]); // Update Preset Questions // const [_, forceUpdate] = useState(false) @@ -361,4 +395,4 @@ export default function ConditionItem({ nodeId, node, data: paramItem, onChange, ); -} \ No newline at end of file +} diff --git a/src/frontend/platform/src/util/flowCompatible.ts b/src/frontend/platform/src/util/flowCompatible.ts index 3aa81e54b6..b4ea19124f 100644 --- a/src/frontend/platform/src/util/flowCompatible.ts +++ b/src/frontend/platform/src/util/flowCompatible.ts @@ -12,10 +12,35 @@ export const flowVersionCompatible = (flow) => { case 'llm': comptibleLLM(node.data); break; case 'rag': comptibleRag(node.data); break; case 'knowledge_retriever': comptibleKnowledgeRetriever(node.data); break; + case 'condition': comptibleCondition(node.data); break; } }) return flow } + +const normalizeConditionRule = (rule) => ({ + id: rule?.id || generateUUID(8), + left_var: typeof rule?.left_var === 'string' ? rule.left_var : '', + left_label: typeof rule?.left_label === 'string' ? rule.left_label : '', + comparison_operation: typeof rule?.comparison_operation === 'string' ? rule.comparison_operation : '', + right_value_type: rule?.right_value_type === 'ref' ? 'ref' : 'input', + right_value: typeof rule?.right_value === 'string' ? rule.right_value : '', + right_label: typeof rule?.right_label === 'string' ? rule.right_label : '', +}); + +const comptibleCondition = (node) => { + if (!Array.isArray(node?.group_params) || !node.group_params[0]?.params?.length) return; + + const conditionParam = node.group_params[0].params.find((param) => param.key === 'condition'); + if (!conditionParam) return; + + const rawBranches = Array.isArray(conditionParam.value) ? conditionParam.value : []; + conditionParam.value = rawBranches.map((branch) => ({ + id: branch?.id || generateUUID(8), + operator: branch?.operator === 'or' ? 'or' : 'and', + conditions: Array.isArray(branch?.conditions) ? branch.conditions.map(normalizeConditionRule) : [], + })); +} const comptibleRag = (node) => { if (!node.v) { node.v = 1 @@ -391,4 +416,4 @@ const comptibleLLM = (node) => { node.v = 2 } -} \ No newline at end of file +} diff --git a/src/frontend/platform/vite.config.mts b/src/frontend/platform/vite.config.mts index 90f2a0008f..c9175cc609 100644 --- a/src/frontend/platform/vite.config.mts +++ b/src/frontend/platform/vite.config.mts @@ -13,10 +13,8 @@ import svgr from "vite-plugin-svgr"; */ const app_env = { BASE_URL: '' } // /custom -// Use environment variable to determine the target. -// const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860"; -const target = process.env.VITE_PROXY_TARGET || "http://192.168.106.120:3002"; -const fileServiceTarget = "http://192.168.106.116:9000"; +const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860"; +const fileServiceTarget = process.env.VITE_FILE_SERVICE_TARGET || target; // 公共代理配置 const commonProxyOptions = {