Migrate from sqlmodel to pure sqlalchemy

This commit is contained in:
2025-05-18 22:13:07 +02:00
parent 061a52c90a
commit a0adf281b5
12 changed files with 68 additions and 52 deletions

View File

@ -5,7 +5,7 @@ import logging
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session from sqlalchemy.orm import Session
from sshecret_admin.auth import Token, authenticate_user, create_access_token from sshecret_admin.auth import Token, authenticate_user, create_access_token
from sshecret_admin.core.dependencies import AdminDependencies from sshecret_admin.core.dependencies import AdminDependencies

View File

@ -8,7 +8,8 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_admin.services.admin_backend import AdminBackend from sshecret_admin.services.admin_backend import AdminBackend
from sshecret_admin.core.dependencies import BaseDependencies, AdminDependencies from sshecret_admin.core.dependencies import BaseDependencies, AdminDependencies
@ -40,7 +41,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
if not token_data: if not token_data:
raise credentials_exception raise credentials_exception
user = session.exec( user = session.scalars(
select(User).where(User.username == token_data.username) select(User).where(User.username == token_data.username)
).first() ).first()
if not user: if not user:
@ -59,7 +60,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
session: Annotated[Session, Depends(dependencies.get_db_session)] session: Annotated[Session, Depends(dependencies.get_db_session)]
): ):
"""Get admin backend API.""" """Get admin backend API."""
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first() password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not password_db: if not password_db:
raise HTTPException( raise HTTPException(
500, detail="Error: The password manager has not yet been set up." 500, detail="Error: The password manager has not yet been set up."

View File

@ -6,9 +6,11 @@ from typing import cast, Any
import bcrypt import bcrypt
import jwt import jwt
from sqlmodel import Session, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_admin.core.settings import AdminServerSettings from sshecret_admin.core.settings import AdminServerSettings
from .models import User, TokenData from .models import User, TokenData
from .exceptions import AuthenticationFailedError from .exceptions import AuthenticationFailedError
@ -72,7 +74,7 @@ def check_password(plain_password: str, hashed_password: str) -> None:
def authenticate_user(session: Session, username: str, password: str) -> User | None: def authenticate_user(session: Session, username: str, password: str) -> User | None:
"""Authenticate user.""" """Authenticate user."""
user = session.exec(select(User).where(User.username == username)).first() user = session.scalars(select(User).where(User.username == username)).first()
if not user: if not user:
return None return None
if not verify_password(password, user.hashed_password): if not verify_password(password, user.hashed_password):

View File

@ -1,8 +1,11 @@
"""Models for authentication.""" """Models for authentication."""
from datetime import datetime from datetime import datetime
import uuid
import sqlalchemy as sa import sqlalchemy as sa
from sqlmodel import SQLModel, Field from pydantic import BaseModel
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
JWT_ALGORITHM = "HS256" JWT_ALGORITHM = "HS256"
@ -12,59 +15,65 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_HOURS = 6 REFRESH_TOKEN_EXPIRE_HOURS = 6
class User(SQLModel, table=True): class Base(DeclarativeBase):
pass
class User(Base):
"""Users.""" """Users."""
username: str = Field(unique=True, primary_key=True) __tablename__: str = "user"
hashed_password: str
disabled: bool = Field(default=False) id: Mapped[uuid.UUID] = mapped_column(
created_at: datetime | None = Field( sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
default=None, )
sa_type=sa.DateTime(timezone=True), username: Mapped[str] = mapped_column(sa.String)
sa_column_kwargs={"server_default": sa.func.now()}, hashed_password: Mapped[str] = mapped_column(sa.String)
nullable=False, disabled: Mapped[bool] = mapped_column(sa.BOOLEAN, default=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
) )
class PasswordDB(SQLModel, table=True): class PasswordDB(Base):
"""Password database.""" """Password database."""
id: int | None = Field(default=None, primary_key=True) __tablename__: str = "password_db"
encrypted_password: str
created_at: datetime | None = Field( id: Mapped[int] = mapped_column(sa.INT, primary_key=True)
default=None, encrypted_password: Mapped[str] = mapped_column(sa.String)
sa_type=sa.DateTime(timezone=True),
sa_column_kwargs={"server_default": sa.func.now()}, created_at: Mapped[datetime] = mapped_column(
nullable=False, sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
) )
updated_at: datetime | None = Field( updated_at: Mapped[datetime | None] = mapped_column(
default=None, sa.DateTime(timezone=True),
sa_type=sa.DateTime(timezone=True), server_default=sa.func.now(),
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()}, onupdate=sa.func.now(),
) )
class TokenData(BaseModel):
class TokenData(SQLModel):
"""Token data.""" """Token data."""
username: str | None = None username: str | None = None
class Token(SQLModel): class Token(BaseModel):
access_token: str access_token: str
token_type: str token_type: str
class LoginError(SQLModel): class LoginError(BaseModel):
"""Login Error model.""" """Login Error model."""
# TODO: Remove this. # TODO: Remove this.
title: str title: str
message: str message: str
def init_db(engine: sa.Engine) -> None: def init_db(engine: sa.Engine) -> None:
"""Create database.""" """Create database."""
SQLModel.metadata.create_all(engine) Base.metadata.create_all(engine)

View File

@ -12,7 +12,8 @@ from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from sshecret_admin import api, frontend from sshecret_admin import api, frontend
from sshecret_admin.auth.models import PasswordDB, init_db from sshecret_admin.auth.models import PasswordDB, init_db
from sshecret_admin.core.db import setup_database from sshecret_admin.core.db import setup_database
@ -50,7 +51,7 @@ def create_admin_app(
settings=settings, regenerate=False settings=settings, regenerate=False
) )
with Session(engine) as session: with Session(engine) as session:
existing_password = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first() existing_password = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not encr_master_password: if not encr_master_password:
if existing_password: if existing_password:

