Skip to content

Commit c96ae32

Browse files
committed
mpssh: Share stdin
1 parent 1d1edec commit c96ae32

1 file changed

Lines changed: 20 additions & 4 deletions

File tree

mpssh

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,20 @@ def escape_shell_command(command: str) -> str:
7979

8080

8181
async def ssh_connect_and_run_command(
82-
host: str, command: str, thread_color: str, raw: bool = False, timeout: int = 30
82+
host: str, command: str, thread_color: str, raw: bool = False, timeout: int = 30, stdin_data: Optional[bytes] = None
8383
) -> SSHResult:
8484
try:
8585
escaped_command = escape_shell_command(command)
8686
ssh_command = f"ssh -o ConnectTimeout={timeout} {host} '{escaped_command}'"
8787

8888
process = await asyncio.create_subprocess_shell(
89-
ssh_command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
89+
ssh_command,
90+
stdout=asyncio.subprocess.PIPE,
91+
stderr=asyncio.subprocess.PIPE,
92+
stdin=asyncio.subprocess.PIPE if stdin_data is not None else None
9093
)
9194

92-
stdout, stderr = (x.decode("utf-8") for x in await process.communicate())
95+
stdout, stderr = (x.decode("utf-8") for x in await process.communicate(input=stdin_data))
9396

9497
if stdout:
9598
for line in stdout.splitlines():
@@ -169,6 +172,9 @@ async def main() -> int:
169172
help="The JSON file containing the list of hosts",
170173
default=default_json_file,
171174
)
175+
parser.add_argument(
176+
"--stdin", action="store_true", help="Read stdin and pass it to the remote command"
177+
)
172178
parser.add_argument(
173179
"hostgroup", type=str, help="The hostgroup to run the command on"
174180
)
@@ -209,6 +215,16 @@ async def main() -> int:
209215
logging.error(f"Hostgroup '{args.hostgroup}' not found in JSON file")
210216
return 1
211217

218+
stdin_data = None
219+
if args.stdin:
220+
stdin_data = sys.stdin.buffer.read()
221+
else:
222+
if not sys.stdin.isatty():
223+
print('123')
224+
sys.stdin.read()
225+
print('456')
226+
227+
212228
thread_colors = [Fore.GREEN, Fore.YELLOW, Fore.BLUE, Fore.MAGENTA, Fore.CYAN]
213229
tasks = []
214230

@@ -221,7 +237,7 @@ async def main() -> int:
221237
thread_color = thread_colors[i % len(thread_colors)]
222238
task = asyncio.create_task(
223239
ssh_connect_and_run_command(
224-
host, " ".join(args.command), thread_color, args.raw, args.timeout
240+
host, " ".join(args.command), thread_color, args.raw, args.timeout, stdin_data
225241
)
226242
)
227243
tasks.append(task)

0 commit comments

Comments
 (0)