Compare commits

...

3 Commits

Author SHA1 Message Date
f10ae027e5 Add alembic migrations 2025-05-18 22:20:01 +02:00
b8cae28888 Change name of default database 2025-05-18 22:19:49 +02:00
a0adf281b5 Migrate from sqlmodel to pure sqlalchemy 2025-05-18 22:13:07 +02:00
18 changed files with 350 additions and 53 deletions

View File

@ -0,0 +1,119 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = sqlite:///sshecret_admin.db
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

View File

@ -0,0 +1 @@
Generic single-database configuration.

View File

@ -0,0 +1,86 @@
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from sshecret_admin.auth.models import Base
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
def get_database_url() -> str | None:
"""Get database URL."""
if db_file := os.getenv("SSHECRET_ADMIN_DATABASE"):
return f"sqlite:///{db_file}"
return config.get_main_option("sqlalchemy.url")
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = get_database_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,47 @@
"""Create initial migration
Revision ID: 2a5a599271aa
Revises:
Create Date: 2025-05-18 22:19:03.739902
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '2a5a599271aa'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('password_db',
sa.Column('id', sa.INTEGER(), nullable=False),
sa.Column('encrypted_password', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('username', sa.String(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False),
sa.Column('disabled', sa.BOOLEAN(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('user')
op.drop_table('password_db')
# ### end Alembic commands ###

View File

@ -5,7 +5,7 @@ import logging
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from sqlalchemy.orm import Session
from sshecret_admin.auth import Token, authenticate_user, create_access_token
from sshecret_admin.core.dependencies import AdminDependencies

View File

@ -8,7 +8,8 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_admin.services.admin_backend import AdminBackend
from sshecret_admin.core.dependencies import BaseDependencies, AdminDependencies
@ -40,7 +41,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
if not token_data:
raise credentials_exception
user = session.exec(
user = session.scalars(
select(User).where(User.username == token_data.username)
).first()
if not user:
@ -59,7 +60,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
session: Annotated[Session, Depends(dependencies.get_db_session)]
):
"""Get admin backend API."""
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not password_db:
raise HTTPException(
500, detail="Error: The password manager has not yet been set up."

View File

@ -6,9 +6,11 @@ from typing import cast, Any
import bcrypt
import jwt
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_admin.core.settings import AdminServerSettings
from .models import User, TokenData
from .exceptions import AuthenticationFailedError
@ -72,7 +74,7 @@ def check_password(plain_password: str, hashed_password: str) -> None:
def authenticate_user(session: Session, username: str, password: str) -> User | None:
"""Authenticate user."""
user = session.exec(select(User).where(User.username == username)).first()
user = session.scalars(select(User).where(User.username == username)).first()
if not user:
return None
if not verify_password(password, user.hashed_password):

View File

@ -1,8 +1,11 @@
"""Models for authentication."""
from datetime import datetime
import uuid
import sqlalchemy as sa
from sqlmodel import SQLModel, Field
from pydantic import BaseModel
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
JWT_ALGORITHM = "HS256"
@ -12,59 +15,65 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_HOURS = 6
class User(SQLModel, table=True):
class Base(DeclarativeBase):
pass
class User(Base):
"""Users."""
username: str = Field(unique=True, primary_key=True)
hashed_password: str
disabled: bool = Field(default=False)
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
__tablename__: str = "user"
id: Mapped[uuid.UUID] = mapped_column(
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
)
username: Mapped[str] = mapped_column(sa.String)
hashed_password: Mapped[str] = mapped_column(sa.String)
disabled: Mapped[bool] = mapped_column(sa.BOOLEAN, default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
class PasswordDB(SQLModel, table=True):
class PasswordDB(Base):
"""Password database."""
id: int | None = Field(default=None, primary_key=True)
encrypted_password: str
__tablename__: str = "password_db"
created_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()},
nullable=False,
id: Mapped[int] = mapped_column(sa.INT, primary_key=True)
encrypted_password: Mapped[str] = mapped_column(sa.String)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
)
updated_at: datetime | None = Field(
default=None,
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
updated_at: Mapped[datetime | None] = mapped_column(
sa.DateTime(timezone=True),
server_default=sa.func.now(),
onupdate=sa.func.now(),
)
class TokenData(SQLModel):
class TokenData(BaseModel):
"""Token data."""
username: str | None = None
class Token(SQLModel):
class Token(BaseModel):
access_token: str
token_type: str
class LoginError(SQLModel):
class LoginError(BaseModel):
"""Login Error model."""
# TODO: Remove this.
title: str
message: str
def init_db(engine: sa.Engine) -> None:
"""Create database."""
SQLModel.metadata.create_all(engine)
Base.metadata.create_all(engine)

View File

@ -12,7 +12,8 @@ from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_admin import api, frontend
from sshecret_admin.auth.models import PasswordDB, init_db
from sshecret_admin.core.db import setup_database
@ -50,7 +51,7 @@ def create_admin_app(
settings=settings, regenerate=False
)
with Session(engine) as session:
existing_password = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
existing_password = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not encr_master_password:
if existing_password:

View File

@ -9,7 +9,8 @@ from typing import Any, cast
import click
import uvicorn
from pydantic import ValidationError
from sqlmodel import Session, create_engine, select
from sqlalchemy import select, create_engine
from sqlalchemy.orm import Session
from sshecret_admin.auth.authentication import hash_password
from sshecret_admin.auth.models import PasswordDB, User, init_db
from sshecret_admin.core.settings import AdminServerSettings
@ -80,7 +81,7 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
user = session.exec(select(User).where(User.username == username)).first()
user = session.scalars(select(User).where(User.username == username)).first()
if not user:
raise click.ClickException(f"Error: No such user, {username}.")
new_passwd_hash = hash_password(password)
@ -100,7 +101,7 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
user = session.exec(select(User).where(User.username == username)).first()
user = session.scalars(select(User).where(User.username == username)).first()
if not user:
raise click.ClickException(f"Error: No such user, {username}.")
@ -142,7 +143,7 @@ def cli_repl(ctx: click.Context) -> None:
engine = create_engine(settings.admin_db)
init_db(engine)
with Session(engine) as session:
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not password_db:
raise click.ClickException(

View File

@ -2,17 +2,17 @@
from collections.abc import Generator, Callable
from sqlmodel import Session, create_engine
import sqlalchemy as sa
from sqlalchemy.orm import Session
from sqlalchemy.engine import URL
from sqlalchemy import create_engine, Engine
def setup_database(
db_url: URL | str,
) -> tuple[sa.Engine, Callable[[], Generator[Session, None, None]]]:
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database."""
engine = create_engine(db_url, echo=False)
engine = create_engine(db_url, echo=False, future=True)
def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session."""

View File

@ -1,10 +1,10 @@
"""Common type definitions."""
from collections.abc import AsyncGenerator, Callable, Generator
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from dataclasses import dataclass
from typing import Awaitable, Self
from typing import Self
from sqlmodel import Session
from sqlalchemy.orm import Session
from sshecret_admin.auth import User
from sshecret_admin.services import AdminBackend
from sshecret_admin.core.settings import AdminServerSettings

View File

@ -8,7 +8,7 @@ from sqlalchemy import URL
DEFAULT_LISTEN_PORT = 8822
DEFAULT_DATABASE = "ssh_admin.db"
DEFAULT_DATABASE = "sshecret_admin.db"
class AdminServerSettings(BaseSettings):

View File

@ -4,9 +4,9 @@ from dataclasses import dataclass
from collections.abc import Callable, Awaitable
from typing import Self
from sqlalchemy.orm import Session
from jinja2_fragments.fastapi import Jinja2Blocks
from fastapi import Request
from sqlmodel import Session
from sshecret_admin.core.dependencies import AdminDep, BaseDependencies

View File

@ -11,7 +11,8 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from jinja2_fragments.fastapi import Jinja2Blocks
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from starlette.datastructures import URL
@ -46,7 +47,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
session: Annotated[Session, Depends(dependencies.get_db_session)]
):
"""Get admin backend API."""
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not password_db:
raise HTTPException(
500, detail="Error: The password manager has not yet been set up."
@ -62,7 +63,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
token_data = decode_token(dependencies.settings, token)
if not token_data:
return None
user = session.exec(
user = session.scalars(
select(User).where(User.username == token_data.username)
).first()
if not user or user.disabled:

View File

@ -7,7 +7,8 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request, Response, status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from sqlalchemy.orm import Session
from sshecret_admin.services import AdminBackend
from starlette.datastructures import URL

View File

@ -2,8 +2,8 @@
from collections.abc import AsyncGenerator, Callable, Generator, Awaitable
from sqlalchemy.orm import Session
from fastapi import Request
from sqlmodel import Session
from sshecret_admin.admin_backend import AdminBackend
from sshecret_admin.auth_models import User