Check in current backend

This commit is contained in:
2025-05-04 09:20:11 +02:00
parent 15952c5dd2
commit 3719a2611d
10 changed files with 93 additions and 95 deletions

View File

@ -1,11 +1,20 @@
import os
from logging.config import fileConfig from logging.config import fileConfig
from sqlalchemy import engine_from_config from sqlalchemy import engine_from_config
from sqlalchemy import pool from sqlalchemy import pool
from sqlmodel import create_engine
from alembic import context from alembic import context
from sshecret_backend.models import * from sshecret_backend.models import *
def get_database_url() -> str:
"""Get database URL."""
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
return f"sqlite:///{db_file}"
return "sqlite:///sshecret.db"
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
config = context.config config = context.config
@ -40,7 +49,7 @@ def run_migrations_offline() -> None:
script output. script output.
""" """
url = config.get_main_option("sqlalchemy.url") url = get_database_url()
context.configure( context.configure(
url=url, url=url,
target_metadata=target_metadata, target_metadata=target_metadata,
@ -59,11 +68,7 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
connectable = engine_from_config( connectable = create_engine(get_database_url())
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(

View File

@ -1,33 +0,0 @@
"""Initial model
Revision ID: a0befb5a74a0
Revises:
Create Date: 2025-04-28 21:18:59.069323
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = 'a0befb5a74a0'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###

View File

@ -1,33 +0,0 @@
"""Add subsystem to auditlog
Revision ID: f30e413c5757
Revises: a0befb5a74a0
Create Date: 2025-04-28 21:21:20.103423
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import sqlmodel
# revision identifiers, used by Alembic.
revision: str = 'f30e413c5757'
down_revision: Union[str, None] = 'a0befb5a74a0'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('auditlog', sa.Column('subsystem', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('auditlog', 'subsystem')
# ### end Alembic commands ###

View File

@ -14,8 +14,13 @@ dependencies = [
"pytest>=8.3.5", "pytest>=8.3.5",
"python-multipart>=0.0.20", "python-multipart>=0.0.20",
"sqlmodel>=0.0.24", "sqlmodel>=0.0.24",
"sshecret",
] ]
[tool.uv.sources]
sshecret = { workspace = true }
[project.scripts] [project.scripts]
sshecret-backend = "sshecret_backend.cli:cli" sshecret-backend = "sshecret_backend.cli:cli"
@ -26,3 +31,9 @@ build-backend = "hatchling.build"
[tool.pytest.ini_options] [tool.pytest.ini_options]
log_cli = true log_cli = true
log_cli_level = "INFO" log_cli_level = "INFO"
[tool.pyright]
venvPath = "../.."
venv = ".venv"
strict = ["**/*.py"]
pythonVersion = "3.13"

View File

@ -11,7 +11,6 @@ from typing import Annotated
from sshecret_backend.models import AuditLog from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from sshecret_backend.view_models import AuditInfo from sshecret_backend.view_models import AuditInfo

View File

@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select from sqlmodel import Session, select
from typing import Annotated from typing import Annotated
from sshecret_backend.models import Client, ClientAccessPolicy from sshecret_backend.models import ClientAccessPolicy
from sshecret_backend.view_models import ( from sshecret_backend.view_models import (
ClientPolicyView, ClientPolicyView,
ClientPolicyUpdate, ClientPolicyUpdate,

View File

@ -6,12 +6,12 @@ from pathlib import Path
from typing import cast from typing import cast
from dotenv import load_dotenv from dotenv import load_dotenv
import click import click
from sqlmodel import Session, create_engine, select from sqlmodel import Session, col, func, select
import uvicorn import uvicorn
from .db import create_api_token from .db import get_engine, create_api_token
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
from .settings import BackendSettings from .settings import BackendSettings
DEFAULT_LISTEN = "127.0.0.1" DEFAULT_LISTEN = "127.0.0.1"
@ -21,6 +21,23 @@ WORKDIR = Path(os.getcwd())
load_dotenv() load_dotenv()
def generate_token(settings: BackendSettings) -> str:
"""Generate a token."""
engine = get_engine(settings.db_url)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
return token
def count_tokens(settings: BackendSettings) -> int:
"""Count the amount of tokens created."""
engine = get_engine(settings.db_url)
init_db(engine)
with Session(engine) as session:
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
return count
@click.group() @click.group()
@click.option("--database", help="Path to the sqlite database file.") @click.option("--database", help="Path to the sqlite database file.")
@ -28,11 +45,19 @@ load_dotenv()
def cli(ctx: click.Context, database: str) -> None: def cli(ctx: click.Context, database: str) -> None:
"""CLI group.""" """CLI group."""
if database: if database:
# Hopefully it's enough to set the environment variable as so. settings = BackendSettings(database=str(Path(database).absolute()))
settings = BackendSettings(db_url=f"sqlite:///{Path(database).absolute()}")
else: else:
settings = BackendSettings() settings = BackendSettings()
if settings.generate_initial_tokens:
if count_tokens(settings) == 0:
click.echo("Creating initial tokens for admin and sshd.")
admin_token = generate_token(settings)
sshd_token = generate_token(settings)
click.echo(f"Admin token: {admin_token}")
click.echo(f"SSHD token: {sshd_token}")
ctx.obj = settings ctx.obj = settings
@ -41,9 +66,7 @@ def cli(ctx: click.Context, database: str) -> None:
def cli_generate_token(ctx: click.Context) -> None: def cli_generate_token(ctx: click.Context) -> None:
"""Generate a token.""" """Generate a token."""
settings = cast(BackendSettings, ctx.obj) settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url) token = generate_token(settings)
with Session(engine) as session:
token = create_api_token(session, True)
click.echo("Generated api token:") click.echo("Generated api token:")
click.echo(token) click.echo(token)
@ -61,7 +84,7 @@ def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
def cli_repl(ctx: click.Context) -> None: def cli_repl(ctx: click.Context) -> None:
"""Run an interactive console.""" """Run an interactive console."""
settings = cast(BackendSettings, ctx.obj) settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url) engine = get_engine(settings.db_url, True)
with Session(engine) as session: with Session(engine) as session:
locals = { locals = {

View File

@ -34,9 +34,8 @@ def setup_database(
return engine, get_db_session return engine, get_db_session
def get_engine(filename: Path, echo: bool = False) -> Engine: def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine.""" """Initialize the engine."""
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
engine = create_engine(url, echo=echo) engine = create_engine(url, echo=echo)
with engine.connect() as connection: with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only

View File

@ -11,7 +11,7 @@ import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
import sqlalchemy as sa import sqlalchemy as sa
from sqlmodel import Field, Relationship, SQLModel from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -98,14 +98,13 @@ class AuditLog(SQLModel, table=True):
""" """
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
subsystem: str | None = None subsystem: str
object: str | None = None message: str
object_id: str | None = None
operation: str operation: str
client_id: uuid.UUID | None = None client_id: uuid.UUID | None = None
client_name: str | None = None client_name: str | None = None
message: str
origin: str | None = None origin: str | None = None
Field(default=None, sa_column=Column(JSON))
timestamp: datetime | None = Field( timestamp: datetime | None = Field(
default=None, default=None,
@ -131,5 +130,5 @@ class APIClient(SQLModel, table=True):
def init_db(engine: sa.Engine) -> None: def init_db(engine: sa.Engine) -> None:
"""Create database.""" """Create database."""
LOG.info("Starting init_db") LOG.info("Running init_db")
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)

View File

@ -1,21 +1,49 @@
"""Settings management.""" """Settings management."""
from pydantic import Field from pathlib import Path
from typing import Annotated, Any
from pydantic import Field, field_validator
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
SettingsConfigDict, SettingsConfigDict,
ForceDecode,
) )
from sqlalchemy import URL
DEFAULT_DATABASE = "sqlite:///sshecret.db" DEFAULT_DATABASE = "sshecret.db"
class BackendSettings(BaseSettings): class BackendSettings(BaseSettings):
"""Backend settings.""" """Backend settings."""
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_") model_config = SettingsConfigDict(
env_file=".backend.env", env_prefix="sshecret_backend_"
)
db_url: str = Field(default=DEFAULT_DATABASE) database: str = Field(default=DEFAULT_DATABASE)
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
@field_validator("generate_initial_tokens", mode="before")
@classmethod
def cast_bool(cls, value: Any) -> bool:
"""Ensure we catch the boolean."""
if isinstance(value, str):
if value.lower() in ("1", "true", "on"):
return True
if value.lower() in ("0", "false", "off"):
return False
return bool(value)
@property
def db_url(self) -> URL:
"""Construct database url."""
return URL.create(drivername="sqlite", database=self.database)
@property
def db_exists(self) -> bool:
"""Check if databatase exists."""
return Path(self.database).exists()
def get_settings() -> BackendSettings: def get_settings() -> BackendSettings: