Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions poetry.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[virtualenvs]
in-project = true
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ dataclasses-json = "*"
pyyaml = "*"
prompt_toolkit = "*"
pexpect = "*"
cryptography = "<43"

[tool.poetry.scripts]
sshg = "sshg:main"

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
pytest-cov = "^2"

Expand Down
48 changes: 46 additions & 2 deletions sshg.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ class HostConfig(DataClassJsonMixin):
via: typing.Optional["HostConfig"] = make_field(mm_field=fields.Field(), default=None)
_parent: typing.Optional["HostConfig"] = make_field(mm_field=fields.Field(), default=None, init=False, repr=False)

def get_password(self) -> str:
if isinstance(self.password, int):
return str(self.password)
return self.password or ""

def post_load(self):
if self._parent:
if not self.user:
Expand Down Expand Up @@ -151,6 +156,43 @@ def output_filter(line):
s.interact()


def _resolve_key_passphrase(keypath: pathlib.Path) -> str:
"""Try loading key without passphrase, prompt until correct if needed."""
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_ssh_private_key

key_data = keypath.read_bytes()
Comment on lines +159 to +163

def _try_load(password: typing.Optional[bytes]) -> typing.Optional[bool]:
"""Return True if the key loads, False if a passphrase is required/incorrect, None if the key is invalid."""
errors: typing.List[Exception] = []
for loader in (load_ssh_private_key, load_pem_private_key):
try:
loader(key_data, password=password)
return True
except (TypeError, ValueError) as e:
errors.append(e)

msg = " ".join(str(e).lower() for e in errors)
if any(k in msg for k in ("password", "passphrase", "bad decrypt", "incorrect")):
return False
return None

res = _try_load(None)
if res is None:
raise ValueError(f"Unsupported or invalid private key: {keypath}")
if res:
return ""

while True:
password = getpass.getpass(f"Enter passphrase for key {keypath}: ")
res = _try_load(password.encode())
if res is None:
raise ValueError(f"Unsupported or invalid private key: {keypath}")
if res:
return password
print("Wrong passphrase, try again.")


def spawn_ssh(host_config: HostConfig, is_local: bool = True, ssh_client: pxssh.pxssh = None, reset_prompt: bool = None) -> pxssh.pxssh:
# https://pexpect.readthedocs.io/en/stable/api/pxssh.html
cmdargs = host_config.build_cmdargs()
Expand All @@ -164,6 +206,8 @@ def spawn_ssh(host_config: HostConfig, is_local: bool = True, ssh_client: pxssh.
if keypath.stat().st_mode & 0o077 != 0:
print("Warning: keypath mode change to 0600")
keypath.chmod(0o600)
if not host_config.get_password():
host_config.password = _resolve_key_passphrase(keypath)

s.SSH_OPTS += " -o StrictHostKeyChecking=no"
if reset_prompt is None:
Expand All @@ -172,7 +216,7 @@ def spawn_ssh(host_config: HostConfig, is_local: bool = True, ssh_client: pxssh.
if is_local:
s.login(host_config.host,
username=host_config.user,
password=host_config.password,
password=host_config.get_password(),
port=host_config.port,
ssh_key=keypath,
quiet=False,
Expand All @@ -182,7 +226,7 @@ def spawn_ssh(host_config: HostConfig, is_local: bool = True, ssh_client: pxssh.
else:
s.login(host_config.host,
username=host_config.user,
password=host_config.password,
password=host_config.get_password(),
port=host_config.port,
ssh_key=keypath,
quiet=False,
Expand Down
Loading