This repository was archived by the owner on Jan 23, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathclient.py
More file actions
302 lines (255 loc) · 11.4 KB
/
client.py
File metadata and controls
302 lines (255 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import os
import shlex
import subprocess
import tempfile
from dataclasses import dataclass
from urllib.parse import urlparse
import click
from jumpstarter_driver_composite.client import CompositeClient
from jumpstarter_driver_network.adapters import TcpPortforwardAdapter
from jumpstarter.client.core import DriverMethodNotImplemented
from jumpstarter.client.decorators import driver_click_command
@dataclass
class SSHCommandRunResult:
"""Result of executing an SSH command"""
return_code: int
stdout: str | bytes
stderr: str | bytes
@staticmethod
def from_completed_process(result: subprocess.CompletedProcess) -> "SSHCommandRunResult":
return SSHCommandRunResult(
return_code=result.returncode,
stdout=result.stdout or "",
stderr=result.stderr or "",
)
@dataclass
class SSHCommandRunOptions:
"""
Options for running an SSH command
Attributes:
direct: If True, connect directly to the host's TCP address.
If False, use SSH port forwarding.
capture_output: If True, capture stdout and stderr.
If False, they are inherited from the parent process.
capture_as_text: If True and output is captured, decode stdout and
stderr as text. Otherwise, they are captured as bytes.
"""
direct: bool = False
capture_output: bool = True
capture_as_text: bool = True
@dataclass(kw_only=True)
class SSHWrapperClient(CompositeClient):
"""
Client interface for SSHWrapper driver
This client provides methods to interact with SSH connections via CLI
"""
def cli(self):
@driver_click_command(
self,
context_settings={"ignore_unknown_options": True},
help="Run SSH command with arguments",
)
@click.option("--direct", is_flag=True, help="Use direct TCP address")
@click.argument("args", nargs=-1)
def ssh(direct, args):
options = SSHCommandRunOptions(
direct=direct,
# For the CLI, we never capture output so that interactive shells
# and long-running commands stream their output directly.
capture_output=False,
)
result = self.run(options, args)
self.logger.debug("SSH exit code: %s", result.return_code)
if result.stdout:
click.echo(result.stdout, nl=False)
if result.stderr:
click.echo(result.stderr, nl=False, err=True)
if result.return_code != 0:
click.get_current_context().exit(result.return_code)
return result.return_code
return ssh
# wrap the underlying tcp stream connections, so we can still use tcp forwarding or
# the fabric driver adapter on top of client.ssh
def stream(self, method="connect"):
return self.tcp.stream(method)
async def stream_async(self, method):
return await self.tcp.stream_async(method)
@property
def command(self) -> str:
"""Get the base SSH command"""
return self.call("get_ssh_command")
@property
def identity(self) -> str | None:
"""
Get the SSH identity (private key) as a string.
Returns:
The SSH identity key content, or None if not configured.
Raises:
ConfigurationError: If `ssh_identity_file` is configured on the
driver but cannot be read.
"""
return self.call("get_ssh_identity")
@property
def username(self) -> str:
"""Get the default SSH username"""
return self.call("get_default_username")
def run(self, options: SSHCommandRunOptions, args) -> SSHCommandRunResult:
"""Run SSH command with the given parameters and arguments"""
# Get SSH command and default username from driver
if options.direct:
# Use direct TCP address
try:
address = self.tcp.address() # (format: "tcp://host:port")
parsed = urlparse(address)
host = parsed.hostname
port = parsed.port
if not host or not port:
raise ValueError(f"Invalid address format: {address}")
self.logger.debug("Using direct TCP connection for SSH - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, options, args)
except (DriverMethodNotImplemented, ValueError) as e:
self.logger.error("Direct address connection failed (%s), falling back to SSH port forwarding", e)
return self.run(SSHCommandRunOptions(
direct=False,
capture_output=options.capture_output,
capture_as_text=options.capture_as_text,
), args)
else:
# Use SSH port forwarding (default behavior)
self.logger.debug("Using SSH port forwarding for SSH connection")
with TcpPortforwardAdapter(
client=self.tcp,
) as addr:
host, port = addr
self.logger.debug("SSH port forward established - host: %s, port: %s", host, port)
return self._run_ssh_local(host, port, options, args)
def _run_ssh_local(self, host, port, options, args):
"""Run SSH command with the given host, port, and arguments"""
# Create temporary identity file if needed
ssh_identity = self.identity
identity_file = None
temp_file = None
if ssh_identity:
try:
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_ssh_key')
temp_file.write(ssh_identity)
temp_file.close()
# Set proper permissions (600) for SSH key
os.chmod(temp_file.name, 0o600)
identity_file = temp_file.name
self.logger.debug("Created temporary identity file: %s", identity_file)
except Exception as e:
self.logger.error("Failed to create temporary identity file: %s", e)
if temp_file:
try:
os.unlink(temp_file.name)
except Exception:
pass
raise
try:
# Build SSH command arguments
ssh_args = self._build_ssh_command_args(port, identity_file, args)
# Separate SSH options from command arguments
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)
# Build final SSH command
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)
# Execute the command
return self._execute_ssh_command(ssh_args, options)
finally:
# Clean up temporary identity file
if identity_file:
try:
os.unlink(identity_file)
self.logger.debug("Cleaned up temporary identity file: %s", identity_file)
except Exception as e:
self.logger.warning("Failed to clean up temporary identity file %s: %s", identity_file, str(e))
def _build_ssh_command_args(self, port, identity_file, args):
"""Build initial SSH command arguments"""
# Split the SSH command into individual arguments
ssh_args = shlex.split(self.command)
default_username = self.username
# Add identity file if provided
if identity_file:
ssh_args.extend(["-i", identity_file])
# Add port if specified
if port and port != 22:
ssh_args.extend(["-p", str(port)])
# Check if user already provided a username with -l flag in SSH options only
# We need to separate SSH options from command args first to avoid false positives
ssh_options, _ = self._separate_ssh_options_and_command_args(args)
has_user_flag = any(
ssh_options[i] == "-l" and i + 1 < len(ssh_options)
for i in range(len(ssh_options))
)
# Add default username if no -l flag provided and we have a default
if not has_user_flag and default_username:
ssh_args.extend(["-l", default_username])
return ssh_args
def _separate_ssh_options_and_command_args(self, args):
"""Separate SSH options from command arguments"""
# SSH flags that do not expect a parameter (simple flags)
ssh_flags_no_param = {
'-4', '-6', '-A', '-a', '-C', '-f', '-G', '-g', '-K', '-k', '-M', '-N',
'-n', '-q', '-s', '-T', '-t', '-V', '-v', '-X', '-x', '-Y', '-y'
}
# SSH flags that do expect a parameter
ssh_flags_with_param = {
'-B', '-b', '-c', '-D', '-E', '-e', '-F', '-I', '-i', '-J', '-L', '-l',
'-m', '-O', '-o', '-P', '-p', '-Q', '-R', '-S', '-W', '-w'
}
ssh_options = []
command_args = []
i = 0
while i < len(args):
arg = args[i]
if arg.startswith('-'):
# Check if it's a known SSH option
if arg in ssh_flags_no_param:
# This is a simple SSH flag without parameter
ssh_options.append(arg)
elif arg in ssh_flags_with_param:
# This is an SSH flag that expects a parameter
ssh_options.append(arg)
# If this option takes a value, add the next argument too
if i + 1 < len(args) and not args[i + 1].startswith('-'):
ssh_options.append(args[i + 1])
i += 1
else:
# This is a command argument - everything from here on is part of the command
command_args = args[i:]
break
else:
# This is a command argument - everything from here on is part of the command
command_args = args[i:]
break
i += 1
# Debug output
self.logger.debug("SSH options: %s", ssh_options)
self.logger.debug("Command args: %s", command_args)
return ssh_options, command_args
def _build_final_ssh_command(self, ssh_args, ssh_options, host, command_args):
"""Build the final SSH command with all components"""
# Add SSH options
ssh_args.extend(ssh_options)
# Add hostname before command arguments
if host:
ssh_args.append(host)
# Add command arguments
ssh_args.extend(command_args)
self.logger.debug("Running SSH command: %s", ssh_args)
return ssh_args
def _execute_ssh_command(self, ssh_args, options: SSHCommandRunOptions) -> SSHCommandRunResult:
"""Execute the SSH command and return the result"""
try:
result = subprocess.run(ssh_args, capture_output=options.capture_output, text=options.capture_as_text)
return SSHCommandRunResult.from_completed_process(result)
except FileNotFoundError:
self.logger.error(
"SSH command '%s' not found. Please ensure SSH is installed and available in PATH.",
ssh_args[0],
)
return SSHCommandRunResult(
return_code=127, # Standard exit code for "command not found"
stdout="",
stderr=f"SSH command '{ssh_args[0]}' not found",
)