Files
pobsync/src/pobsync_backend/ssh_keys.py
Peter van Arkel d3ffca1843 (feature) Add host key scanning for SSH credentials
Add a host detail action that scans the target SSH host key with
ssh-keyscan and stores it on the selected SSH credential.

Merge scanned known_hosts entries without duplicates and let the
existing runtime config pass them through as UserKnownHostsFile for
unattended rsync over SSH.

Extend host checks to warn when the selected credential has no known_hosts
entries, making host key verification failures actionable from Django.
2026-05-19 19:55:40 +02:00

146 lines
4.6 KiB
Python

from __future__ import annotations
import os
import shutil
import subprocess
from pathlib import Path
from django.conf import settings
from .models import SshCredential
class SshKeyError(RuntimeError):
pass
def credential_dir(credential: SshCredential) -> Path:
return Path(settings.POBSYNC_HOME) / "state" / "ssh-credentials" / str(credential.pk)
def identity_path(credential: SshCredential) -> Path:
if credential.key_path:
return Path(credential.key_path)
return credential_dir(credential) / "identity"
def generate_ssh_key(credential: SshCredential, *, key_type: str = "ed25519", force: bool = False) -> SshCredential:
if credential.pk is None:
raise SshKeyError("Credential must be saved before generating an SSH key.")
if shutil.which("ssh-keygen") is None:
raise SshKeyError("ssh-keygen is not available.")
key_dir = credential_dir(credential)
key_dir.mkdir(mode=0o700, parents=True, exist_ok=True)
os.chmod(key_dir, 0o700)
private_key = key_dir / "identity"
public_key_file = key_dir / "identity.pub"
if force:
private_key.unlink(missing_ok=True)
public_key_file.unlink(missing_ok=True)
elif private_key.exists() or public_key_file.exists():
raise SshKeyError(f"SSH key already exists for {credential.name}.")
result = subprocess.run(
[
"ssh-keygen",
"-t",
key_type,
"-N",
"",
"-C",
f"pobsync:{credential.name}",
"-f",
str(private_key),
],
check=False,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=15,
)
if result.returncode != 0:
raise SshKeyError(result.stderr.strip() or "ssh-keygen failed.")
os.chmod(private_key, 0o600)
public_key = public_key_file.read_text(encoding="utf-8").strip()
fingerprint = fingerprint_for_key(private_key)
credential.private_key = ""
credential.public_key = public_key
credential.key_path = str(private_key)
credential.key_type = key_type
credential.fingerprint = fingerprint
credential.generated = True
credential.save(update_fields=["private_key", "public_key", "key_path", "key_type", "fingerprint", "generated", "updated_at"])
return credential
def fingerprint_for_key(private_key: Path) -> str:
result = subprocess.run(
["ssh-keygen", "-lf", str(private_key)],
check=False,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=5,
)
if result.returncode != 0:
raise SshKeyError(result.stderr.strip() or "Could not fingerprint SSH key.")
return result.stdout.strip()
def delete_generated_key_files(credential: SshCredential) -> None:
path = identity_path(credential)
allowed_root = (Path(settings.POBSYNC_HOME) / "state" / "ssh-credentials").resolve()
try:
resolved = path.resolve()
except FileNotFoundError:
resolved = path
if allowed_root not in resolved.parents:
raise SshKeyError(f"Refusing to delete key outside {allowed_root}.")
path.unlink(missing_ok=True)
path.with_suffix(path.suffix + ".pub").unlink(missing_ok=True)
if path.name == "identity":
(path.parent / "identity.pub").unlink(missing_ok=True)
def scan_known_host(address: str, *, port: int = 22, timeout: int = 5) -> str:
if shutil.which("ssh-keyscan") is None:
raise SshKeyError("ssh-keyscan is not available.")
command = ["ssh-keyscan", "-T", str(timeout), "-p", str(port), address]
result = subprocess.run(
command,
check=False,
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
timeout=timeout + 2,
)
if result.returncode != 0 and not result.stdout.strip():
raise SshKeyError(result.stderr.strip() or f"Could not scan SSH host key for {address}.")
lines = [line.strip() for line in result.stdout.splitlines() if line.strip() and not line.startswith("#")]
if not lines:
raise SshKeyError(f"ssh-keyscan returned no host keys for {address}.")
return "\n".join(lines)
def merge_known_hosts(existing: str, scanned: str) -> str:
lines: list[str] = []
seen: set[str] = set()
for line in [*existing.splitlines(), *scanned.splitlines()]:
normalized = line.strip()
if not normalized or normalized in seen:
continue
seen.add(normalized)
lines.append(normalized)
return "\n".join(lines) + ("\n" if lines else "")