Files
pobsync/src/pobsync_backend/tests/test_ssh_credentials.py

75 lines
3.2 KiB
Python
Raw Normal View History

from __future__ import annotations
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory
from django import forms
from django.core.management import call_command
from django.test import SimpleTestCase, TestCase, override_settings
from pobsync_backend.forms import normalize_private_key, validate_ssh_private_key
from pobsync_backend.models import GlobalConfig, SshCredential
from pobsync_backend.ssh_keys import merge_known_hosts
class SshCredentialValidationTests(SimpleTestCase):
def test_normalize_private_key_repairs_wrapped_openssh_body(self) -> None:
with TemporaryDirectory() as tmp:
key_path = Path(tmp) / "identity"
subprocess.run(
["ssh-keygen", "-t", "ed25519", "-N", "", "-C", "test", "-f", str(key_path)],
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
private_key = key_path.read_text(encoding="utf-8")
begin_marker = "-----BEGIN OPENSSH PRIVATE KEY-----"
end_marker = "-----END OPENSSH PRIVATE KEY-----"
body = private_key.split(begin_marker, 1)[1].split(end_marker, 1)[0]
damaged_body = " \n ".join(body.split())
damaged_key = f"{begin_marker}\n{damaged_body}\n{end_marker}"
normalized_key = normalize_private_key(damaged_key)
self.assertEqual(validate_ssh_private_key(normalized_key), validate_ssh_private_key(private_key))
def test_validate_private_key_rejects_pem_key_with_actionable_message(self) -> None:
with self.assertRaises(forms.ValidationError) as exc:
validate_ssh_private_key("-----BEGIN RSA PRIVATE KEY-----\nabc\n-----END RSA PRIVATE KEY-----")
self.assertIn("PEM private keys are not supported", str(exc.exception))
class SshCredentialManagementTests(TestCase):
def test_ensure_ssh_key_command_generates_default_key(self) -> None:
with TemporaryDirectory() as tmp, override_settings(POBSYNC_HOME=str(Path(tmp) / "home")):
call_command("ensure_pobsync_ssh_key", "--name", "default")
credential = SshCredential.objects.get(name="default")
self.assertTrue(credential.generated)
self.assertTrue(Path(credential.key_path).exists())
self.assertTrue(credential.public_key.startswith("ssh-ed25519 "))
def test_ensure_ssh_key_command_sets_global_default_when_available(self) -> None:
global_config = GlobalConfig.objects.create(name="default", backup_root="/backups")
with TemporaryDirectory() as tmp, override_settings(POBSYNC_HOME=str(Path(tmp) / "home")):
call_command("ensure_pobsync_ssh_key", "--name", "default", "--set-global-default")
global_config.refresh_from_db()
self.assertEqual(global_config.default_ssh_credential.name, "default")
def test_merge_known_hosts_appends_unique_entries(self) -> None:
merged = merge_known_hosts(
"web-01.example.test ssh-ed25519 AAAAOLD\n",
"web-01.example.test ssh-ed25519 AAAAOLD\nweb-01.example.test ssh-rsa AAAANEW\n",
)
self.assertEqual(
merged,
"web-01.example.test ssh-ed25519 AAAAOLD\nweb-01.example.test ssh-rsa AAAANEW\n",
)