Compare commits
11 Commits
7c65d5bb93
...
197c8a7c05
| Author | SHA1 | Date | |
|---|---|---|---|
| 197c8a7c05 | |||
| 80e2c339e3 | |||
| 458863de3d | |||
| a2ec2173ac | |||
| 090ec4dc3f | |||
| a07fba9560 | |||
| d3d99775d9 | |||
| b34c49d3e3 | |||
| d0b92b220e | |||
| 3dfd03688b | |||
| 388200fd52 |
36
.coveragerc
Normal file
36
.coveragerc
Normal file
@ -0,0 +1,36 @@
|
||||
[run]
|
||||
branch = True
|
||||
source =
|
||||
src/sshecret
|
||||
packages/sshecret-admin/src/sshecret_admin
|
||||
packages/sshecret-backend/src/sshecret_backend
|
||||
packages/sshecret-sshd/src/sshecret_sshd
|
||||
|
||||
omit =
|
||||
*/__init__.py
|
||||
*/types.py
|
||||
*/testing.py
|
||||
*/settings.py
|
||||
*/main.py
|
||||
*/cli.py
|
||||
*/tests/*
|
||||
*/test_*.py
|
||||
*/conftest.py
|
||||
*/site-packages/*
|
||||
concurrency = multiprocessing
|
||||
|
||||
[report]
|
||||
show_missing = True
|
||||
skip_covered = True
|
||||
|
||||
exclude_lines =
|
||||
if __name__ == .__main__.:
|
||||
def __repr__
|
||||
def __str__
|
||||
def __eq__
|
||||
def __ne__
|
||||
raise NotImplementedError
|
||||
except ImportError
|
||||
|
||||
[html]
|
||||
directory = coverage_html_report
|
||||
@ -21,7 +21,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
"""Create clients router."""
|
||||
app = APIRouter()
|
||||
app = APIRouter(dependencies=[Depends(dependencies.get_current_active_user)])
|
||||
|
||||
@app.get("/clients/")
|
||||
async def get_clients(
|
||||
|
||||
@ -19,7 +19,7 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
"""Create secrets router."""
|
||||
app = APIRouter()
|
||||
app = APIRouter(dependencies=[Depends(dependencies.get_current_active_user)])
|
||||
|
||||
@app.get("/secrets/")
|
||||
async def get_secret_names(
|
||||
|
||||
@ -55,7 +55,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
raise HTTPException(status_code=400, detail="Inactive or disabled user")
|
||||
return current_user
|
||||
|
||||
async def get_admin_backend(session: Annotated[Session, Depends(dependencies.get_db_session)]):
|
||||
async def get_admin_backend(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
if not password_db:
|
||||
@ -65,11 +67,13 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
yield admin
|
||||
|
||||
app = APIRouter(
|
||||
prefix=f"/api/{API_VERSION}", dependencies=[Depends(get_current_active_user)]
|
||||
app = APIRouter(prefix=f"/api/{API_VERSION}")
|
||||
|
||||
endpoint_deps = AdminDependencies.create(
|
||||
dependencies, get_admin_backend, get_current_active_user
|
||||
)
|
||||
|
||||
endpoint_deps = AdminDependencies.create(dependencies, get_admin_backend)
|
||||
LOG.debug("Registering sub-routers")
|
||||
|
||||
app.include_router(auth.create_router(endpoint_deps))
|
||||
app.include_router(clients.create_router(endpoint_deps))
|
||||
|
||||
@ -93,3 +93,9 @@ def decode_token(settings: AdminServerSettings, token: str) -> TokenData | None:
|
||||
except jwt.InvalidTokenError as e:
|
||||
LOG.debug("Could not decode token: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash password."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||||
return hashed_password.decode()
|
||||
|
||||
@ -46,10 +46,6 @@ class PasswordDB(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
class TokenData(SQLModel):
|
||||
"""Token data."""
|
||||
@ -69,3 +65,6 @@ class LoginError(SQLModel):
|
||||
title: str
|
||||
message: str
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
@ -5,13 +5,13 @@ import code
|
||||
from collections.abc import Awaitable
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
import bcrypt
|
||||
import click
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
import uvicorn
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, create_engine, select
|
||||
from sshecret_admin.auth.models import init_db, User, PasswordDB
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
@ -19,17 +19,6 @@ formatter = logging.Formatter(
|
||||
"%(asctime)s [%(processName)s: %(process)d] [%(threadName)s: %(thread)d] [%(levelname)s] %(name)s: %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
LOG = logging.getLogger()
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash password."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||||
return hashed_password.decode()
|
||||
|
||||
|
||||
def create_user(session: Session, username: str, password: str) -> None:
|
||||
"""Create a user."""
|
||||
@ -44,8 +33,14 @@ def create_user(session: Session, username: str, password: str) -> None:
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, debug: bool) -> None:
|
||||
"""Sshecret Admin."""
|
||||
LOG = logging.getLogger()
|
||||
LOG.addHandler(handler)
|
||||
|
||||
if debug:
|
||||
click.echo("Setting logging to debug level")
|
||||
LOG.setLevel(logging.DEBUG)
|
||||
else:
|
||||
LOG.setLevel(logging.INFO)
|
||||
try:
|
||||
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
||||
except ValidationError as e:
|
||||
|
||||
@ -12,7 +12,7 @@ def setup_database(
|
||||
) -> tuple[sa.Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=True)
|
||||
engine = create_engine(db_url, echo=False)
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
from typing import Awaitable, Self
|
||||
|
||||
from sqlmodel import Session
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
@ -13,6 +14,8 @@ DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
|
||||
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
||||
|
||||
GetUserDep = Callable[[User], Awaitable[User]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDependencies:
|
||||
@ -21,17 +24,25 @@ class BaseDependencies:
|
||||
settings: AdminServerSettings
|
||||
get_db_session: DBSessionDep
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdminDependencies(BaseDependencies):
|
||||
"""Dependency class with admin."""
|
||||
|
||||
get_admin_backend: AdminDep
|
||||
get_current_active_user: GetUserDep
|
||||
|
||||
@classmethod
|
||||
def create(cls, deps: BaseDependencies, get_admin_backend: AdminDep) -> Self:
|
||||
def create(
|
||||
cls,
|
||||
deps: BaseDependencies,
|
||||
get_admin_backend: AdminDep,
|
||||
get_current_active_user: GetUserDep,
|
||||
) -> Self:
|
||||
"""Create from base dependencies."""
|
||||
return cls(
|
||||
settings=deps.settings,
|
||||
get_db_session=deps.get_db_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
get_current_active_user=get_current_active_user,
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""SSH Server settings."""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from sqlalchemy import URL
|
||||
@ -22,10 +23,9 @@ class AdminServerSettings(BaseSettings):
|
||||
listen_address: str = Field(default="")
|
||||
secret_key: str
|
||||
port: int = DEFAULT_LISTEN_PORT
|
||||
|
||||
database: str = Field(default=DEFAULT_DATABASE)
|
||||
#admin_db: str = Field(default=DEFAULT_DATABASE)
|
||||
debug: bool = False
|
||||
password_manager_directory: Path | None = None
|
||||
|
||||
@property
|
||||
def admin_db(self) -> URL:
|
||||
|
||||
@ -43,7 +43,10 @@ class PasswordContext:
|
||||
)
|
||||
if entry and overwrite:
|
||||
entry.password = secret
|
||||
elif entry:
|
||||
self.keepass.save()
|
||||
return
|
||||
|
||||
if entry:
|
||||
raise ValueError("Error: A secret with this name already exists.")
|
||||
LOG.debug("Add secret entry to keepass: %s", entry_name)
|
||||
entry = self.keepass.add_entry(
|
||||
|
||||
@ -24,11 +24,14 @@ def setup_master_password(
|
||||
|
||||
This method should run just after setting up the database.
|
||||
"""
|
||||
created = _initial_key_setup(settings, filename, regenerate)
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
created = _initial_key_setup(settings, keyfile, regenerate)
|
||||
if not created:
|
||||
return None
|
||||
|
||||
return _generate_master_password(settings, filename)
|
||||
return _generate_master_password(settings, keyfile)
|
||||
|
||||
|
||||
def decrypt_master_password(
|
||||
@ -36,10 +39,12 @@ def decrypt_master_password(
|
||||
) -> str:
|
||||
"""Retrieve master password."""
|
||||
keyfile = Path(filename)
|
||||
if settings.password_manager_directory:
|
||||
keyfile = settings.password_manager_directory / filename
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
|
||||
private_key = load_private_key(KEY_FILENAME, password=settings.secret_key)
|
||||
private_key = load_private_key(str(keyfile.absolute()), password=settings.secret_key)
|
||||
return decode_string(encrypted, private_key)
|
||||
|
||||
|
||||
@ -50,12 +55,10 @@ def _generate_password() -> str:
|
||||
|
||||
def _initial_key_setup(
|
||||
settings: AdminServerSettings,
|
||||
filename: str = KEY_FILENAME,
|
||||
keyfile: Path,
|
||||
regenerate: bool = False,
|
||||
) -> bool:
|
||||
"""Set up initial keys."""
|
||||
keyfile = Path(filename)
|
||||
|
||||
if keyfile.exists() and not regenerate:
|
||||
return False
|
||||
|
||||
@ -67,16 +70,15 @@ def _initial_key_setup(
|
||||
|
||||
|
||||
def _generate_master_password(
|
||||
settings: AdminServerSettings, filename: str = KEY_FILENAME
|
||||
settings: AdminServerSettings, keyfile: Path
|
||||
) -> str:
|
||||
"""Generate master password for password database.
|
||||
|
||||
Returns the encrypted string, base64 encoded.
|
||||
"""
|
||||
keyfile = Path(filename)
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
private_key = load_private_key(filename, password=settings.secret_key)
|
||||
private_key = load_private_key(str(keyfile.absolute()), password=settings.secret_key)
|
||||
public_key = private_key.public_key()
|
||||
master_password = _generate_password()
|
||||
return encrypt_string(master_password, public_key)
|
||||
|
||||
@ -85,7 +85,7 @@ class SecretUpdate(BaseModel):
|
||||
"""
|
||||
if isinstance(self.value, str):
|
||||
return self.value
|
||||
secret = secrets.token_urlsafe(self.value.length)
|
||||
secret = secrets.token_urlsafe(32)[:self.value.length]
|
||||
return secret
|
||||
|
||||
|
||||
|
||||
@ -1,40 +1,17 @@
|
||||
"""Testing helper functions."""
|
||||
"""Testing helper functions.
|
||||
|
||||
This allows creation of a user from within tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import bcrypt
|
||||
|
||||
from sqlmodel import Session
|
||||
from .auth_models import User
|
||||
|
||||
|
||||
def get_test_user_details() -> tuple[str, str]:
|
||||
"""Resolve testing user."""
|
||||
test_user = os.getenv("SSHECRET_TEST_USERNAME") or "test"
|
||||
test_password = os.getenv("SSHECRET_TEST_PASSWORD") or "test"
|
||||
if test_user and test_password:
|
||||
return (test_user, test_password)
|
||||
|
||||
raise RuntimeError(
|
||||
"Error: No testing username and password registered in environment."
|
||||
)
|
||||
|
||||
|
||||
def is_testing_mode() -> bool:
|
||||
"""Check if we're running in test mode.
|
||||
|
||||
We will determine this by looking for the environment variable SSHECRET_TEST_MODE=1
|
||||
"""
|
||||
if os.environ.get("PYTEST_VERSION") is not None:
|
||||
return True
|
||||
return False
|
||||
from sshecret_admin.auth.models import User
|
||||
|
||||
|
||||
def create_test_user(session: Session, username: str, password: str) -> User:
|
||||
"""Create test user.
|
||||
|
||||
We create a user with whatever username and password is supplied.
|
||||
"""
|
||||
"""Create test user."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||||
user = User(username=username, hashed_password=hashed_password.decode())
|
||||
|
||||
@ -0,0 +1,88 @@
|
||||
"""Initial
|
||||
|
||||
Revision ID: 06af53cdf350
|
||||
Revises:
|
||||
Create Date: 2025-05-06 08:39:33.531696
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '06af53cdf350'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('api_client',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('subsystem', sa.String(), nullable=True),
|
||||
sa.Column('token', sa.String(), nullable=False),
|
||||
sa.Column('read_write', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('audit_log',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('subsystem', sa.String(), nullable=False),
|
||||
sa.Column('message', sa.String(), nullable=False),
|
||||
sa.Column('operation', sa.String(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('client_name', sa.String(), nullable=True),
|
||||
sa.Column('secret_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('secret_name', sa.String(), nullable=True),
|
||||
sa.Column('origin', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('client',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=True),
|
||||
sa.Column('public_key', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
op.create_table('client_access_policy',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('source', sa.String(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['client.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('client_secret',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=True),
|
||||
sa.Column('secret', sa.String(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=True),
|
||||
sa.Column('invalidated', sa.Boolean(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.ForeignKeyConstraint(['client_id'], ['client.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('client_secret')
|
||||
op.drop_table('client_access_policy')
|
||||
op.drop_table('client')
|
||||
op.drop_table('audit_log')
|
||||
op.drop_table('api_client')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,34 @@
|
||||
"""Update apiclient
|
||||
|
||||
Revision ID: 37329d9b5437
|
||||
Revises: 06af53cdf350
|
||||
Create Date: 2025-05-06 08:53:45.774225
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '37329d9b5437'
|
||||
down_revision: Union[str, None] = '06af53cdf350'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('api_client', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True))
|
||||
op.drop_column('api_client', 'read_write')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('api_client', sa.Column('read_write', sa.BOOLEAN(), nullable=False))
|
||||
op.drop_column('api_client', 'updated_at')
|
||||
# ### end Alembic commands ###
|
||||
@ -1,2 +0,0 @@
|
||||
[pytest]
|
||||
asyncio_mode = auto
|
||||
@ -1,2 +0,0 @@
|
||||
def hello() -> str:
|
||||
return "Hello from sshecret-sshd!"
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import cast
|
||||
import click
|
||||
from pydantic import ValidationError
|
||||
from .settings import ServerSettings
|
||||
from .ssh_server import start_server
|
||||
from .ssh_server import start_sshecret_sshd
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
@ -51,7 +51,7 @@ def cli_run(ctx: click.Context, host: str | None, port: int | None) -> None:
|
||||
settings.port = port
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(start_server(settings))
|
||||
loop.run_until_complete(start_sshecret_sshd(settings))
|
||||
title = click.style("Sshecret SSH Daemon", fg="red", bold=True)
|
||||
click.echo(f"Starting {title}: {settings.listen_address}:{settings.port}")
|
||||
try:
|
||||
|
||||
@ -6,7 +6,7 @@ ERROR_SOURCE_IP_NOT_ALLOWED = (
|
||||
)
|
||||
ERROR_NO_PUBLIC_KEY = "Error: No valid public key received."
|
||||
ERROR_INVALID_KEY_TYPE = "Error: Invalid key type: Only RSA keys are supported."
|
||||
ERROR_UNKNOWN_COMMAND = "Error: The given command was not understood."
|
||||
ERROR_UNKNOWN_COMMAND = "Error: Unsupported command."
|
||||
SERVER_KEY_TYPE = "ed25519"
|
||||
ERROR_BACKEND_ERROR = "Error: Unexpected response or error from backend"
|
||||
ERROR_INFO_BACKEND_GONE = "Unexpected error: Backend connection lost."
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""SSH Server implementation."""
|
||||
|
||||
from asyncio import _register_task
|
||||
import logging
|
||||
|
||||
import asyncssh
|
||||
@ -32,7 +33,7 @@ class CommandError(Exception):
|
||||
"""Error class for errors during command processing."""
|
||||
|
||||
|
||||
def audit_process(
|
||||
async def audit_process(
|
||||
backend: SshecretBackend,
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
operation: Operation,
|
||||
@ -54,24 +55,25 @@ def audit_process(
|
||||
data["command"] = cmd
|
||||
data["args"] = " ".join(cmd_args)
|
||||
|
||||
backend.audit(SubSystem.SSHD).write(
|
||||
await backend.audit(SubSystem.SSHD).write_async(
|
||||
operation, message, remote_ip, client, secret=None, secret_name=secret, **data
|
||||
)
|
||||
|
||||
|
||||
def audit_event(
|
||||
async def audit_event(
|
||||
backend: SshecretBackend,
|
||||
message: str,
|
||||
operation: Operation,
|
||||
client: Client | None = None,
|
||||
origin: str | None = None,
|
||||
secret: str | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Add an audit event."""
|
||||
if not origin:
|
||||
origin = "UNKNOWN"
|
||||
backend.audit(SubSystem.SSHD).write(
|
||||
operation, message, origin, client, secret=None, secret_name=secret
|
||||
await backend.audit(SubSystem.SSHD).write_async(
|
||||
operation, message, origin, client, secret=None, secret_name=secret, **data
|
||||
)
|
||||
|
||||
|
||||
@ -158,22 +160,14 @@ async def get_stdin_public_key(process: asyncssh.SSHServerProcess[str]) -> str |
|
||||
public_key = verify_key_input(line.rstrip("\n"))
|
||||
if public_key:
|
||||
break
|
||||
process.stdout.write("Invalid key. Must be RSA Public Key.\n")
|
||||
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
||||
except asyncssh.BreakReceived:
|
||||
pass
|
||||
process.stdout.write("OK\n")
|
||||
else:
|
||||
process.stdout.write("OK\n")
|
||||
return public_key
|
||||
|
||||
|
||||
def get_info_user_and_public_key(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Get username and public_key from process."""
|
||||
username = cast("str | None", process.get_extra_info("provided_username", None))
|
||||
public_key = cast("str | None", process.get_extra_info("provided_key", None))
|
||||
return (username, public_key)
|
||||
|
||||
|
||||
async def register_client(
|
||||
process: asyncssh.SSHServerProcess[str],
|
||||
backend: SshecretBackend,
|
||||
@ -187,7 +181,7 @@ async def register_client(
|
||||
key = asyncssh.import_public_key(public_key)
|
||||
if key.algorithm.decode() != "ssh-rsa":
|
||||
raise CommandError(constants.ERROR_INVALID_KEY_TYPE)
|
||||
audit_process(backend, process, Operation.CREATE, "Registering new client")
|
||||
await audit_process(backend, process, Operation.CREATE, "Registering new client")
|
||||
LOG.debug("Registering client %s with public key %s", username, public_key)
|
||||
await backend.create_client(username, public_key)
|
||||
|
||||
@ -205,7 +199,7 @@ async def get_secret(
|
||||
if secret_name not in client.secrets:
|
||||
raise CommandError(constants.ERROR_NO_SECRET_FOUND)
|
||||
|
||||
audit_event(
|
||||
await audit_event(
|
||||
backend,
|
||||
"Client requested secret",
|
||||
operation=Operation.READ,
|
||||
@ -247,7 +241,7 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
|
||||
allowed_networks = get_info_allowed_registration(process)
|
||||
if not allowed_networks:
|
||||
process.stdout.write("Unauthorized.\n")
|
||||
audit_process(
|
||||
await audit_process(
|
||||
backend,
|
||||
process,
|
||||
Operation.DENY,
|
||||
@ -266,7 +260,7 @@ async def dispatch_cmd_register(process: asyncssh.SSHServerProcess[str]) -> None
|
||||
if client_address in network:
|
||||
break
|
||||
else:
|
||||
audit_process(
|
||||
await audit_process(
|
||||
backend,
|
||||
process,
|
||||
Operation.DENY,
|
||||
@ -381,8 +375,14 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
|
||||
"""
|
||||
LOG.debug("Started authentication flow for user %s", username)
|
||||
if not self._conn:
|
||||
return True
|
||||
allowed_registration_sources: list[IPvAnyNetwork] = []
|
||||
if self.registration_enabled and not self.allow_registration_from:
|
||||
allowed_registration_sources.append(ipaddress.IPv4Network("0.0.0.0/0"))
|
||||
allowed_registration_sources.append(ipaddress.IPv6Network("::/0"))
|
||||
elif self.registration_enabled and self.allow_registration_from:
|
||||
allowed_registration_sources = self.allow_registration_from
|
||||
|
||||
assert self._conn is not None, "Error: No connection found."
|
||||
if client := await self.backend.get_client(username):
|
||||
LOG.debug("Client lookup sucessful: %r", client)
|
||||
if key := self.resolve_client_key(client):
|
||||
@ -390,40 +390,50 @@ class AsshyncServer(asyncssh.SSHServer):
|
||||
self._conn.set_extra_info(client=client)
|
||||
self._conn.set_authorized_keys(key)
|
||||
else:
|
||||
audit_event(
|
||||
await audit_event(
|
||||
self.backend,
|
||||
"Client denied due to policy",
|
||||
Operation.DENY,
|
||||
client,
|
||||
origin=self.client_ip,
|
||||
)
|
||||
LOG.warning("Client connection denied due to policy.")
|
||||
elif self.registration_enabled:
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
self._conn.set_extra_info(
|
||||
allow_registration_from=self.allow_registration_from
|
||||
)
|
||||
LOG.warning(
|
||||
"Registration enabled, and client is not recognized. Bypassing authentication."
|
||||
)
|
||||
return False
|
||||
LOG.warning(
|
||||
"Client connection denied. Source: %s, policy: %r.",
|
||||
self.client_ip,
|
||||
client.policies,
|
||||
)
|
||||
elif allowed_registration_sources and self.client_ip:
|
||||
client_ip = ipaddress.ip_address(self.client_ip)
|
||||
for network in allowed_registration_sources:
|
||||
if client_ip.version != network.version:
|
||||
continue
|
||||
if client_ip in network:
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
self._conn.set_extra_info(
|
||||
allow_registration_from=self.allow_registration_from
|
||||
)
|
||||
LOG.info(
|
||||
"Registration enabled, and client is not recognized. Bypassing authentication."
|
||||
)
|
||||
return False
|
||||
else:
|
||||
await audit_event(
|
||||
self.backend,
|
||||
"Received registration command from unauthorized subnet.",
|
||||
Operation.DENY,
|
||||
origin=self.client_ip,
|
||||
username=username,
|
||||
)
|
||||
|
||||
LOG.warning(
|
||||
"Registration not permitted for username=%s, origin: %s",
|
||||
username,
|
||||
self.client_ip,
|
||||
)
|
||||
|
||||
LOG.debug("Continuing to regular authentication")
|
||||
return True
|
||||
|
||||
@override
|
||||
def validate_public_key(self, username: str, key: asyncssh.SSHKey) -> bool:
|
||||
"""Intercept public key validation."""
|
||||
if not self._conn:
|
||||
return False
|
||||
|
||||
# get an export of the provided public key.
|
||||
keystring = key.export_public_key().decode()
|
||||
self._conn.set_extra_info(provided_username=username)
|
||||
self._conn.set_extra_info(provided_key=keystring)
|
||||
LOG.debug("Intercepting user public key")
|
||||
return False
|
||||
|
||||
def resolve_client_key(self, client: Client) -> asyncssh.SSHAuthorizedKeys | None:
|
||||
"""Resolve the client key.
|
||||
|
||||
@ -492,7 +502,9 @@ async def run_ssh_server(
|
||||
return server
|
||||
|
||||
|
||||
async def start_server(settings: ServerSettings | None = None) -> None:
|
||||
async def start_sshecret_sshd(
|
||||
settings: ServerSettings | None = None,
|
||||
) -> asyncssh.SSHAcceptor:
|
||||
"""Start the server."""
|
||||
server_key = get_server_key()
|
||||
|
||||
@ -500,7 +512,7 @@ async def start_server(settings: ServerSettings | None = None) -> None:
|
||||
settings = ServerSettings() # pyright: ignore[reportCallIssue]
|
||||
|
||||
backend = SshecretBackend(str(settings.backend_url), settings.backend_token)
|
||||
await run_ssh_server(
|
||||
return await run_ssh_server(
|
||||
backend=backend,
|
||||
listen_address=settings.listen_address,
|
||||
port=settings.port,
|
||||
|
||||
@ -6,12 +6,15 @@ test = "pytest ${PWD}"
|
||||
all = [ {ref="fmt"}, {ref="lint"}, {ref="check"}, {ref="test"} ]
|
||||
"ci:fmt" = "ruff format --check ${PWD}" # fail if not formatted
|
||||
"ci:lint" = "ruff check ${PWD}"
|
||||
[tool.poe.tasks.coverage]
|
||||
cmd = "pytest --cov-config=${PWD}/.coveragerc --cov --cov-report=html --cov-report=term-missing"
|
||||
cwd = "${POE_PWD}"
|
||||
|
||||
|
||||
[project]
|
||||
name = "sshecret"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
description = "A simple secret manager with an SSH server."
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Allan Eising", email = "allan@eising.dk" }
|
||||
@ -61,7 +64,9 @@ dev = [
|
||||
"python-dotenv>=1.0.1",
|
||||
]
|
||||
test = [
|
||||
"coverage>=7.8.0",
|
||||
"pytest>=8.3.5",
|
||||
"pytest-asyncio>=0.26.0",
|
||||
"pytest-cov>=6.1.1",
|
||||
"robotframework>=7.2.2",
|
||||
]
|
||||
|
||||
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
27
tests/integration/clients.py
Normal file
27
tests/integration/clients.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Client helpers."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientData:
|
||||
"""Test client."""
|
||||
|
||||
name: str
|
||||
private_key: rsa.RSAPrivateKey
|
||||
|
||||
@property
|
||||
def public_key(self) -> str:
|
||||
"""Return public key as string."""
|
||||
return generate_public_key_string(self.private_key.public_key())
|
||||
|
||||
|
||||
def create_test_client(name: str) -> ClientData:
|
||||
"""Create test client."""
|
||||
return ClientData(
|
||||
name=name,
|
||||
private_key=generate_private_key()
|
||||
)
|
||||
209
tests/integration/conftest.py
Normal file
209
tests/integration/conftest.py
Normal file
@ -0,0 +1,209 @@
|
||||
"""Test library.
|
||||
|
||||
Strategy:
|
||||
|
||||
We start by spawning the backend server, and create two test keys.
|
||||
|
||||
Then we spawn the sshd and the admin api.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import asyncssh
|
||||
import secrets
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
import uvicorn
|
||||
from sshecret.backend import SshecretBackend
|
||||
from sshecret.crypto import (
|
||||
generate_private_key,
|
||||
generate_public_key_string,
|
||||
write_private_key,
|
||||
)
|
||||
from sshecret_admin.core.app import create_admin_app
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_backend.app import create_backend_app
|
||||
from sshecret_backend.settings import BackendSettings
|
||||
from sshecret_backend.testing import create_test_token
|
||||
from sshecret_sshd.settings import ServerSettings
|
||||
from sshecret_sshd.ssh_server import start_sshecret_sshd
|
||||
|
||||
from .clients import ClientData
|
||||
from .helpers import create_sshd_server_key, create_test_admin_user, in_tempdir
|
||||
from .types import PortFactory, TestPorts
|
||||
|
||||
TEST_SCOPE = "function"
|
||||
LOOP_SCOPE = "function"
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
"""Generate a test key."""
|
||||
private_key = generate_private_key()
|
||||
return generate_public_key_string(private_key.public_key())
|
||||
|
||||
|
||||
@pytest.fixture(name="test_ports", scope="session")
|
||||
def generate_test_ports(unused_tcp_port_factory: PortFactory) -> TestPorts:
|
||||
"""Generate the test ports."""
|
||||
test_ports = TestPorts(
|
||||
backend=unused_tcp_port_factory(),
|
||||
admin=unused_tcp_port_factory(),
|
||||
sshd=unused_tcp_port_factory(),
|
||||
)
|
||||
print(f"{test_ports=!r}")
|
||||
return test_ports
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_server", loop_scope=LOOP_SCOPE)
|
||||
async def run_backend_server(test_ports: TestPorts):
|
||||
"""Run the backend server."""
|
||||
port = test_ports.backend
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
backend_work_path = Path(tmp_dir)
|
||||
db_file = backend_work_path / "backend.db"
|
||||
backend_settings = BackendSettings(database=str(db_file.absolute()))
|
||||
backend_app = create_backend_app(backend_settings)
|
||||
token = create_test_token(backend_settings)
|
||||
config = uvicorn.Config(app=backend_app, port=port, loop="asyncio")
|
||||
server = uvicorn.Server(config=config)
|
||||
server_task = asyncio.create_task(server.serve())
|
||||
await asyncio.sleep(0.1)
|
||||
backend_url = f"http://127.0.0.1:{port}"
|
||||
yield (backend_url, token)
|
||||
server.should_exit = True
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="admin_server", loop_scope=LOOP_SCOPE)
|
||||
async def run_admin_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
||||
"""Run admin server."""
|
||||
backend_url, backend_token = backend_server
|
||||
secret_key = secrets.token_urlsafe(32)
|
||||
port = test_ports.admin
|
||||
with in_tempdir() as admin_work_path:
|
||||
admin_db = admin_work_path / "ssh_admin.db"
|
||||
admin_settings = AdminServerSettings.model_validate(
|
||||
{
|
||||
"sshecret_backend_url": backend_url,
|
||||
"backend_token": backend_token,
|
||||
"secret_key": secret_key,
|
||||
"listen_address": "127.0.0.1",
|
||||
"port": port,
|
||||
"database": str(admin_db.absolute()),
|
||||
"password_manager_directory": str(admin_work_path.absolute()),
|
||||
}
|
||||
)
|
||||
admin_app = create_admin_app(admin_settings)
|
||||
config = uvicorn.Config(app=admin_app, port=port, loop="asyncio")
|
||||
server = uvicorn.Server(config=config)
|
||||
server_task = asyncio.create_task(server.serve())
|
||||
await asyncio.sleep(0.1)
|
||||
admin_url = f"http://127.0.0.1:{port}"
|
||||
admin_password = secrets.token_urlsafe(10)
|
||||
create_test_admin_user(admin_settings, "test", admin_password)
|
||||
await asyncio.sleep(0.1)
|
||||
yield (admin_url, ("test", admin_password))
|
||||
server.should_exit = True
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="ssh_server", loop_scope=LOOP_SCOPE)
|
||||
async def start_ssh_server(test_ports: TestPorts, backend_server: tuple[str, str]):
|
||||
"""Run ssh server."""
|
||||
backend_url, backend_token = backend_server
|
||||
port = test_ports.sshd
|
||||
with in_tempdir() as ssh_workdir:
|
||||
create_sshd_server_key(ssh_workdir)
|
||||
sshd_server_settings = ServerSettings.model_validate(
|
||||
{
|
||||
"sshecret_backend_url": backend_url,
|
||||
"backend_token": backend_token,
|
||||
"listen_address": "",
|
||||
"port": port,
|
||||
"registration": {"enabled": True, "allow_from": "0.0.0.0/0"},
|
||||
"enable_ping_command": True,
|
||||
}
|
||||
)
|
||||
|
||||
ssh_server = await start_sshecret_sshd(sshd_server_settings)
|
||||
await asyncio.sleep(0.1)
|
||||
print(f"Started sshd on port {port}")
|
||||
yield port
|
||||
|
||||
ssh_server.close()
|
||||
await ssh_server.wait_closed()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope=TEST_SCOPE, name="backend_client", loop_scope=LOOP_SCOPE)
|
||||
async def create_backend_http_client(backend_server: tuple[str, str]):
|
||||
"""Create a test client."""
|
||||
backend_url, backend_token = backend_server
|
||||
print(f"Creating backend client towards {backend_url}")
|
||||
async with httpx.AsyncClient(
|
||||
base_url=backend_url, headers={"X-API-Token": backend_token}
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(name="backend_api")
|
||||
async def get_test_backend_api(backend_server: tuple[str, str]) -> SshecretBackend:
|
||||
"""Get the backend API."""
|
||||
backend_url, backend_token = backend_server
|
||||
return SshecretBackend(backend_url, backend_token)
|
||||
|
||||
|
||||
@pytest.fixture(scope=TEST_SCOPE)
|
||||
def ssh_command_runner(ssh_server: int, tmp_path: Path):
|
||||
"""Run a single command on the ssh server."""
|
||||
port = ssh_server
|
||||
|
||||
async def run_command_as(test_client: ClientData, command: str):
|
||||
private_key_file = tmp_path / f"id_{test_client.name}"
|
||||
write_private_key(test_client.private_key, private_key_file)
|
||||
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=test_client.name,
|
||||
client_keys=[str(private_key_file)],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
result = await conn.run(command)
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_command_as
|
||||
|
||||
|
||||
@pytest.fixture(name="ssh_session", scope=TEST_SCOPE)
|
||||
def create_ssh_session(ssh_server: int, tmp_path: Path):
|
||||
"""Create a ssh Session."""
|
||||
port = ssh_server
|
||||
|
||||
@asynccontextmanager
|
||||
async def run_process(test_client: ClientData, command: str):
|
||||
private_key_file = tmp_path / f"id_{test_client.name}"
|
||||
write_private_key(test_client.private_key, private_key_file)
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=test_client.name,
|
||||
client_keys=[str(private_key_file)],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
async with conn.create_process(command) as process:
|
||||
yield process
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
return run_process
|
||||
41
tests/integration/helpers.py
Normal file
41
tests/integration/helpers.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Helper functions."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from sqlmodel import Session, create_engine
|
||||
from sshecret.crypto import generate_private_key, write_private_key
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.auth.models import User, init_db
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
def create_test_admin_user(settings: AdminServerSettings, username: str, password: str) -> None:
|
||||
"""Create a test admin user."""
|
||||
hashed_password = hash_password(password)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
user = User(username=username, hashed_password=hashed_password)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
|
||||
def create_sshd_server_key(sshd_path: Path) -> Path:
|
||||
"""Create a ssh key at a general"""
|
||||
server_file = sshd_path / "ssh_host_key"
|
||||
private_key = generate_private_key()
|
||||
write_private_key(private_key, server_file)
|
||||
return server_file
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_tempdir() -> Iterator[Path]:
|
||||
"""Run in a temporary directory."""
|
||||
curdir = os.getcwd()
|
||||
with tempfile.TemporaryDirectory() as temp_directory:
|
||||
temp_path = Path(temp_directory)
|
||||
os.chdir(temp_directory)
|
||||
yield temp_path
|
||||
os.chdir(curdir)
|
||||
222
tests/integration/test_admin_api.py
Normal file
222
tests/integration/test_admin_api.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""Tests of the admin interface."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
import pytest
|
||||
|
||||
import httpx
|
||||
|
||||
from sshecret.backend import Client
|
||||
|
||||
from sshecret.crypto import generate_private_key, generate_public_key_string
|
||||
|
||||
from .types import AdminServer
|
||||
|
||||
|
||||
def make_test_key() -> str:
|
||||
"""Generate a test key."""
|
||||
private_key = generate_private_key()
|
||||
return generate_public_key_string(private_key.public_key())
|
||||
|
||||
|
||||
class BaseAdminTests:
|
||||
"""Base admin test class."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def http_client(
|
||||
self, admin_server: AdminServer, authenticate: bool = True
|
||||
) -> AsyncIterator[httpx.AsyncClient]:
|
||||
"""Run a client towards the admin rest api."""
|
||||
admin_url, credentials = admin_server
|
||||
username, password = credentials
|
||||
headers: dict[str, str] | None = None
|
||||
if authenticate:
|
||||
async with httpx.AsyncClient(base_url=admin_url) as client:
|
||||
|
||||
response = await client.post(
|
||||
"api/v1/token", data={"username": username, "password": password}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
token = data["access_token"]
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async with httpx.AsyncClient(base_url=admin_url, headers=headers) as client:
|
||||
yield client
|
||||
|
||||
async def create_client(
|
||||
self,
|
||||
admin_server: AdminServer,
|
||||
name: str,
|
||||
public_key: str | None = None,
|
||||
) -> Client:
|
||||
"""Create a client."""
|
||||
if not public_key:
|
||||
public_key = make_test_key()
|
||||
|
||||
new_client = {
|
||||
"name": name,
|
||||
"public_key": public_key,
|
||||
"sources": ["192.0.2.0/24"],
|
||||
}
|
||||
|
||||
async with self.http_client(admin_server, True) as http_client:
|
||||
response = await http_client.post("api/v1/clients/", json=new_client)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
client = Client.model_validate(data)
|
||||
|
||||
return client
|
||||
|
||||
|
||||
class TestAdminAPI(BaseAdminTests):
|
||||
"""Tests of the Admin REST API."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(
|
||||
self, admin_server: tuple[str, tuple[str, str]]
|
||||
) -> None:
|
||||
"""Test admin login."""
|
||||
async with self.http_client(admin_server, False) as client:
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_login(self, admin_server: AdminServer) -> None:
|
||||
"""Test admin login."""
|
||||
|
||||
async with self.http_client(admin_server, False) as client:
|
||||
resp = await client.get("api/v1/clients/")
|
||||
assert resp.status_code == 401
|
||||
|
||||
async with self.http_client(admin_server, True) as client:
|
||||
resp = await client.get("api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestAdminApiClients(BaseAdminTests):
|
||||
"""Test client routes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_client(self, admin_server: AdminServer) -> None:
|
||||
"""Test create_client."""
|
||||
client = await self.create_client(admin_server, "testclient")
|
||||
|
||||
assert client.id is not None
|
||||
assert client.name == "testclient"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_clients(self, admin_server: AdminServer) -> None:
|
||||
"""Test get_clients."""
|
||||
|
||||
client_names = ["test-db", "test-app", "test-www"]
|
||||
for name in client_names:
|
||||
await self.create_client(admin_server, name)
|
||||
async with self.http_client(admin_server) as http_client:
|
||||
resp = await http_client.get("api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 3
|
||||
for entry in data:
|
||||
assert isinstance(entry, dict)
|
||||
client_name = entry.get("name")
|
||||
assert client_name in client_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_client(self, admin_server: AdminServer) -> None:
|
||||
"""Test delete_client."""
|
||||
await self.create_client(admin_server, name="testclient")
|
||||
async with self.http_client(admin_server) as http_client:
|
||||
resp = await http_client.get("api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]["name"] == "testclient"
|
||||
|
||||
resp = await http_client.delete("/api/v1/clients/testclient")
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = await http_client.get("api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 0
|
||||
|
||||
|
||||
class TestAdminApiSecrets(BaseAdminTests):
|
||||
"""Test secret management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_secret(self, admin_server: AdminServer) -> None:
|
||||
"""Test add_secret."""
|
||||
await self.create_client(admin_server, name="testclient")
|
||||
async with self.http_client(admin_server) as http_client:
|
||||
data = {
|
||||
"name": "testsecret",
|
||||
"clients": ["testclient"],
|
||||
"value": "secretstring",
|
||||
}
|
||||
resp = await http_client.post("api/v1/secrets/", json=data)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret(self, admin_server: AdminServer) -> None:
|
||||
"""Test get_secret."""
|
||||
await self.test_add_secret(admin_server)
|
||||
async with self.http_client(admin_server) as http_client:
|
||||
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, dict)
|
||||
assert data["name"] == "testsecret"
|
||||
assert data["secret"] == "secretstring"
|
||||
assert "testclient" in data["clients"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_secret_auto(self, admin_server: AdminServer) -> None:
|
||||
"""Test adding a secret with an auto-generated value."""
|
||||
await self.create_client(admin_server, name="testclient")
|
||||
async with self.http_client(admin_server) as http_client:
|
||||
data = {
|
||||
"name": "testsecret",
|
||||
"clients": ["testclient"],
|
||||
"value": {"auto_generate": True, "length": 17},
|
||||
}
|
||||
resp = await http_client.post("api/v1/secrets/", json=data)
|
||||
assert resp.status_code == 200
|
||||
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, dict)
|
||||
assert data["name"] == "testsecret"
|
||||
assert len(data["secret"]) == 17
|
||||
assert "testclient" in data["clients"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_secret(self, admin_server: AdminServer) -> None:
|
||||
"""Test updating secrets."""
|
||||
await self.test_add_secret_auto(admin_server)
|
||||
async with self.http_client(admin_server) as http_client:
|
||||
resp = await http_client.put(
|
||||
"api/v1/secrets/testsecret",
|
||||
json={"value": "secret"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["secret"] == "secret"
|
||||
|
||||
resp = await http_client.put(
|
||||
"api/v1/secrets/testsecret",
|
||||
json={"value": {"auto_generate": True, "length": 16}},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
resp = await http_client.get("api/v1/secrets/testsecret")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data["secret"]) == 16
|
||||
59
tests/integration/test_backend.py
Normal file
59
tests/integration/test_backend.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Test backend.
|
||||
|
||||
These tests just ensure that the backend works well enough for us to run the
|
||||
rest of the tests.
|
||||
|
||||
"""
|
||||
import pytest
|
||||
import httpx
|
||||
from sshecret.backend import SshecretBackend
|
||||
from .clients import create_test_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthcheck(backend_client: httpx.AsyncClient) -> None:
|
||||
"""Test healthcheck command."""
|
||||
resp = await backend_client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "LIVE"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_client(backend_api: SshecretBackend) -> None:
|
||||
"""Test creating a client."""
|
||||
test_client = create_test_client("test")
|
||||
await backend_api.create_client("test", test_client.public_key, "A test client")
|
||||
|
||||
# fetch the list of clients.
|
||||
|
||||
clients = await backend_api.get_clients()
|
||||
assert clients is not None
|
||||
|
||||
assert len(clients) == 1
|
||||
|
||||
assert clients[0].name == "test"
|
||||
|
||||
assert clients[0].public_key == test_client.public_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_secret(backend_api: SshecretBackend) -> None:
|
||||
"""Test creating secrets."""
|
||||
test_client = create_test_client("test")
|
||||
await backend_api.create_client("test", test_client.public_key, "A test client")
|
||||
|
||||
await backend_api.create_client_secret("test", "mysecret", "encrypted_secret")
|
||||
|
||||
secrets = await backend_api.get_secrets()
|
||||
assert len(secrets) == 1
|
||||
assert secrets[0].name == "mysecret"
|
||||
|
||||
|
||||
secret_to_client = await backend_api.get_secret("mysecret")
|
||||
assert secret_to_client is not None
|
||||
|
||||
assert secret_to_client.name == "mysecret"
|
||||
assert "test" in secret_to_client.clients
|
||||
|
||||
secret = await backend_api.get_client_secret("test", "mysecret")
|
||||
|
||||
assert secret is not None
|
||||
assert secret == "encrypted_secret"
|
||||
139
tests/integration/test_sshd.py
Normal file
139
tests/integration/test_sshd.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""Tests where the sshd is the main consumer.
|
||||
|
||||
This essentially also tests parts of the admin API.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
import os
|
||||
import httpx
|
||||
|
||||
import pytest
|
||||
from sshecret.crypto import decode_string
|
||||
from sshecret.backend.api import SshecretBackend
|
||||
|
||||
from .clients import create_test_client, ClientData
|
||||
|
||||
from .types import CommandRunner, ProcessRunner
|
||||
|
||||
|
||||
class TestSshd:
|
||||
"""Class based tests.
|
||||
|
||||
This allows us to create small helpers.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_secret(
|
||||
self, backend_api: SshecretBackend, ssh_command_runner: CommandRunner
|
||||
) -> None:
|
||||
"""Test get secret flow."""
|
||||
test_client = create_test_client("testclient")
|
||||
await backend_api.create_client(
|
||||
"testclient", test_client.public_key, "A test client"
|
||||
)
|
||||
await backend_api.create_client_secret("testclient", "testsecret", "bogus")
|
||||
response = await ssh_command_runner(test_client, "get_secret testsecret")
|
||||
assert response.exit_status == 0
|
||||
assert response.stdout is not None
|
||||
assert isinstance(response.stdout, str)
|
||||
assert response.stdout.rstrip() == "bogus"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register(
|
||||
self, backend_api: SshecretBackend, ssh_session: ProcessRunner
|
||||
) -> None:
|
||||
"""Test registration."""
|
||||
await self.register_client("new_client", ssh_session)
|
||||
# Check that the client is created.
|
||||
clients = await backend_api.get_clients()
|
||||
assert len(clients) == 1
|
||||
|
||||
client = clients[0]
|
||||
assert client.name == "new_client"
|
||||
|
||||
async def register_client(
|
||||
self, name: str, ssh_session: ProcessRunner
|
||||
) -> ClientData:
|
||||
"""Register client."""
|
||||
test_client = create_test_client(name)
|
||||
async with ssh_session(test_client, "register") as session:
|
||||
maxlines = 10
|
||||
linenum = 0
|
||||
found = False
|
||||
while linenum < maxlines:
|
||||
line = await session.stdout.readline()
|
||||
if "Enter public key" in line:
|
||||
found = True
|
||||
break
|
||||
assert found is True
|
||||
session.stdin.write(test_client.public_key + "\n")
|
||||
|
||||
result = await session.stdout.readline()
|
||||
assert "OK" in result
|
||||
await session.wait()
|
||||
return test_client
|
||||
|
||||
|
||||
class TestSshdIntegration(TestSshd):
|
||||
"""Integration tests."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end(
|
||||
self,
|
||||
backend_api: SshecretBackend,
|
||||
admin_server: tuple[str, tuple[str, str]],
|
||||
ssh_session: ProcessRunner,
|
||||
ssh_command_runner: CommandRunner,
|
||||
) -> None:
|
||||
"""Test end to end."""
|
||||
test_client = await self.register_client("myclient", ssh_session)
|
||||
url, credentials = admin_server
|
||||
username, password = credentials
|
||||
async with self.admin_client(url, username, password) as http_client:
|
||||
resp = await http_client.get("api/v1/clients/")
|
||||
assert resp.status_code == 200
|
||||
clients = resp.json()
|
||||
assert len(clients) == 1
|
||||
assert clients[0]["name"] == "myclient"
|
||||
|
||||
create_model = {
|
||||
"name": "mysecret",
|
||||
"clients": ["myclient"],
|
||||
"value": "mypassword",
|
||||
}
|
||||
resp = await http_client.post("api/v1/secrets/", json=create_model)
|
||||
assert resp.status_code == 200
|
||||
|
||||
# Login via ssh to fetch the decrypted value.
|
||||
ssh_output = await ssh_command_runner(test_client, "get_secret mysecret")
|
||||
assert ssh_output.stdout is not None
|
||||
assert isinstance(ssh_output.stdout, str)
|
||||
encrypted = ssh_output.stdout.rstrip()
|
||||
decrypted = decode_string(encrypted, test_client.private_key)
|
||||
assert decrypted == "mypassword"
|
||||
|
||||
async def login(self, url: str, username: str, password: str) -> str:
|
||||
"""Login and get token."""
|
||||
api_url = os.path.join(url, "api/v1", "token")
|
||||
client = httpx.AsyncClient()
|
||||
|
||||
response = await client.post(
|
||||
api_url, data={"username": username, "password": password}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "access_token" in data
|
||||
assert isinstance(data["access_token"], str)
|
||||
return str(data["access_token"])
|
||||
|
||||
@asynccontextmanager
|
||||
async def admin_client(
|
||||
self, url: str, username: str, password: str
|
||||
) -> AsyncIterator[httpx.AsyncClient]:
|
||||
"""Create an admin client."""
|
||||
token = await self.login(url, username, password)
|
||||
async with httpx.AsyncClient(
|
||||
base_url=url, headers={"Authorization": f"Bearer {token}"}
|
||||
) as client:
|
||||
yield client
|
||||
30
tests/integration/types.py
Normal file
30
tests/integration/types.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Typings."""
|
||||
import asyncssh
|
||||
|
||||
from typing import Any, AsyncContextManager, Protocol
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Callable, Awaitable
|
||||
|
||||
from .clients import ClientData
|
||||
|
||||
|
||||
PortFactory = Callable[[], int]
|
||||
|
||||
AdminServer = tuple[str, tuple[str, str]]
|
||||
|
||||
@dataclass
|
||||
class TestPorts:
|
||||
"""Test port dataclass."""
|
||||
|
||||
backend: int
|
||||
admin: int
|
||||
sshd: int
|
||||
|
||||
|
||||
CommandRunner = Callable[[ClientData, str], Awaitable[asyncssh.SSHCompletedProcess]]
|
||||
|
||||
class ProcessRunner(Protocol):
|
||||
"""Process runner typing."""
|
||||
|
||||
def __call__(self, test_client: ClientData, command: str) -> AsyncContextManager[asyncssh.SSHClientProcess[Any]]:
|
||||
...
|
||||
0
tests/packages/__init__.py
Normal file
0
tests/packages/__init__.py
Normal file
0
tests/packages/backend/__init__.py
Normal file
0
tests/packages/backend/__init__.py
Normal file
1
tests/packages/sshd/__init__.py
Normal file
1
tests/packages/sshd/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
import asyncio
|
||||
from pydantic import IPvAnyNetwork
|
||||
import pytest
|
||||
import uuid
|
||||
import asyncssh
|
||||
import tempfile
|
||||
from contextlib import asynccontextmanager
|
||||
import pytest_asyncio
|
||||
from pytest import FixtureRequest
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from ipaddress import IPv4Network, IPv6Network, ip_network
|
||||
|
||||
from sshecret_sshd.ssh_server import run_ssh_server
|
||||
from sshecret_sshd.settings import ClientRegistrationSettings
|
||||
@ -31,14 +35,16 @@ def client_registry() -> ClientRegistry:
|
||||
) -> str:
|
||||
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||
public_key = private_key.export_public_key()
|
||||
clients[name] = ClientKey(name, private_key, public_key.decode().rstrip())
|
||||
clients[name] = ClientKey(
|
||||
name, private_key, public_key.decode().rstrip(), policies
|
||||
)
|
||||
secrets.update({(name, s): f"mocked-secret-{s}" for s in (secret_names or [])})
|
||||
return clients[name]
|
||||
|
||||
return {"clients": clients, "secrets": secrets, "add_client": add_client}
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
backend = MagicMock()
|
||||
clients_data = client_registry["clients"]
|
||||
@ -47,13 +53,16 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
async def get_client(name: str) -> Client | None:
|
||||
client_key = clients_data.get(name)
|
||||
if client_key:
|
||||
policies = [IPv4Network("0.0.0.0/0"), IPv6Network("::/0")]
|
||||
if client_key.policies:
|
||||
policies = [ip_network(network) for network in client_key.policies]
|
||||
response_model = Client(
|
||||
id=uuid.uuid4(),
|
||||
name=name,
|
||||
description=f"Mock client {name}",
|
||||
public_key=client_key.public_key,
|
||||
secrets=[s for (c, s) in secrets_data if c == name],
|
||||
policies=[IPv4Network("0.0.0.0/0"), IPv6Network("::/0")],
|
||||
policies=policies,
|
||||
)
|
||||
return response_model
|
||||
return None
|
||||
@ -79,32 +88,46 @@ async def mock_backend(client_registry: ClientRegistry) -> MagicMock:
|
||||
continue
|
||||
secrets_data[(name, secret_name)] = secret
|
||||
|
||||
async def write_audit(*args, **kwargs):
|
||||
"""Write audit mock."""
|
||||
return None
|
||||
|
||||
backend.get_client = AsyncMock(side_effect=get_client)
|
||||
backend.get_client_secret = AsyncMock(side_effect=get_client_secret)
|
||||
backend.create_client = AsyncMock(side_effect=create_client)
|
||||
|
||||
# Make sure backend.audit(...) returns the audit mock
|
||||
audit = MagicMock()
|
||||
audit.write = MagicMock()
|
||||
audit.write_async = AsyncMock(side_effect=write_audit)
|
||||
backend.audit = MagicMock(return_value=audit)
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def ssh_server(
|
||||
mock_backend: MagicMock, unused_tcp_port: int
|
||||
request: FixtureRequest,
|
||||
mock_backend: MagicMock,
|
||||
unused_tcp_port: int,
|
||||
) -> SshServerFixtureFun:
|
||||
port = unused_tcp_port
|
||||
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
key_str = private_key.export_private_key()
|
||||
registration_mark = request.node.get_closest_marker("enable_registration")
|
||||
registration_enabled = registration_mark is not None
|
||||
registration_source_mark = request.node.get_closest_marker("registration_sources")
|
||||
allowed_from: list[IPvAnyNetwork] = []
|
||||
if registration_source_mark:
|
||||
for network in registration_source_mark.args:
|
||||
allowed_from.append(ip_network(network))
|
||||
else:
|
||||
allowed_from = [IPv4Network("0.0.0.0/0")]
|
||||
with tempfile.NamedTemporaryFile("w+", delete=True) as key_file:
|
||||
key_file.write(key_str.decode())
|
||||
key_file.flush()
|
||||
|
||||
registration_settings = ClientRegistrationSettings(
|
||||
enabled=True, allow_from=[IPv4Network("0.0.0.0/0")]
|
||||
enabled=registration_enabled,
|
||||
allow_from=allowed_from,
|
||||
)
|
||||
server = await run_ssh_server(
|
||||
backend=mock_backend,
|
||||
155
tests/packages/sshd/test_errors.py
Normal file
155
tests/packages/sshd/test_errors.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""Test various exceptions and error conditions."""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
import asyncssh
|
||||
import pytest
|
||||
|
||||
from .types import ClientRegistry, CommandRunner, ProcessRunner, SshServerFixture
|
||||
|
||||
|
||||
class BaseSshTests:
|
||||
"""Base test class."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def unregistered_client(self, username: str, port: int):
|
||||
"""Generate SSH session as an uregistered client."""
|
||||
private_key = asyncssh.generate_private_key("ssh-rsa")
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=username,
|
||||
client_keys=[private_key],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
@asynccontextmanager
|
||||
async def ssh_connection(
|
||||
self, username: str, port: int, private_key: asyncssh.SSHKey
|
||||
):
|
||||
"""Generate SSH session as a client with an ed25519 key."""
|
||||
# private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
conn = await asyncssh.connect(
|
||||
"127.0.0.1",
|
||||
port=port,
|
||||
username=username,
|
||||
client_keys=[private_key],
|
||||
known_hosts=None,
|
||||
)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
await conn.wait_closed()
|
||||
|
||||
|
||||
class TestRegistrationErrors(BaseSshTests):
|
||||
"""Test class for errors related to registartion."""
|
||||
|
||||
@pytest.mark.enable_registration(True)
|
||||
@pytest.mark.registration_sources("192.0.2.0/24")
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client_invalid_source(
|
||||
self, ssh_server: SshServerFixture
|
||||
) -> None:
|
||||
"""Test client registration from a network that's not permitted."""
|
||||
_, port = ssh_server
|
||||
with pytest.raises(asyncssh.misc.PermissionDenied):
|
||||
async with self.unregistered_client("stranger", port) as conn:
|
||||
async with conn.create_process("register") as process:
|
||||
stdout, stderr = process.collect_output()
|
||||
print(f"{stdout=!r}\n{stderr=!r}")
|
||||
if isinstance(stdout, str):
|
||||
assert "Enter public key" not in stdout
|
||||
result = await process.wait()
|
||||
assert result.exit_status == 1
|
||||
|
||||
@pytest.mark.enable_registration(True)
|
||||
@pytest.mark.registration_sources("127.0.0.1")
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_key_type(self, ssh_server: SshServerFixture) -> None:
|
||||
"""Test registration with an unsupported key."""
|
||||
_, port = ssh_server
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
public_key = private_key.export_public_key().decode().rstrip() + "\n"
|
||||
|
||||
async with self.ssh_connection("stranger", port, private_key) as conn:
|
||||
async with conn.create_process("register") as process:
|
||||
output = await process.stdout.readline()
|
||||
assert "Enter public key" in output
|
||||
stdout, stderr = await process.communicate(public_key)
|
||||
print(f"{stdout=!r}, {stderr=!r}")
|
||||
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
|
||||
result = await process.wait()
|
||||
assert result.exit_status == 1
|
||||
|
||||
@pytest.mark.enable_registration(True)
|
||||
@pytest.mark.registration_sources("127.0.0.1")
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_key(self, ssh_server: SshServerFixture) -> None:
|
||||
"""Test registration with a bogus string as key.."""
|
||||
_, port = ssh_server
|
||||
private_key = asyncssh.generate_private_key("ssh-ed25519")
|
||||
public_key = f"ssh-test {'A' * 544}\n"
|
||||
|
||||
async with self.ssh_connection("stranger", port, private_key) as conn:
|
||||
async with conn.create_process("register") as process:
|
||||
output = await process.stdout.readline()
|
||||
assert "Enter public key" in output
|
||||
stdout, stderr = await process.communicate(public_key)
|
||||
print(f"{stdout=!r}, {stderr=!r}")
|
||||
assert stderr == "Error: Invalid key type: Only RSA keys are supported."
|
||||
result = await process.wait()
|
||||
assert result.exit_status == 1
|
||||
|
||||
|
||||
class TestCommandErrors(BaseSshTests):
|
||||
"""Tests various errors around commands."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_command(
|
||||
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
"""Test sending an invalid command."""
|
||||
await client_registry["add_client"]("test")
|
||||
|
||||
result = await ssh_command_runner("test", "cat /etc/passwd")
|
||||
|
||||
assert result.exit_status == 1
|
||||
stderr = result.stderr or ""
|
||||
assert stderr == "Error: Unsupported command."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_command(
|
||||
self, ssh_server: SshServerFixture, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
"""Test sending no command."""
|
||||
await client_registry["add_client"]("test")
|
||||
_, port = ssh_server
|
||||
client_key = client_registry["clients"]["test"]
|
||||
async with self.ssh_connection("test", port, client_key.private_key) as conn:
|
||||
async with conn.create_process() as process:
|
||||
stdout, stderr = await process.communicate()
|
||||
print(f"{stdout=!r}, {stderr=!r}")
|
||||
assert stderr == "Error: No command was received from the client."
|
||||
result = await process.wait()
|
||||
assert result.exit_status == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_client_connection(
|
||||
self, ssh_command_runner: CommandRunner, client_registry: ClientRegistry
|
||||
) -> None:
|
||||
"""Test client that is not permitted to connect."""
|
||||
await client_registry["add_client"](
|
||||
"test-client",
|
||||
["mysecret"],
|
||||
["192.0.2.0/24"],
|
||||
)
|
||||
|
||||
with pytest.raises(asyncssh.misc.PermissionDenied):
|
||||
await ssh_command_runner("test-client", "get_secret mysecret")
|
||||
@ -5,6 +5,7 @@ import pytest
|
||||
from .types import ClientRegistry, CommandRunner, ProcessRunner
|
||||
|
||||
|
||||
@pytest.mark.enable_registration(True)
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_client(
|
||||
ssh_session: ProcessRunner,
|
||||
@ -31,6 +31,8 @@ class ClientKey:
|
||||
name: str
|
||||
private_key: asyncssh.SSHKey
|
||||
public_key: str
|
||||
policies: list[str] | None = None
|
||||
|
||||
|
||||
|
||||
class AddClientFun(Protocol):
|
||||
102
uv.lock
generated
102
uv.lock
generated
@ -6,7 +6,6 @@ members = [
|
||||
"sshecret",
|
||||
"sshecret-admin",
|
||||
"sshecret-backend",
|
||||
"sshecret-client",
|
||||
"sshecret-sshd",
|
||||
]
|
||||
|
||||
@ -213,6 +212,35 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b2/0b/ab3ce2b27dd74b6a6703065bd304ea8211ff4de3b1c304446ed95234177b/construct_typing-0.6.2-py3-none-any.whl", hash = "sha256:ebea6989ac622d0c4eb457092cef0c7bfbcfa110bd018670fea7064d0bc09e47", size = 23298 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/19/4f/2251e65033ed2ce1e68f00f91a0294e0f80c80ae8c3ebbe2f12828c4cd53/coverage-7.8.0.tar.gz", hash = "sha256:7a3d62b3b03b4b6fd41a085f3574874cf946cb4604d2b4d3e8dca8cd570ca501", size = 811872 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/21/87e9b97b568e223f3438d93072479c2f36cc9b3f6b9f7094b9d50232acc0/coverage-7.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ac46d0c2dd5820ce93943a501ac5f6548ea81594777ca585bf002aa8854cacd", size = 211708 },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/be/882d08b28a0d19c9c4c2e8a1c6ebe1f79c9c839eb46d4fca3bd3b34562b9/coverage-7.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:771eb7587a0563ca5bb6f622b9ed7f9d07bd08900f7589b4febff05f469bea00", size = 211981 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7a/1d/ce99612ebd58082fbe3f8c66f6d8d5694976c76a0d474503fa70633ec77f/coverage-7.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42421e04069fb2cbcbca5a696c4050b84a43b05392679d4068acbe65449b5c64", size = 245495 },
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/8d/6115abe97df98db6b2bd76aae395fcc941d039a7acd25f741312ced9a78f/coverage-7.8.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:554fec1199d93ab30adaa751db68acec2b41c5602ac944bb19187cb9a41a8067", size = 242538 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/74/2f8cc196643b15bc096d60e073691dadb3dca48418f08bc78dd6e899383e/coverage-7.8.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aaeb00761f985007b38cf463b1d160a14a22c34eb3f6a39d9ad6fc27cb73008", size = 244561 },
|
||||
{ url = "https://files.pythonhosted.org/packages/22/70/c10c77cd77970ac965734fe3419f2c98665f6e982744a9bfb0e749d298f4/coverage-7.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:581a40c7b94921fffd6457ffe532259813fc68eb2bdda60fa8cc343414ce3733", size = 244633 },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/5a/4f7569d946a07c952688debee18c2bb9ab24f88027e3d71fd25dbc2f9dca/coverage-7.8.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f319bae0321bc838e205bf9e5bc28f0a3165f30c203b610f17ab5552cff90323", size = 242712 },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/a1/03a43b33f50475a632a91ea8c127f7e35e53786dbe6781c25f19fd5a65f8/coverage-7.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04bfec25a8ef1c5f41f5e7e5c842f6b615599ca8ba8391ec33a9290d9d2db3a3", size = 244000 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/89/ab6c43b1788a3128e4d1b7b54214548dcad75a621f9d277b14d16a80d8a1/coverage-7.8.0-cp313-cp313-win32.whl", hash = "sha256:dd19608788b50eed889e13a5d71d832edc34fc9dfce606f66e8f9f917eef910d", size = 214195 },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/12/6bf5f9a8b063d116bac536a7fb594fc35cb04981654cccb4bbfea5dcdfa0/coverage-7.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:a9abbccd778d98e9c7e85038e35e91e67f5b520776781d9a1e2ee9d400869487", size = 214998 },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/e6/1e9df74ef7a1c983a9c7443dac8aac37a46f1939ae3499424622e72a6f78/coverage-7.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:18c5ae6d061ad5b3e7eef4363fb27a0576012a7447af48be6c75b88494c6cf25", size = 212541 },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/51/c32174edb7ee49744e2e81c4b1414ac9df3dacfcb5b5f273b7f285ad43f6/coverage-7.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:95aa6ae391a22bbbce1b77ddac846c98c5473de0372ba5c463480043a07bff42", size = 212767 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/8f/f454cbdb5212f13f29d4a7983db69169f1937e869a5142bce983ded52162/coverage-7.8.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e013b07ba1c748dacc2a80e69a46286ff145935f260eb8c72df7185bf048f502", size = 256997 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/74/2bf9e78b321216d6ee90a81e5c22f912fc428442c830c4077b4a071db66f/coverage-7.8.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d766a4f0e5aa1ba056ec3496243150698dc0481902e2b8559314368717be82b1", size = 252708 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/4d/50d7eb1e9a6062bee6e2f92e78b0998848a972e9afad349b6cdde6fa9e32/coverage-7.8.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad80e6b4a0c3cb6f10f29ae4c60e991f424e6b14219d46f1e7d442b938ee68a4", size = 255046 },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/9e/71fb4e7402a07c4198ab44fc564d09d7d0ffca46a9fb7b0a7b929e7641bd/coverage-7.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b87eb6fc9e1bb8f98892a2458781348fa37e6925f35bb6ceb9d4afd54ba36c73", size = 256139 },
|
||||
{ url = "https://files.pythonhosted.org/packages/49/1a/78d37f7a42b5beff027e807c2843185961fdae7fe23aad5a4837c93f9d25/coverage-7.8.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d1ba00ae33be84066cfbe7361d4e04dec78445b2b88bdb734d0d1cbab916025a", size = 254307 },
|
||||
{ url = "https://files.pythonhosted.org/packages/58/e9/8fb8e0ff6bef5e170ee19d59ca694f9001b2ec085dc99b4f65c128bb3f9a/coverage-7.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f3c38e4e5ccbdc9198aecc766cedbb134b2d89bf64533973678dfcf07effd883", size = 255116 },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/b0/d968ecdbe6fe0a863de7169bbe9e8a476868959f3af24981f6a10d2b6924/coverage-7.8.0-cp313-cp313t-win32.whl", hash = "sha256:379fe315e206b14e21db5240f89dc0774bdd3e25c3c58c2c733c99eca96f1ada", size = 214909 },
|
||||
{ url = "https://files.pythonhosted.org/packages/87/e9/d6b7ef9fecf42dfb418d93544af47c940aa83056c49e6021a564aafbc91f/coverage-7.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:2e4b6b87bb0c846a9315e3ab4be2d52fac905100565f4b92f02c445c8799e257", size = 216068 },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/f1/4da7717f0063a222db253e7121bd6a56f6fb1ba439dcc36659088793347c/coverage-7.8.0-py3-none-any.whl", hash = "sha256:dbf364b4c5e7bae9250528167dfe40219b62e2d573c854d74be213e1e52069f7", size = 203435 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "44.0.2"
|
||||
@ -554,20 +582,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paramiko"
|
||||
version = "3.5.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "bcrypt" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "pynacl" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7d/15/ad6ce226e8138315f2451c2aeea985bf35ee910afb477bae7477dc3a8f3b/paramiko-3.5.1.tar.gz", hash = "sha256:b2c665bc45b2b215bd7d7f039901b14b067da00f3a11e6640995fd58f2664822", size = 1566110 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/15/f8/c7bd0ef12954a81a1d3cea60a13946bd9a49a0036a5927770c461eade7ae/paramiko-3.5.1-py3-none-any.whl", hash = "sha256:43b9a0501fc2b5e70680388d9346cf252cfb7d00b0667c39e80eb43a408b8f61", size = 227298 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "passlib"
|
||||
version = "1.7.4"
|
||||
@ -726,26 +740,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/04/6cf0687780c68e7fb0525e7210ec5477987c0481904f600c2e5d81bbb7dd/pykeepass-4.1.1.post1-py3-none-any.whl", hash = "sha256:4cfd54f376cb1f58dd8f11fbe7923282bc7dd97ffdf1bb622004a6e718bfe379", size = 55584 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pynacl"
|
||||
version = "1.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "cffi" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a7/22/27582568be639dfe22ddb3902225f91f2f17ceff88ce80e4db396c8986da/PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba", size = 3392854 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ce/75/0b8ede18506041c0bf23ac4d8e2971b4161cd6ce630b177d0a08eb0d8857/PyNaCl-1.5.0-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:401002a4aaa07c9414132aaed7f6836ff98f59277a234704ff66878c2ee4a0d1", size = 349920 },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/bb/fddf10acd09637327a97ef89d2a9d621328850a72f1fdc8c08bdf72e385f/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:52cb72a79269189d4e0dc537556f4740f7f0a9ec41c1322598799b0bdad4ef92", size = 601722 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5d/70/87a065c37cca41a75f2ce113a5a2c2aa7533be648b184ade58971b5f7ccc/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a36d4a9dda1f19ce6e03c9a784a2921a4b726b02e1c736600ca9c22029474394", size = 680087 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/87/f1bb6a595f14a327e8285b9eb54d41fef76c585a0edef0a45f6fc95de125/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:0c84947a22519e013607c9be43706dd42513f9e6ae5d39d3613ca1e142fba44d", size = 856678 },
|
||||
{ url = "https://files.pythonhosted.org/packages/66/28/ca86676b69bf9f90e710571b67450508484388bfce09acf8a46f0b8c785f/PyNaCl-1.5.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06b8f6fa7f5de8d5d2f7573fe8c863c051225a27b61e6860fd047b1775807858", size = 1133660 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/85/c262db650e86812585e2bc59e497a8f59948a005325a11bbbc9ecd3fe26b/PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:a422368fc821589c228f4c49438a368831cb5bbc0eab5ebe1d7fac9dded6567b", size = 663824 },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/1a/cc308a884bd299b651f1633acb978e8596c71c33ca85e9dc9fa33a5399b9/PyNaCl-1.5.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:61f642bf2378713e2c2e1de73444a3778e5f0a38be6fee0fe532fe30060282ff", size = 1117912 },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/2d/b7df6ddb0c2a33afdb358f8af6ea3b8c4d1196ca45497dd37a56f0c122be/PyNaCl-1.5.0-cp36-abi3-win32.whl", hash = "sha256:e46dae94e34b085175f8abb3b0aaa7da40767865ac82c928eeb9e57e1ea8a543", size = 204624 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/22/d3db169895faaf3e2eda892f005f433a62db2decbcfbc2f61e6517adfa87/PyNaCl-1.5.0-cp36-abi3-win_amd64.whl", hash = "sha256:20f42270d27e1b6a29f54032090b972d97f0a1b0948cc52392041ef7831fee93", size = 212141 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyotp"
|
||||
version = "2.9.0"
|
||||
@ -791,6 +785,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/7f/338843f449ace853647ace35870874f69a764d251872ed1b4de9f234822c/pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0", size = 19694 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "6.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "coverage" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/25/69/5f1e57f6c5a39f81411b550027bf72842c4567ff5fd572bed1edc9e4b5d9/pytest_cov-6.1.1.tar.gz", hash = "sha256:46935f7aaefba760e716c2ebfbe1c216240b9592966e7da99ea8292d4d3e2a0a", size = 66857 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/28/d0/def53b4a790cfb21483016430ed828f64830dd981ebe1089971cd10cab25/pytest_cov-6.1.1-py3-none-any.whl", hash = "sha256:bddf29ed2d0ab6f4df17b4c55b0a657287db8684af9c42ea546b21b1041b3dde", size = 23841 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dotenv"
|
||||
version = "1.0.1"
|
||||
@ -950,8 +957,10 @@ dev = [
|
||||
{ name = "python-dotenv" },
|
||||
]
|
||||
test = [
|
||||
{ name = "coverage" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "robotframework" },
|
||||
]
|
||||
|
||||
@ -979,8 +988,10 @@ dev = [
|
||||
{ name = "python-dotenv", specifier = ">=1.0.1" },
|
||||
]
|
||||
test = [
|
||||
{ name = "coverage", specifier = ">=7.8.0" },
|
||||
{ name = "pytest", specifier = ">=8.3.5" },
|
||||
{ name = "pytest-asyncio", specifier = ">=0.26.0" },
|
||||
{ name = "pytest-cov", specifier = ">=6.1.1" },
|
||||
{ name = "robotframework", specifier = ">=7.2.2" },
|
||||
]
|
||||
|
||||
@ -1056,27 +1067,6 @@ requires-dist = [
|
||||
{ name = "sshecret", editable = "." },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sshecret-client"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "packages/sshecret_client" }
|
||||
dependencies = [
|
||||
{ name = "asyncssh" },
|
||||
{ name = "click" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "paramiko" },
|
||||
{ name = "sshecret" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "asyncssh", specifier = ">=2.20.0" },
|
||||
{ name = "click", specifier = ">=8.1.8" },
|
||||
{ name = "cryptography", specifier = ">=44.0.2" },
|
||||
{ name = "paramiko", specifier = ">=3.5.1" },
|
||||
{ name = "sshecret", editable = "." },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sshecret-sshd"
|
||||
version = "0.1.0"
|
||||
|
||||
Reference in New Issue
Block a user