Reformat and lint
This commit is contained in:
@ -10,15 +10,26 @@ from ipaddress import IPv4Network, IPv6Network
|
||||
from sshecret_sshd.ssh_server import run_ssh_server
|
||||
from sshecret_sshd.settings import ClientRegistrationSettings
|
||||
|
||||
from .types import Client, ClientKey, ClientRegistry, SshServerFixtureFun, SshServerFixture
|
||||
from .types import (
|
||||
Client,
|
||||
ClientKey,
|
||||
ClientRegistry,
|
||||
SshServerFixtureFun,
|
||||
SshServerFixture,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_registry() -> ClientRegistry:
|
||||
clients = {}
|
||||
secrets = {}
|
||||
|
||||
async def add_client(name: str, secret_names: list[str]|None=None, policies: list[str]|None=None) -> str:
|
||||
private_key = asyncssh.generate_private_key('ssh-rsa')
|
||||
async def add_client(
|
||||
name: str,
|
||||
secret_names: list[str] | None = None,
|
||||
policies: list[str] | None = None,
|
||||
) -> str:
|
||||
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||
public_key = private_key.export_public_key()
|
||||
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
|
||||
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
||||
@ -26,11 +37,12 @@ def client_registry() -> ClientRegistry:
|
||||
|
||||
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
backend = MagicMock()
|
||||
clients_data = client_registry['clients']
|
||||
secrets_data = client_registry['secrets']
|
||||
clients_data = client_registry["clients"]
|
||||
secrets_data = client_registry["secrets"]
|
||||
|
||||
async def get_client(name: str) -> Client | None:
|
||||
client_key = clients_data.get(name)
|
||||
@ -41,7 +53,7 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
description=f"Mock client {name}",
|
||||
public_key=client_key.public_key,
|
||||
secrets=[s for (c, s) in secrets_data if c == name],
|
||||
policies=[IPv4Network('0.0.0.0/0'), IPv6Network('::/0')],
|
||||
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
|
||||
)
|
||||
return response_model
|
||||
return None
|
||||
@ -57,7 +69,9 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
Otherwise we can't test this...
|
||||
"""
|
||||
if "template" not in clients_data:
|
||||
raise RuntimeError("Error, must have a client called template for this to work.")
|
||||
raise RuntimeError(
|
||||
"Error, must have a client called template for this to work."
|
||||
)
|
||||
clients_data[name] = clients_data["template"]
|
||||
for secret_key, secret in secrets_data.items():
|
||||
s_client, secret_name = secret_key
|
||||
@ -76,18 +90,22 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServerFixtureFun:
|
||||
async def ssh_server(
|
||||
mock_backend: MagicMock, unused_tcp_port: int
|
||||
) -> SshServerFixtureFun:
|
||||
port = unused_tcp_port
|
||||
|
||||
private_key = asyncssh.generate_private_key('ssh-ed25519')
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
key_str = private_key.export_private_key()
|
||||
with tempfile.NamedTemporaryFile('w+', delete=True) as key_file:
|
||||
|
||||
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
||||
key_file.write(key_str.decode())
|
||||
key_file.flush()
|
||||
|
||||
registration_settings = ClientRegistrationSettings(enabled=True, allow_from=[IPv4Network("0.0.0.0/0")])
|
||||
registration_settings = ClientRegistrationSettings(
|
||||
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
|
||||
)
|
||||
server = await run_ssh_server(
|
||||
backend=mock_backend,
|
||||
listen_address="localhost",
|
||||
@ -104,6 +122,7 @@ async def ssh_server(mock_backend: MagicMock, unused_tcp_port: int) -> SshServer
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegistry):
|
||||
"""Run a single command.
|
||||
@ -113,7 +132,7 @@ def ssh_command_runner(ssh_server: SshServerFixture, client_registry: ClientRegi
|
||||
_, port = ssh_server
|
||||
|
||||
async def run_command_as(name: str, command: str):
|
||||
client_key = client_registry['clients'][name]
|
||||
client_key = client_registry["clients"][name]
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
|
||||
Reference in New Issue
Block a user