-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathmain.py
More file actions
213 lines (168 loc) · 7.52 KB
/
main.py
File metadata and controls
213 lines (168 loc) · 7.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#!/usr/bin/env python3
"""
Simple MCP streamable private gateway client example without authentication.
This client connects to an MCP server using streamable HTTP or SSE transport.
"""
import asyncio
import os
from typing import Any
import httpx
from mcp.client.session import ClientSession
from mcp.client.streamable_http import streamable_http_client
class ExtendedHttpxClient(httpx.AsyncClient):
"""Custom httpx AsyncClient with support for custom_extensions.
This class extends httpx.AsyncClient to add a custom_extensions attribute
that can be used to pass custom extensions to MCP requests.
"""
def __init__(self, *args: Any, custom_extensions: dict[str, str] | None = None, **kwargs: Any):
"""Initialize the extended httpx client.
Args:
*args: Positional arguments passed to httpx.AsyncClient
custom_extensions: Optional dict of extensions to include in requests
**kwargs: Keyword arguments passed to httpx.AsyncClient
"""
super().__init__(*args, **kwargs)
self.custom_extensions = custom_extensions or {}
class SimpleStreamablePrivateGateway:
"""Simple MCP streamable private gateway client without authentication."""
def __init__(self, server_url: str, server_hostname: str, transport_type: str = "streamable-http"):
self.server_url = server_url
self.server_hostname = server_hostname
self.transport_type = transport_type
self.session: ClientSession | None = None
async def connect(self):
"""Connect to the MCP server."""
print(f"🔗 Attempting to connect to {self.server_url}...")
try:
print("📡 Opening StreamableHTTP transport connection...")
# Note: terminate_on_close=False prevents SSL handshake failures during exit
# Some servers may not handle session termination gracefully over SSL
# Create custom httpx client with headers, timeout, and extensions
async with ExtendedHttpxClient(
headers={"Host": self.server_hostname},
timeout=httpx.Timeout(60.0),
custom_extensions={"sni_hostname": self.server_hostname},
) as custom_client:
async with streamable_http_client(
url=self.server_url,
http_client=custom_client,
terminate_on_close=False, # Skip session termination to avoid SSL errors
) as (read_stream, write_stream, get_session_id):
await self._run_session(read_stream, write_stream, get_session_id)
except Exception as e:
print(f"❌ Failed to connect: {e}")
import traceback
traceback.print_exc()
async def _run_session(self, read_stream, write_stream, get_session_id):
"""Run the MCP session with the given streams."""
print("🤝 Initializing MCP session...")
async with ClientSession(read_stream, write_stream) as session:
self.session = session
print("⚡ Starting session initialization...")
await session.initialize()
print("✨ Session initialization complete!")
print(f"\n✅ Connected to MCP server at {self.server_url}")
if get_session_id:
session_id = get_session_id()
if session_id:
print(f"Session ID: {session_id}")
# Run interactive loop
await self.interactive_loop()
async def list_tools(self):
"""List available tools from the server."""
if not self.session:
print("❌ Not connected to server")
return
try:
result = await self.session.list_tools()
if hasattr(result, "tools") and result.tools:
print("\n📋 Available tools:")
for i, tool in enumerate(result.tools, 1):
print(f"{i}. {tool.name}")
if tool.description:
print(f" Description: {tool.description}")
print()
else:
print("No tools available")
except Exception as e:
print(f"❌ Failed to list tools: {e}")
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None):
"""Call a specific tool."""
if not self.session:
print("❌ Not connected to server")
return
try:
result = await self.session.call_tool(tool_name, arguments or {})
print(f"\n🔧 Tool '{tool_name}' result:")
if hasattr(result, "content"):
for content in result.content:
if content.type == "text":
print(content.text)
else:
print(content)
else:
print(result)
except Exception as e:
print(f"❌ Failed to call tool '{tool_name}': {e}")
async def interactive_loop(self):
"""Run interactive command loop."""
print("\n🎯 Interactive Streamable Private Gateway")
print("Commands:")
print(" list - List available tools")
print(" call <tool_name> [args] - Call a tool")
print(" quit - Exit the client")
print()
while True:
try:
command = input("mcp> ").strip()
if not command:
continue
if command == "quit":
print("👋 Goodbye!")
break
elif command == "list":
await self.list_tools()
elif command.startswith("call "):
parts = command.split(maxsplit=2)
tool_name = parts[1] if len(parts) > 1 else ""
if not tool_name:
print("❌ Please specify a tool name")
continue
# Parse arguments (simple JSON-like format)
arguments = {}
if len(parts) > 2:
import json
try:
arguments = json.loads(parts[2])
except json.JSONDecodeError:
print("❌ Invalid arguments format (expected JSON)")
continue
await self.call_tool(tool_name, arguments)
else:
print("❌ Unknown command. Try 'list', 'call <tool_name>', or 'quit'")
except KeyboardInterrupt:
print("\n\n👋 Goodbye!")
break
except EOFError:
print("\n👋 Goodbye!")
break
async def main():
"""Main entry point."""
# Default server URL - can be overridden with environment variable
# Most MCP streamable HTTP servers use /mcp as the endpoint
server_port = os.getenv("MCP_SERVER_PORT", "8000")
server_hostname = os.getenv("MCP_SERVER_HOSTNAME", "localhost")
transport_type = "streamable-http"
server_url = f"https://localhost:{server_port}/mcp"
print("🚀 Simple Streamable Private Gateway")
print(f"Connecting to: {server_url}")
print(f"Server hostname: {server_hostname}")
print(f"Transport type: {transport_type}")
# Start connection flow
client = SimpleStreamablePrivateGateway(server_url, server_hostname, transport_type)
await client.connect()
def cli():
"""CLI entry point for uv script."""
asyncio.run(main())
if __name__ == "__main__":
cli()