Check in current backend

This commit is contained in:
2025-05-04 09:20:11 +02:00
parent 15952c5dd2
commit 3719a2611d
10 changed files with 93 additions and 95 deletions

View File

@ -11,7 +11,6 @@ from typing import Annotated
from sshecret_backend.models import AuditLog
from sshecret_backend.types import DBSessionDep
from sshecret_backend import audit
from sshecret_backend.view_models import AuditInfo

View File

@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from sqlmodel import Session, select
from typing import Annotated
from sshecret_backend.models import Client, ClientAccessPolicy
from sshecret_backend.models import ClientAccessPolicy
from sshecret_backend.view_models import (
ClientPolicyView,
ClientPolicyUpdate,

View File

@ -6,12 +6,12 @@ from pathlib import Path
from typing import cast
from dotenv import load_dotenv
import click
from sqlmodel import Session, create_engine, select
from sqlmodel import Session, col, func, select
import uvicorn
from .db import create_api_token
from .db import get_engine, create_api_token
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
from .settings import BackendSettings
DEFAULT_LISTEN = "127.0.0.1"
@ -21,6 +21,23 @@ WORKDIR = Path(os.getcwd())
load_dotenv()
def generate_token(settings: BackendSettings) -> str:
"""Generate a token."""
engine = get_engine(settings.db_url)
init_db(engine)
with Session(engine) as session:
token = create_api_token(session, True)
return token
def count_tokens(settings: BackendSettings) -> int:
"""Count the amount of tokens created."""
engine = get_engine(settings.db_url)
init_db(engine)
with Session(engine) as session:
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
return count
@click.group()
@click.option("--database", help="Path to the sqlite database file.")
@ -28,11 +45,19 @@ load_dotenv()
def cli(ctx: click.Context, database: str) -> None:
"""CLI group."""
if database:
# Hopefully it's enough to set the environment variable as so.
settings = BackendSettings(db_url=f"sqlite:///{Path(database).absolute()}")
settings = BackendSettings(database=str(Path(database).absolute()))
else:
settings = BackendSettings()
if settings.generate_initial_tokens:
if count_tokens(settings) == 0:
click.echo("Creating initial tokens for admin and sshd.")
admin_token = generate_token(settings)
sshd_token = generate_token(settings)
click.echo(f"Admin token: {admin_token}")
click.echo(f"SSHD token: {sshd_token}")
ctx.obj = settings
@ -41,9 +66,7 @@ def cli(ctx: click.Context, database: str) -> None:
def cli_generate_token(ctx: click.Context) -> None:
"""Generate a token."""
settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url)
with Session(engine) as session:
token = create_api_token(session, True)
token = generate_token(settings)
click.echo("Generated api token:")
click.echo(token)
@ -61,7 +84,7 @@ def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
def cli_repl(ctx: click.Context) -> None:
"""Run an interactive console."""
settings = cast(BackendSettings, ctx.obj)
engine = create_engine(settings.db_url)
engine = get_engine(settings.db_url, True)
with Session(engine) as session:
locals = {

View File

@ -34,9 +34,8 @@ def setup_database(
return engine, get_db_session
def get_engine(filename: Path, echo: bool = False) -> Engine:
def get_engine(url: URL, echo: bool = False) -> Engine:
"""Initialize the engine."""
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
engine = create_engine(url, echo=echo)
with engine.connect() as connection:
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only

View File

@ -11,7 +11,7 @@ import logging
import uuid
from datetime import datetime
import sqlalchemy as sa
from sqlmodel import Field, Relationship, SQLModel
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
LOG = logging.getLogger(__name__)
@ -98,14 +98,13 @@ class AuditLog(SQLModel, table=True):
"""
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
subsystem: str | None = None
object: str | None = None
object_id: str | None = None
subsystem: str
message: str
operation: str
client_id: uuid.UUID | None = None
client_name: str | None = None
message: str
origin: str | None = None
Field(default=None, sa_column=Column(JSON))
timestamp: datetime | None = Field(
default=None,
@ -131,5 +130,5 @@ class APIClient(SQLModel, table=True):
def init_db(engine: sa.Engine) -> None:
"""Create database."""
LOG.info("Starting init_db")
LOG.info("Running init_db")
SQLModel.metadata.create_all(engine)

View File

@ -1,21 +1,49 @@
"""Settings management."""
from pydantic import Field
from pathlib import Path
from typing import Annotated, Any
from pydantic import Field, field_validator
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
ForceDecode,
)
from sqlalchemy import URL
DEFAULT_DATABASE = "sqlite:///sshecret.db"
DEFAULT_DATABASE = "sshecret.db"
class BackendSettings(BaseSettings):
"""Backend settings."""
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
model_config = SettingsConfigDict(
env_file=".backend.env", env_prefix="sshecret_backend_"
)
db_url: str = Field(default=DEFAULT_DATABASE)
database: str = Field(default=DEFAULT_DATABASE)
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
@field_validator("generate_initial_tokens", mode="before")
@classmethod
def cast_bool(cls, value: Any) -> bool:
"""Ensure we catch the boolean."""
if isinstance(value, str):
if value.lower() in ("1", "true", "on"):
return True
if value.lower() in ("0", "false", "off"):
return False
return bool(value)
@property
def db_url(self) -> URL:
"""Construct database url."""
return URL.create(drivername="sqlite", database=self.database)
@property
def db_exists(self) -> bool:
"""Check if databatase exists."""
return Path(self.database).exists()
def get_settings() -> BackendSettings: