Check in current backend
This commit is contained in:
@ -1,11 +1,20 @@
|
|||||||
|
import os
|
||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import engine_from_config
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
|
from sqlmodel import create_engine
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sshecret_backend.models import *
|
from sshecret_backend.models import *
|
||||||
|
|
||||||
|
def get_database_url() -> str:
|
||||||
|
"""Get database URL."""
|
||||||
|
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||||
|
return f"sqlite:///{db_file}"
|
||||||
|
return "sqlite:///sshecret.db"
|
||||||
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
config = context.config
|
config = context.config
|
||||||
@ -40,7 +49,7 @@ def run_migrations_offline() -> None:
|
|||||||
script output.
|
script output.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
url = get_database_url()
|
||||||
context.configure(
|
context.configure(
|
||||||
url=url,
|
url=url,
|
||||||
target_metadata=target_metadata,
|
target_metadata=target_metadata,
|
||||||
@ -59,11 +68,7 @@ def run_migrations_online() -> None:
|
|||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
connectable = engine_from_config(
|
connectable = create_engine(get_database_url())
|
||||||
config.get_section(config.config_ini_section, {}),
|
|
||||||
prefix="sqlalchemy.",
|
|
||||||
poolclass=pool.NullPool,
|
|
||||||
)
|
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
|
|||||||
@ -1,33 +0,0 @@
|
|||||||
"""Initial model
|
|
||||||
|
|
||||||
Revision ID: a0befb5a74a0
|
|
||||||
Revises:
|
|
||||||
Create Date: 2025-04-28 21:18:59.069323
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import sqlmodel
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'a0befb5a74a0'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
pass
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
"""Add subsystem to auditlog
|
|
||||||
|
|
||||||
Revision ID: f30e413c5757
|
|
||||||
Revises: a0befb5a74a0
|
|
||||||
Create Date: 2025-04-28 21:21:20.103423
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
import sqlmodel
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'f30e413c5757'
|
|
||||||
down_revision: Union[str, None] = 'a0befb5a74a0'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.add_column('auditlog', sa.Column('subsystem', sqlmodel.sql.sqltypes.AutoString(), nullable=True))
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_column('auditlog', 'subsystem')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
@ -14,8 +14,13 @@ dependencies = [
|
|||||||
"pytest>=8.3.5",
|
"pytest>=8.3.5",
|
||||||
"python-multipart>=0.0.20",
|
"python-multipart>=0.0.20",
|
||||||
"sqlmodel>=0.0.24",
|
"sqlmodel>=0.0.24",
|
||||||
|
"sshecret",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
sshecret = { workspace = true }
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
sshecret-backend = "sshecret_backend.cli:cli"
|
sshecret-backend = "sshecret_backend.cli:cli"
|
||||||
|
|
||||||
@ -26,3 +31,9 @@ build-backend = "hatchling.build"
|
|||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
log_cli = true
|
log_cli = true
|
||||||
log_cli_level = "INFO"
|
log_cli_level = "INFO"
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
venvPath = "../.."
|
||||||
|
venv = ".venv"
|
||||||
|
strict = ["**/*.py"]
|
||||||
|
pythonVersion = "3.13"
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from typing import Annotated
|
|||||||
|
|
||||||
from sshecret_backend.models import AuditLog
|
from sshecret_backend.models import AuditLog
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import DBSessionDep
|
||||||
from sshecret_backend import audit
|
|
||||||
from sshecret_backend.view_models import AuditInfo
|
from sshecret_backend.view_models import AuditInfo
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import Client, ClientAccessPolicy
|
from sshecret_backend.models import ClientAccessPolicy
|
||||||
from sshecret_backend.view_models import (
|
from sshecret_backend.view_models import (
|
||||||
ClientPolicyView,
|
ClientPolicyView,
|
||||||
ClientPolicyUpdate,
|
ClientPolicyUpdate,
|
||||||
|
|||||||
@ -6,12 +6,12 @@ from pathlib import Path
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import click
|
import click
|
||||||
from sqlmodel import Session, create_engine, select
|
from sqlmodel import Session, col, func, select
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
from .db import create_api_token
|
from .db import get_engine, create_api_token
|
||||||
|
|
||||||
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient
|
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
|
||||||
from .settings import BackendSettings
|
from .settings import BackendSettings
|
||||||
|
|
||||||
DEFAULT_LISTEN = "127.0.0.1"
|
DEFAULT_LISTEN = "127.0.0.1"
|
||||||
@ -21,6 +21,23 @@ WORKDIR = Path(os.getcwd())
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
def generate_token(settings: BackendSettings) -> str:
|
||||||
|
"""Generate a token."""
|
||||||
|
engine = get_engine(settings.db_url)
|
||||||
|
init_db(engine)
|
||||||
|
with Session(engine) as session:
|
||||||
|
token = create_api_token(session, True)
|
||||||
|
return token
|
||||||
|
|
||||||
|
def count_tokens(settings: BackendSettings) -> int:
|
||||||
|
"""Count the amount of tokens created."""
|
||||||
|
engine = get_engine(settings.db_url)
|
||||||
|
init_db(engine)
|
||||||
|
with Session(engine) as session:
|
||||||
|
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.option("--database", help="Path to the sqlite database file.")
|
@click.option("--database", help="Path to the sqlite database file.")
|
||||||
@ -28,11 +45,19 @@ load_dotenv()
|
|||||||
def cli(ctx: click.Context, database: str) -> None:
|
def cli(ctx: click.Context, database: str) -> None:
|
||||||
"""CLI group."""
|
"""CLI group."""
|
||||||
if database:
|
if database:
|
||||||
# Hopefully it's enough to set the environment variable as so.
|
settings = BackendSettings(database=str(Path(database).absolute()))
|
||||||
settings = BackendSettings(db_url=f"sqlite:///{Path(database).absolute()}")
|
|
||||||
else:
|
else:
|
||||||
settings = BackendSettings()
|
settings = BackendSettings()
|
||||||
|
|
||||||
|
|
||||||
|
if settings.generate_initial_tokens:
|
||||||
|
if count_tokens(settings) == 0:
|
||||||
|
click.echo("Creating initial tokens for admin and sshd.")
|
||||||
|
admin_token = generate_token(settings)
|
||||||
|
sshd_token = generate_token(settings)
|
||||||
|
click.echo(f"Admin token: {admin_token}")
|
||||||
|
click.echo(f"SSHD token: {sshd_token}")
|
||||||
|
|
||||||
ctx.obj = settings
|
ctx.obj = settings
|
||||||
|
|
||||||
|
|
||||||
@ -41,9 +66,7 @@ def cli(ctx: click.Context, database: str) -> None:
|
|||||||
def cli_generate_token(ctx: click.Context) -> None:
|
def cli_generate_token(ctx: click.Context) -> None:
|
||||||
"""Generate a token."""
|
"""Generate a token."""
|
||||||
settings = cast(BackendSettings, ctx.obj)
|
settings = cast(BackendSettings, ctx.obj)
|
||||||
engine = create_engine(settings.db_url)
|
token = generate_token(settings)
|
||||||
with Session(engine) as session:
|
|
||||||
token = create_api_token(session, True)
|
|
||||||
click.echo("Generated api token:")
|
click.echo("Generated api token:")
|
||||||
click.echo(token)
|
click.echo(token)
|
||||||
|
|
||||||
@ -61,7 +84,7 @@ def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
|||||||
def cli_repl(ctx: click.Context) -> None:
|
def cli_repl(ctx: click.Context) -> None:
|
||||||
"""Run an interactive console."""
|
"""Run an interactive console."""
|
||||||
settings = cast(BackendSettings, ctx.obj)
|
settings = cast(BackendSettings, ctx.obj)
|
||||||
engine = create_engine(settings.db_url)
|
engine = get_engine(settings.db_url, True)
|
||||||
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
locals = {
|
locals = {
|
||||||
|
|||||||
@ -34,9 +34,8 @@ def setup_database(
|
|||||||
return engine, get_db_session
|
return engine, get_db_session
|
||||||
|
|
||||||
|
|
||||||
def get_engine(filename: Path, echo: bool = False) -> Engine:
|
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||||
"""Initialize the engine."""
|
"""Initialize the engine."""
|
||||||
url = URL.create(drivername="sqlite", database=str(filename.absolute()))
|
|
||||||
engine = create_engine(url, echo=echo)
|
engine = create_engine(url, echo=echo)
|
||||||
with engine.connect() as connection:
|
with engine.connect() as connection:
|
||||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlmodel import Field, Relationship, SQLModel
|
from sqlmodel import JSON, Column, DateTime, Field, Relationship, SQLModel
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -98,14 +98,13 @@ class AuditLog(SQLModel, table=True):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||||
subsystem: str | None = None
|
subsystem: str
|
||||||
object: str | None = None
|
message: str
|
||||||
object_id: str | None = None
|
|
||||||
operation: str
|
operation: str
|
||||||
client_id: uuid.UUID | None = None
|
client_id: uuid.UUID | None = None
|
||||||
client_name: str | None = None
|
client_name: str | None = None
|
||||||
message: str
|
|
||||||
origin: str | None = None
|
origin: str | None = None
|
||||||
|
Field(default=None, sa_column=Column(JSON))
|
||||||
|
|
||||||
timestamp: datetime | None = Field(
|
timestamp: datetime | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -131,5 +130,5 @@ class APIClient(SQLModel, table=True):
|
|||||||
|
|
||||||
def init_db(engine: sa.Engine) -> None:
|
def init_db(engine: sa.Engine) -> None:
|
||||||
"""Create database."""
|
"""Create database."""
|
||||||
LOG.info("Starting init_db")
|
LOG.info("Running init_db")
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|||||||
@ -1,21 +1,49 @@
|
|||||||
"""Settings management."""
|
"""Settings management."""
|
||||||
|
|
||||||
from pydantic import Field
|
from pathlib import Path
|
||||||
|
from typing import Annotated, Any
|
||||||
|
from pydantic import Field, field_validator
|
||||||
from pydantic_settings import (
|
from pydantic_settings import (
|
||||||
BaseSettings,
|
BaseSettings,
|
||||||
SettingsConfigDict,
|
SettingsConfigDict,
|
||||||
|
ForceDecode,
|
||||||
)
|
)
|
||||||
|
from sqlalchemy import URL
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DATABASE = "sqlite:///sshecret.db"
|
DEFAULT_DATABASE = "sshecret.db"
|
||||||
|
|
||||||
|
|
||||||
class BackendSettings(BaseSettings):
|
class BackendSettings(BaseSettings):
|
||||||
"""Backend settings."""
|
"""Backend settings."""
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".backend.env", env_prefix="sshecret_")
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".backend.env", env_prefix="sshecret_backend_"
|
||||||
|
)
|
||||||
|
|
||||||
db_url: str = Field(default=DEFAULT_DATABASE)
|
database: str = Field(default=DEFAULT_DATABASE)
|
||||||
|
generate_initial_tokens: Annotated[bool, ForceDecode] = Field(default=False)
|
||||||
|
|
||||||
|
@field_validator("generate_initial_tokens", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def cast_bool(cls, value: Any) -> bool:
|
||||||
|
"""Ensure we catch the boolean."""
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value.lower() in ("1", "true", "on"):
|
||||||
|
return True
|
||||||
|
if value.lower() in ("0", "false", "off"):
|
||||||
|
return False
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_url(self) -> URL:
|
||||||
|
"""Construct database url."""
|
||||||
|
return URL.create(drivername="sqlite", database=self.database)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_exists(self) -> bool:
|
||||||
|
"""Check if databatase exists."""
|
||||||
|
return Path(self.database).exists()
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> BackendSettings:
|
def get_settings() -> BackendSettings:
|
||||||
|
|||||||
Reference in New Issue
Block a user