View File

@ -9,7 +9,8 @@ from typing import Any, cast
import click import click
import uvicorn import uvicorn
from pydantic import ValidationError from pydantic import ValidationError
from sqlmodel import Session, create_engine, select from sqlalchemy import select, create_engine
from sqlalchemy.orm import Session
from sshecret_admin.auth.authentication import hash_password from sshecret_admin.auth.authentication import hash_password
from sshecret_admin.auth.models import PasswordDB, User, init_db from sshecret_admin.auth.models import PasswordDB, User, init_db
from sshecret_admin.core.settings import AdminServerSettings from sshecret_admin.core.settings import AdminServerSettings
@ -80,7 +81,7 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
engine = create_engine(settings.admin_db) engine = create_engine(settings.admin_db)
init_db(engine) init_db(engine)
with Session(engine) as session: with Session(engine) as session:
user = session.exec(select(User).where(User.username == username)).first() user = session.scalars(select(User).where(User.username == username)).first()
if not user: if not user:
raise click.ClickException(f"Error: No such user, {username}.") raise click.ClickException(f"Error: No such user, {username}.")
new_passwd_hash = hash_password(password) new_passwd_hash = hash_password(password)
@ -100,7 +101,7 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
engine = create_engine(settings.admin_db) engine = create_engine(settings.admin_db)
init_db(engine) init_db(engine)
with Session(engine) as session: with Session(engine) as session:
user = session.exec(select(User).where(User.username == username)).first() user = session.scalars(select(User).where(User.username == username)).first()
if not user: if not user:
raise click.ClickException(f"Error: No such user, {username}.") raise click.ClickException(f"Error: No such user, {username}.")
@ -142,7 +143,7 @@ def cli_repl(ctx: click.Context) -> None:
engine = create_engine(settings.admin_db) engine = create_engine(settings.admin_db)
init_db(engine) init_db(engine)
with Session(engine) as session: with Session(engine) as session:
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first() password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not password_db: if not password_db:
raise click.ClickException( raise click.ClickException(

View File

@ -2,17 +2,17 @@
from collections.abc import Generator, Callable from collections.abc import Generator, Callable
from sqlmodel import Session, create_engine from sqlalchemy.orm import Session
import sqlalchemy as sa
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy import create_engine, Engine
def setup_database( def setup_database(
db_url: URL | str, db_url: URL | str,
) -> tuple[sa.Engine, Callable[[], Generator[Session, None, None]]]: ) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
"""Setup database.""" """Setup database."""
engine = create_engine(db_url, echo=False) engine = create_engine(db_url, echo=False, future=True)
def get_db_session() -> Generator[Session, None, None]: def get_db_session() -> Generator[Session, None, None]:
"""Get DB Session.""" """Get DB Session."""

View File

@ -1,10 +1,10 @@
"""Common type definitions.""" """Common type definitions."""
from collections.abc import AsyncGenerator, Callable, Generator from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Awaitable, Self from typing import Self
from sqlmodel import Session from sqlalchemy.orm import Session
from sshecret_admin.auth import User from sshecret_admin.auth import User
from sshecret_admin.services import AdminBackend from sshecret_admin.services import AdminBackend
from sshecret_admin.core.settings import AdminServerSettings from sshecret_admin.core.settings import AdminServerSettings

View File

@ -4,9 +4,9 @@ from dataclasses import dataclass
from collections.abc import Callable, Awaitable from collections.abc import Callable, Awaitable
from typing import Self from typing import Self
from sqlalchemy.orm import Session
from jinja2_fragments.fastapi import Jinja2Blocks from jinja2_fragments.fastapi import Jinja2Blocks
from fastapi import Request from fastapi import Request
from sqlmodel import Session
from sshecret_admin.core.dependencies import AdminDep, BaseDependencies from sshecret_admin.core.dependencies import AdminDep, BaseDependencies

View File

@ -11,7 +11,8 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from jinja2_fragments.fastapi import Jinja2Blocks from jinja2_fragments.fastapi import Jinja2Blocks
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from starlette.datastructures import URL from starlette.datastructures import URL
@ -46,7 +47,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
session: Annotated[Session, Depends(dependencies.get_db_session)] session: Annotated[Session, Depends(dependencies.get_db_session)]
): ):
"""Get admin backend API.""" """Get admin backend API."""
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first() password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
if not password_db: if not password_db:
raise HTTPException( raise HTTPException(
500, detail="Error: The password manager has not yet been set up." 500, detail="Error: The password manager has not yet been set up."
@ -62,7 +63,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
token_data = decode_token(dependencies.settings, token) token_data = decode_token(dependencies.settings, token)
if not token_data: if not token_data:
return None return None
user = session.exec( user = session.scalars(
select(User).where(User.username == token_data.username) select(User).where(User.username == token_data.username)
).first() ).first()
if not user or user.disabled: if not user or user.disabled:

View File

@ -7,7 +7,8 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request, Response, status from fastapi import APIRouter, Depends, Query, Request, Response, status
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import Session
from sqlalchemy.orm import Session
from sshecret_admin.services import AdminBackend from sshecret_admin.services import AdminBackend
from starlette.datastructures import URL from starlette.datastructures import URL

View File

@ -2,8 +2,8 @@
from collections.abc import AsyncGenerator, Callable, Generator, Awaitable from collections.abc import AsyncGenerator, Callable, Generator, Awaitable
from sqlalchemy.orm import Session
from fastapi import Request from fastapi import Request
from sqlmodel import Session
from sshecret_admin.admin_backend import AdminBackend from sshecret_admin.admin_backend import AdminBackend
from sshecret_admin.auth_models import User from sshecret_admin.auth_models import User