diff --git a/src/pobsync_backend/host_ops.py b/src/pobsync_backend/host_ops.py index 557dd2c..32fe0fb 100644 --- a/src/pobsync_backend/host_ops.py +++ b/src/pobsync_backend/host_ops.py @@ -62,6 +62,17 @@ def collect_host_checks(host: HostConfig, global_config: GlobalConfig | None = N "Generated filesystem keys are recommended for native systemd installs.", ) ) + if credential.known_hosts.strip(): + checks.append(SelfCheck("Host known_hosts", "ok", "Selected credential has known_hosts entries.")) + else: + checks.append( + SelfCheck( + "Host known_hosts", + "warning", + "Selected credential has no known_hosts entries.", + "Use Scan SSH host key before queueing unattended backups.", + ) + ) host_root = resolve_host_root(global_config.backup_root, host.host) checks.append(_host_path_check("Host backup root", host_root, must_exist=True, must_be_writable=True)) diff --git a/src/pobsync_backend/ssh_keys.py b/src/pobsync_backend/ssh_keys.py index 8aa50d2..01e9141 100644 --- a/src/pobsync_backend/ssh_keys.py +++ b/src/pobsync_backend/ssh_keys.py @@ -108,3 +108,38 @@ def delete_generated_key_files(credential: SshCredential) -> None: path.with_suffix(path.suffix + ".pub").unlink(missing_ok=True) if path.name == "identity": (path.parent / "identity.pub").unlink(missing_ok=True) + + +def scan_known_host(address: str, *, port: int = 22, timeout: int = 5) -> str: + if shutil.which("ssh-keyscan") is None: + raise SshKeyError("ssh-keyscan is not available.") + + command = ["ssh-keyscan", "-T", str(timeout), "-p", str(port), address] + result = subprocess.run( + command, + check=False, + stdin=subprocess.DEVNULL, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=timeout + 2, + ) + if result.returncode != 0 and not result.stdout.strip(): + raise SshKeyError(result.stderr.strip() or f"Could not scan SSH host key for {address}.") + + lines = [line.strip() for line in result.stdout.splitlines() if line.strip() and not line.startswith("#")] + if not lines: + raise SshKeyError(f"ssh-keyscan returned no host keys for {address}.") + return "\n".join(lines) + + +def merge_known_hosts(existing: str, scanned: str) -> str: + lines: list[str] = [] + seen: set[str] = set() + for line in [*existing.splitlines(), *scanned.splitlines()]: + normalized = line.strip() + if not normalized or normalized in seen: + continue + seen.add(normalized) + lines.append(normalized) + return "\n".join(lines) + ("\n" if lines else "") diff --git a/src/pobsync_backend/templates/pobsync_backend/host_detail.html b/src/pobsync_backend/templates/pobsync_backend/host_detail.html index 43a2d5a..8a25f97 100644 --- a/src/pobsync_backend/templates/pobsync_backend/host_detail.html +++ b/src/pobsync_backend/templates/pobsync_backend/host_detail.html @@ -17,6 +17,10 @@ {% csrf_token %} +
+ {% csrf_token %} + +
diff --git a/src/pobsync_backend/tests/test_ssh_credentials.py b/src/pobsync_backend/tests/test_ssh_credentials.py index a1fd728..cac8f9e 100644 --- a/src/pobsync_backend/tests/test_ssh_credentials.py +++ b/src/pobsync_backend/tests/test_ssh_credentials.py @@ -10,6 +10,7 @@ from django.test import SimpleTestCase, TestCase, override_settings from pobsync_backend.forms import normalize_private_key, validate_ssh_private_key from pobsync_backend.models import GlobalConfig, SshCredential +from pobsync_backend.ssh_keys import merge_known_hosts class SshCredentialValidationTests(SimpleTestCase): @@ -60,3 +61,14 @@ class SshCredentialManagementTests(TestCase): global_config.refresh_from_db() self.assertEqual(global_config.default_ssh_credential.name, "default") + + def test_merge_known_hosts_appends_unique_entries(self) -> None: + merged = merge_known_hosts( + "web-01.example.test ssh-ed25519 AAAAOLD\n", + "web-01.example.test ssh-ed25519 AAAAOLD\nweb-01.example.test ssh-rsa AAAANEW\n", + ) + + self.assertEqual( + merged, + "web-01.example.test ssh-ed25519 AAAAOLD\nweb-01.example.test ssh-rsa AAAANEW\n", + ) diff --git a/src/pobsync_backend/tests/test_views.py b/src/pobsync_backend/tests/test_views.py index ad7a41c..1efcbb0 100644 --- a/src/pobsync_backend/tests/test_views.py +++ b/src/pobsync_backend/tests/test_views.py @@ -522,6 +522,24 @@ class ViewTests(TestCase): self.assertTrue((backup_root / host.host / "manual").is_dir()) self.assertTrue((backup_root / host.host / ".incomplete").is_dir()) + def test_scan_host_known_key_action_updates_selected_credential(self) -> None: + self.client.force_login(self.staff_user) + credential = SshCredential.objects.create(name="default-key", key_path="/var/lib/pobsync/state/ssh-credentials/1/identity") + GlobalConfig.objects.create(name="default", backup_root="/backups", default_ssh_credential=credential, ssh_port=2222) + host = HostConfig.objects.create(host="web-01", address="web-01.example.test") + + with patch( + "pobsync_backend.views.scan_known_host", + return_value="web-01.example.test ssh-ed25519 AAAASCANNED", + ) as scan: + response = self.client.post(reverse("scan_host_known_key", args=[host.host]), follow=True) + + self.assertRedirects(response, reverse("host_detail", args=[host.host])) + self.assertContains(response, "Stored SSH host key for web-01") + scan.assert_called_once_with("web-01.example.test", port=2222) + credential.refresh_from_db() + self.assertEqual(credential.known_hosts, "web-01.example.test ssh-ed25519 AAAASCANNED\n") + def test_host_detail_surfaces_active_backup_run(self) -> None: self.client.force_login(self.staff_user) GlobalConfig.objects.create(name="default", backup_root="/backups") diff --git a/src/pobsync_backend/views.py b/src/pobsync_backend/views.py index 505e302..8fdd794 100644 --- a/src/pobsync_backend/views.py +++ b/src/pobsync_backend/views.py @@ -30,7 +30,7 @@ from .models import BackupRun, GlobalConfig, HostConfig, ScheduleConfig, Snapsho from .retention import run_sql_retention_apply, run_sql_retention_plan from .self_check import collect_self_checks, summarize_self_checks from .snapshot_discovery import discover_snapshots, inspect_snapshot_discovery -from .ssh_keys import SshKeyError, delete_generated_key_files, generate_ssh_key +from .ssh_keys import SshKeyError, delete_generated_key_files, generate_ssh_key, merge_known_hosts, scan_known_host @staff_member_required @@ -289,6 +289,28 @@ def prepare_host_directories(request, host: str): return redirect("host_detail", host=host_config.host) +@staff_member_required +@require_POST +def scan_host_known_key(request, host: str): + host_config = get_object_or_404(HostConfig, host=host) + global_config = GlobalConfig.objects.filter(name="default").first() + credential = host_config.ssh_credential or (global_config.default_ssh_credential if global_config else None) + if credential is None: + messages.error(request, f"No SSH credential is selected for {host_config.host}.") + return redirect("host_detail", host=host_config.host) + + port = host_config.ssh_port or (global_config.ssh_port if global_config else 22) + try: + scanned = scan_known_host(host_config.address, port=int(port or 22)) + except SshKeyError as exc: + messages.error(request, f"Could not scan SSH host key for {host_config.host}: {exc}") + else: + credential.known_hosts = merge_known_hosts(credential.known_hosts, scanned) + credential.save(update_fields=["known_hosts", "updated_at"]) + messages.success(request, f"Stored SSH host key for {host_config.host} on credential {credential.name}.") + return redirect("host_detail", host=host_config.host) + + @staff_member_required @require_POST def queue_manual_backup(request, host: str): diff --git a/src/pobsync_server/urls.py b/src/pobsync_server/urls.py index 4cfea68..9b91de3 100644 --- a/src/pobsync_server/urls.py +++ b/src/pobsync_server/urls.py @@ -20,6 +20,7 @@ urlpatterns = [ path("hosts//", views.host_detail, name="host_detail"), path("hosts//config/", views.edit_host_config, name="edit_host_config"), path("hosts//prepare-directories/", views.prepare_host_directories, name="prepare_host_directories"), + path("hosts//scan-known-key/", views.scan_host_known_key, name="scan_host_known_key"), path("hosts//queue-backup/", views.queue_manual_backup, name="queue_manual_backup"), path("hosts//discover-snapshots/", views.discover_host_snapshots, name="discover_host_snapshots"), path("hosts//retention-apply/", views.apply_host_retention, name="apply_host_retention"),