Skip to content

Commit 4aec9ef

Browse files
committed
feat(cli): add cross-platform socket support for Windows compatibility
- Refactor CLI platform to support both Unix Socket and TCP Socket - Add platform detection module to auto-detect OS capabilities - Add socket factory pattern for creating appropriate socket servers - Split monolithic cli_adapter.py into modular components: * platform_detector.py: OS and socket capability detection * socket_abstract.py: Abstract base class for socket servers * socket_factory.py: Factory for creating socket servers * tcp_socket_server.py: TCP socket server implementation * unix_socket_server.py: Unix socket server implementation * connection_info_writer.py: Connection info file writer - Update astrbot-cli client to support both socket types - Fix logger imports to use AstrBot's custom logger system This enables CLI platform to work on Windows (using TCP Socket) while maintaining Unix Socket support on Linux/Mac.
1 parent 6e2e9b0 commit 4aec9ef

8 files changed

Lines changed: 1371 additions & 52 deletions

File tree

astrbot-cli

Lines changed: 143 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#!/usr/bin/env python3
22
"""
3-
AstrBot CLI Tool - Unix Socket客户端
3+
AstrBot CLI Tool - 跨平台Socket客户端
4+
5+
支持Unix Socket和TCP Socket连接
46
57
用法:
68
astrbot-cli "你好"
@@ -14,6 +16,7 @@ import os
1416
import socket
1517
import sys
1618
import uuid
19+
from typing import Optional
1720

1821

1922
def get_data_path() -> str:
@@ -50,22 +53,129 @@ def load_auth_token() -> str:
5053
return ""
5154

5255

56+
def load_connection_info(data_dir: str) -> Optional[dict]:
57+
"""加载连接信息
58+
59+
从.cli_connection文件读取Socket连接信息
60+
61+
Args:
62+
data_dir: 数据目录路径
63+
64+
Returns:
65+
连接信息字典,如果文件不存在则返回None
66+
67+
Example:
68+
Unix Socket: {"type": "unix", "path": "/tmp/astrbot.sock"}
69+
TCP Socket: {"type": "tcp", "host": "127.0.0.1", "port": 12345}
70+
"""
71+
connection_file = os.path.join(data_dir, ".cli_connection")
72+
try:
73+
with open(connection_file, encoding="utf-8") as f:
74+
connection_info = json.load(f)
75+
return connection_info
76+
except FileNotFoundError:
77+
return None
78+
except json.JSONDecodeError as e:
79+
print(
80+
f"[ERROR] Invalid JSON in connection file: {connection_file}",
81+
file=sys.stderr,
82+
)
83+
print(f"[ERROR] {e}", file=sys.stderr)
84+
return None
85+
except Exception as e:
86+
print(
87+
f"[ERROR] Failed to load connection info: {e}",
88+
file=sys.stderr,
89+
)
90+
return None
91+
92+
93+
def connect_to_server(
94+
connection_info: dict, timeout: float = 30.0
95+
) -> socket.socket:
96+
"""连接到服务器
97+
98+
根据连接信息类型选择Unix Socket或TCP Socket连接
99+
100+
Args:
101+
connection_info: 连接信息字典
102+
timeout: 超时时间(秒)
103+
104+
Returns:
105+
socket连接对象
106+
107+
Raises:
108+
ValueError: 无效的连接类型
109+
ConnectionError: 连接失败
110+
"""
111+
socket_type = connection_info.get("type")
112+
113+
if socket_type == "unix":
114+
# Unix Socket连接
115+
socket_path = connection_info.get("path")
116+
if not socket_path:
117+
raise ValueError("Unix socket path is missing in connection info")
118+
119+
try:
120+
client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
121+
client_socket.settimeout(timeout)
122+
client_socket.connect(socket_path)
123+
return client_socket
124+
except FileNotFoundError:
125+
raise ConnectionError(
126+
f"Socket file not found: {socket_path}. Is AstrBot running?"
127+
)
128+
except ConnectionRefusedError:
129+
raise ConnectionError(
130+
"Connection refused. Is AstrBot running in socket mode?"
131+
)
132+
except Exception as e:
133+
raise ConnectionError(f"Unix socket connection error: {e}")
134+
135+
elif socket_type == "tcp":
136+
# TCP Socket连接
137+
host = connection_info.get("host")
138+
port = connection_info.get("port")
139+
if not host or not port:
140+
raise ValueError("TCP host or port is missing in connection info")
141+
142+
try:
143+
client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
144+
client_socket.settimeout(timeout)
145+
client_socket.connect((host, port))
146+
return client_socket
147+
except ConnectionRefusedError:
148+
raise ConnectionError(
149+
f"Connection refused to {host}:{port}. Is AstrBot running?"
150+
)
151+
except socket.timeout:
152+
raise ConnectionError(f"Connection timeout to {host}:{port}")
153+
except Exception as e:
154+
raise ConnectionError(f"TCP socket connection error: {e}")
155+
156+
else:
157+
raise ValueError(
158+
f"Invalid socket type: {socket_type}. Expected 'unix' or 'tcp'"
159+
)
160+
161+
53162
def send_message(
54163
message: str, socket_path: str | None = None, timeout: float = 30.0
55164
) -> dict:
56165
"""发送消息到AstrBot并获取响应
57166
167+
支持自动检测连接类型(Unix Socket或TCP Socket)
168+
58169
Args:
59170
message: 要发送的消息
60-
socket_path: Unix socket路径(默认使用临时目录下的astrbot.sock)
171+
socket_path: Unix socket路径(仅用于向后兼容,优先使用.cli_connection)
61172
timeout: 超时时间(秒)
62173
63174
Returns:
64175
响应字典
65176
"""
66-
# 使用默认socket路径
67-
if socket_path is None:
68-
socket_path = os.path.join(get_temp_path(), "astrbot.sock")
177+
# [ENTRY] send_message
178+
data_dir = get_data_path()
69179

70180
# 加载认证token
71181
auth_token = load_auth_token()
@@ -77,30 +187,33 @@ def send_message(
77187
if auth_token:
78188
request["auth_token"] = auth_token
79189

80-
# 连接到Unix socket
190+
# [PROCESS] 尝试加载连接信息
191+
connection_info = load_connection_info(data_dir)
192+
193+
# 连接到服务器
81194
try:
82-
client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
83-
client_socket.settimeout(timeout)
84-
client_socket.connect(socket_path)
85-
except FileNotFoundError:
86-
return {
87-
"status": "error",
88-
"error": f"Socket file not found: {socket_path}. Is AstrBot running?",
89-
}
90-
except ConnectionRefusedError:
91-
return {
92-
"status": "error",
93-
"error": "Connection refused. Is AstrBot running in socket mode?",
94-
}
195+
if connection_info is not None:
196+
# [PROCESS] 使用连接信息文件
197+
client_socket = connect_to_server(connection_info, timeout)
198+
else:
199+
# [PROCESS] 向后兼容:使用默认Unix Socket路径
200+
if socket_path is None:
201+
socket_path = os.path.join(get_temp_path(), "astrbot.sock")
202+
203+
fallback_info = {"type": "unix", "path": socket_path}
204+
client_socket = connect_to_server(fallback_info, timeout)
205+
206+
except (ValueError, ConnectionError) as e:
207+
return {"status": "error", "error": str(e)}
95208
except Exception as e:
96209
return {"status": "error", "error": f"Connection error: {e}"}
97210

98211
try:
99-
# 发送请求
212+
# [PROCESS] 发送请求
100213
request_data = json.dumps(request, ensure_ascii=False).encode("utf-8")
101214
client_socket.sendall(request_data)
102215

103-
# 接收响应(循环接收所有数据,支持大响应如base64图片)
216+
# [PROCESS] 接收响应(循环接收所有数据,支持大响应如base64图片)
104217
response_data = b""
105218
while True:
106219
chunk = client_socket.recv(4096)
@@ -110,18 +223,22 @@ def send_message(
110223
# 尝试解析JSON,如果成功说明接收完整
111224
try:
112225
response = json.loads(response_data.decode("utf-8"))
226+
# [EXIT] send_message success
113227
return response
114228
except json.JSONDecodeError:
115229
# JSON不完整,继续接收
116230
continue
117231

118232
# 如果循环结束仍未成功解析,尝试最后一次
119233
response = json.loads(response_data.decode("utf-8"))
234+
# [EXIT] send_message success
120235
return response
121236

122237
except TimeoutError:
238+
# [ERROR] Request timeout
123239
return {"status": "error", "error": "Request timeout"}
124240
except Exception as e:
241+
# [ERROR] Communication error
125242
return {"status": "error", "error": f"Communication error: {e}"}
126243
finally:
127244
client_socket.close()
@@ -130,14 +247,18 @@ def send_message(
130247
def main():
131248
"""主函数"""
132249
parser = argparse.ArgumentParser(
133-
description="AstrBot CLI Tool - Send messages to AstrBot via Unix Socket",
250+
description="AstrBot CLI Tool - Send messages to AstrBot (Unix Socket or TCP Socket)",
134251
formatter_class=argparse.RawDescriptionHelpFormatter,
135252
epilog="""
136253
Examples:
137254
astrbot-cli "你好"
138255
astrbot-cli "/help"
139256
astrbot-cli --socket /tmp/custom.sock "测试消息"
140257
echo "你好" | astrbot-cli
258+
259+
Connection:
260+
Automatically detects connection type from .cli_connection file.
261+
Falls back to default Unix Socket if file not found.
141262
""",
142263
)
143264

astrbot/core/platform/sources/cli/cli_adapter.py

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
from ...register import register_platform_adapter
2929
from .cli_event import CLIMessageEvent
30+
from .connection_info_writer import write_connection_info
31+
from .platform_detector import detect_platform
32+
from .socket_factory import create_socket_server
3033

3134

3235
@register_platform_adapter(
@@ -36,7 +39,10 @@
3639
"type": "cli",
3740
"enable": False, # 默认关闭,开发时手动启用
3841
"mode": "socket", # 默认使用Socket模式
39-
"socket_path": None, # None表示使用动态路径(temp_dir/astrbot.sock)
42+
"socket_type": "auto", # Socket类型: "auto"(自动检测) | "unix" | "tcp"
43+
"socket_path": None, # Unix Socket路径,None表示使用动态路径
44+
"tcp_host": "127.0.0.1", # TCP Socket监听地址
45+
"tcp_port": 0, # TCP Socket监听端口,0表示随机端口
4046
"whitelist": [], # 空白名单表示允许所有
4147
"use_isolated_sessions": False, # 是否启用会话隔离(每个请求独立会话)
4248
"session_ttl": 30, # 会话过期时间(秒),仅在use_isolated_sessions=True时生效,测试用30秒,生产建议1800秒(30分钟)
@@ -113,10 +119,13 @@ def __init__(
113119
)
114120
self.poll_interval = platform_config.get("poll_interval", 1.0)
115121

116-
# Unix Socket配置
122+
# Socket配置(跨平台)
123+
self.socket_type = platform_config.get("socket_type", "auto")
117124
self.socket_path = platform_config.get(
118125
"socket_path", os.path.join(get_astrbot_temp_path(), "astrbot.sock")
119126
)
127+
self.tcp_host = platform_config.get("tcp_host", "127.0.0.1")
128+
self.tcp_port = platform_config.get("tcp_port", 0)
120129

121130
# Token认证配置
122131
self.auth_token = self._ensure_auth_token()
@@ -342,56 +351,63 @@ async def _run_file_mode(self) -> None:
342351
logger.info("[EXIT] CLIPlatformAdapter._run_file_mode return=None")
343352

344353
async def _run_socket_mode(self) -> None:
345-
"""Unix Socket服务器模式
354+
"""跨平台Socket服务器模式
346355
347356
管道流程:
348-
客户端连接接收JSON请求解析消息创建事件等待响应 → 返回JSON
357+
平台检测创建Socket服务器写入连接信息接受连接处理请求
349358
"""
350-
import os
351-
import socket
359+
logger.info("[ENTRY] _run_socket_mode inputs={}")
352360

353361
self._running = True
354362

355-
# 删除旧的socket文件
356-
if os.path.exists(self.socket_path):
357-
os.remove(self.socket_path)
358-
logger.info(f"[PROCESS] Removed old socket file: {self.socket_path}")
359-
360-
# 创建Unix socket
361-
server_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
362-
server_socket.bind(self.socket_path)
363+
# 检测平台信息
364+
platform_info = detect_platform()
365+
logger.info(
366+
"[PROCESS] Platform detected: os=%s, python=%s, unix_socket=%s",
367+
platform_info.os_type,
368+
platform_info.python_version,
369+
platform_info.supports_unix_socket,
370+
)
363371

364-
# 设置严格权限(仅所有者可访问)
365-
os.chmod(self.socket_path, 0o600)
366-
logger.info(f"[SECURITY] Socket permissions set to 600: {self.socket_path}")
372+
# 创建Socket服务器(工厂模式)
373+
config = {
374+
"socket_type": self.socket_type,
375+
"socket_path": self.socket_path,
376+
"tcp_host": self.tcp_host,
377+
"tcp_port": self.tcp_port,
378+
}
379+
server = create_socket_server(platform_info, config, self.auth_token)
380+
logger.info("[PROCESS] Socket server created: %s", type(server).__name__)
367381

368-
server_socket.listen(5)
369-
server_socket.setblocking(False)
382+
try:
383+
# 启动服务器
384+
await server.start()
385+
logger.info("[PROCESS] Socket server started")
370386

371-
logger.info(f"[PROCESS] Unix Socket server started: {self.socket_path}")
387+
# 写入连接信息供客户端读取
388+
connection_info = server.get_connection_info()
389+
write_connection_info(connection_info, get_astrbot_data_path())
390+
logger.info("[PROCESS] Connection info written: %s", connection_info)
372391

373-
try:
392+
# 接受连接循环
374393
while self._running:
375394
try:
376-
# 接受连接(非阻塞)
377-
loop = asyncio.get_running_loop()
378-
client_socket, _ = await loop.sock_accept(server_socket)
395+
client_socket, client_addr = await server.accept_connection()
396+
logger.debug("[PROCESS] Client connected: %s", client_addr)
379397

380398
# 处理连接(异步)
381399
asyncio.create_task(self._handle_socket_client(client_socket))
382400

383401
except Exception as e:
384-
logger.error(f"[ERROR] Socket accept error: {e}")
402+
logger.error("[ERROR] Socket accept error: %s", e)
385403
await asyncio.sleep(0.1)
386404

387405
except Exception as e:
388-
logger.error(f"[ERROR] Socket mode error: {e}")
406+
logger.error("[ERROR] Socket mode error: %s", e)
389407
finally:
390408
self._running = False
391-
server_socket.close()
392-
if os.path.exists(self.socket_path):
393-
os.remove(self.socket_path)
394-
logger.info("[EXIT] CLIPlatformAdapter._run_socket_mode return=None")
409+
await server.stop()
410+
logger.info("[EXIT] _run_socket_mode return=None")
395411

396412
async def _handle_socket_client(self, client_socket) -> None:
397413
"""[原子模块] SocketHandler: 处理单个socket客户端连接

0 commit comments

Comments
 (0)