Compare commits
2 Commits
f10ae027e5
...
5865cc450f
| Author | SHA1 | Date | |
|---|---|---|---|
| 5865cc450f | |||
| fc0c3fb950 |
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from .authentication import (
|
from .authentication import (
|
||||||
authenticate_user,
|
authenticate_user,
|
||||||
|
authenticate_user_async,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
create_refresh_token,
|
create_refresh_token,
|
||||||
check_password,
|
check_password,
|
||||||
@ -16,6 +17,7 @@ __all__ = [
|
|||||||
"Token",
|
"Token",
|
||||||
"User",
|
"User",
|
||||||
"authenticate_user",
|
"authenticate_user",
|
||||||
|
"authenticate_user_async",
|
||||||
"check_password",
|
"check_password",
|
||||||
"create_access_token",
|
"create_access_token",
|
||||||
"create_refresh_token",
|
"create_refresh_token",
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import bcrypt
|
|||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sshecret_admin.core.settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
@ -72,6 +73,16 @@ def check_password(plain_password: str, hashed_password: str) -> None:
|
|||||||
raise AuthenticationFailedError()
|
raise AuthenticationFailedError()
|
||||||
|
|
||||||
|
|
||||||
|
async def authenticate_user_async(session: AsyncSession, username: str, password: str) -> User | None:
|
||||||
|
"""Authenticate user async."""
|
||||||
|
user = (await session.scalars(select(User).where(User.username == username))).first()
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
if not verify_password(password, user.hashed_password):
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
||||||
"""Authenticate user."""
|
"""Authenticate user."""
|
||||||
user = session.scalars(select(User).where(User.username == username)).first()
|
user = session.scalars(select(User).where(User.username == username)).first()
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
"""Database setup."""
|
"""Database setup."""
|
||||||
|
|
||||||
from collections.abc import Generator, Callable
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from collections.abc import AsyncIterator, Generator, Callable
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
from sqlalchemy import create_engine, Engine
|
from sqlalchemy import create_engine, Engine
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
|
||||||
def setup_database(
|
def setup_database(
|
||||||
db_url: URL | str,
|
db_url: URL | str,
|
||||||
@ -20,3 +24,43 @@ def setup_database(
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
return engine, get_db_session
|
return engine, get_db_session
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSessionManager:
|
||||||
|
def __init__(self, host: URL | str, **engine_kwargs: str):
|
||||||
|
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||||
|
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
await self._engine.dispose()
|
||||||
|
|
||||||
|
self._engine = None
|
||||||
|
self._sessionmaker = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(self) -> AsyncIterator[AsyncConnection]:
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
async with self._engine.begin() as connection:
|
||||||
|
try:
|
||||||
|
yield connection
|
||||||
|
except Exception:
|
||||||
|
await connection.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||||
|
if self._sessionmaker is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
session = self._sessionmaker()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sshecret_admin.auth import User
|
from sshecret_admin.auth import User
|
||||||
from sshecret_admin.services import AdminBackend
|
from sshecret_admin.services import AdminBackend
|
||||||
|
|||||||
@ -31,3 +31,8 @@ class AdminServerSettings(BaseSettings):
|
|||||||
def admin_db(self) -> URL:
|
def admin_db(self) -> URL:
|
||||||
"""Construct database url."""
|
"""Construct database url."""
|
||||||
return URL.create(drivername="sqlite", database=self.database)
|
return URL.create(drivername="sqlite", database=self.database)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def async_db_url(self) -> URL:
|
||||||
|
"""Construct database url with sync handling."""
|
||||||
|
return URL.create(drivername="sqlite+aiosqlite", database=self.database)
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
"""Frontend dependencies."""
|
"""Frontend dependencies."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from collections.abc import Callable, Awaitable
|
from collections.abc import AsyncGenerator, Callable, Awaitable
|
||||||
from typing import Self
|
from typing import Self
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
@ -14,6 +15,7 @@ from sshecret_admin.auth.models import User
|
|||||||
|
|
||||||
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
||||||
UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
|
UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
|
||||||
|
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -25,6 +27,7 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_user_from_access_token: UserTokenDep
|
get_user_from_access_token: UserTokenDep
|
||||||
get_user_from_refresh_token: UserTokenDep
|
get_user_from_refresh_token: UserTokenDep
|
||||||
get_login_status: UserLoginDep
|
get_login_status: UserLoginDep
|
||||||
|
get_async_session: AsyncSessionDep
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
@ -35,6 +38,7 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_user_from_access_token: UserTokenDep,
|
get_user_from_access_token: UserTokenDep,
|
||||||
get_user_from_refresh_token: UserTokenDep,
|
get_user_from_refresh_token: UserTokenDep,
|
||||||
get_login_status: UserLoginDep,
|
get_login_status: UserLoginDep,
|
||||||
|
get_async_session: AsyncSessionDep
|
||||||
) -> Self:
|
) -> Self:
|
||||||
"""Create from base dependencies."""
|
"""Create from base dependencies."""
|
||||||
return cls(
|
return cls(
|
||||||
@ -45,4 +49,5 @@ class FrontendDependencies(BaseDependencies):
|
|||||||
get_user_from_access_token=get_user_from_access_token,
|
get_user_from_access_token=get_user_from_access_token,
|
||||||
get_user_from_refresh_token=get_user_from_refresh_token,
|
get_user_from_refresh_token=get_user_from_refresh_token,
|
||||||
get_login_status=get_login_status,
|
get_login_status=get_login_status,
|
||||||
|
get_async_session=get_async_session,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from starlette.datastructures import URL
|
|||||||
from sshecret_admin.auth import PasswordDB, User, decode_token
|
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||||
from sshecret_admin.core.dependencies import BaseDependencies
|
from sshecret_admin.core.dependencies import BaseDependencies
|
||||||
from sshecret_admin.services.admin_backend import AdminBackend
|
from sshecret_admin.services.admin_backend import AdminBackend
|
||||||
|
from sshecret_admin.core.db import DatabaseSessionManager
|
||||||
|
|
||||||
from .dependencies import FrontendDependencies
|
from .dependencies import FrontendDependencies
|
||||||
from .exceptions import RedirectException
|
from .exceptions import RedirectException
|
||||||
@ -47,7 +48,9 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||||
):
|
):
|
||||||
"""Get admin backend API."""
|
"""Get admin backend API."""
|
||||||
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
password_db = session.scalars(
|
||||||
|
select(PasswordDB).where(PasswordDB.id == 1)
|
||||||
|
).first()
|
||||||
if not password_db:
|
if not password_db:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
500, detail="Error: The password manager has not yet been set up."
|
500, detail="Error: The password manager has not yet been set up."
|
||||||
@ -116,6 +119,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def get_async_session():
|
||||||
|
"""Get async session."""
|
||||||
|
sessionmanager = DatabaseSessionManager(dependencies.settings.async_db_url)
|
||||||
|
async with sessionmanager.session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
view_dependencies = FrontendDependencies.create(
|
view_dependencies = FrontendDependencies.create(
|
||||||
dependencies,
|
dependencies,
|
||||||
get_admin_backend,
|
get_admin_backend,
|
||||||
@ -123,6 +132,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
|||||||
get_user_from_access_token,
|
get_user_from_access_token,
|
||||||
get_user_from_refresh_token,
|
get_user_from_refresh_token,
|
||||||
get_login_status,
|
get_login_status,
|
||||||
|
get_async_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(audit.create_router(view_dependencies))
|
app.include_router(audit.create_router(view_dependencies))
|
||||||
|
|||||||
@ -8,13 +8,13 @@ from fastapi import APIRouter, Depends, Query, Request, Response, status
|
|||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sshecret_admin.services import AdminBackend
|
from sshecret_admin.services import AdminBackend
|
||||||
from starlette.datastructures import URL
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
from sshecret_admin.auth import (
|
from sshecret_admin.auth import (
|
||||||
User,
|
User,
|
||||||
authenticate_user,
|
authenticate_user_async,
|
||||||
create_access_token,
|
create_access_token,
|
||||||
create_refresh_token,
|
create_refresh_token,
|
||||||
)
|
)
|
||||||
@ -80,7 +80,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
|||||||
async def login_user(
|
async def login_user(
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
next: Annotated[str, Query()] = "/dashboard",
|
next: Annotated[str, Query()] = "/dashboard",
|
||||||
@ -100,7 +100,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
user = authenticate_user(session, form_data.username, form_data.password)
|
user = await authenticate_user_async(session, form_data.username, form_data.password)
|
||||||
login_failed = RedirectException(
|
login_failed = RedirectException(
|
||||||
to=URL("/login").include_query_params(
|
to=URL("/login").include_query_params(
|
||||||
error_title="Login Error", error_message="Invalid username or password"
|
error_title="Login Error", error_message="Invalid username or password"
|
||||||
|
|||||||
@ -3,18 +3,18 @@
|
|||||||
# pyright: reportUnusedFunction=false
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from typing import Any
|
||||||
from typing import Any, cast
|
from fastapi import APIRouter, Depends, Request
|
||||||
from fastapi import APIRouter, Depends, Request, Query
|
|
||||||
from pydantic import BaseModel, Field, TypeAdapter
|
from pydantic import BaseModel, Field, TypeAdapter
|
||||||
from sqlalchemy import select, func, and_
|
from sqlalchemy import select, func, and_
|
||||||
from sqlalchemy.orm import InstrumentedAttribute, Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import InstrumentedAttribute
|
||||||
from sqlalchemy.sql.expression import ColumnExpressionArgument
|
from sqlalchemy.sql.expression import ColumnExpressionArgument
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import AuditLog, Operation, SubSystem
|
from sshecret_backend.models import AuditLog, Operation, SubSystem
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import AsyncDBSessionDep
|
||||||
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
|
from sshecret_backend.view_models import AuditInfo, AuditView, AuditListResult
|
||||||
|
|
||||||
|
|
||||||
@ -58,24 +58,23 @@ class AuditFilter(BaseModel):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
def get_audit_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||||
"""Construct audit sub-api."""
|
"""Construct audit sub-api."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/audit/", response_model=AuditListResult)
|
@router.get("/audit/", response_model=AuditListResult)
|
||||||
async def get_audit_logs(
|
async def get_audit_logs(
|
||||||
request: Request,
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
|
||||||
filters: Annotated[AuditFilter, Depends()],
|
filters: Annotated[AuditFilter, Depends()],
|
||||||
) -> AuditListResult:
|
) -> AuditListResult:
|
||||||
"""Get audit logs."""
|
"""Get audit logs."""
|
||||||
# audit.audit_access_audit_log(session, request)
|
# audit.audit_access_audit_log(session, request)
|
||||||
|
|
||||||
total = session.scalars(
|
total = (await session.scalars(
|
||||||
select(func.count("*"))
|
select(func.count("*"))
|
||||||
.select_from(AuditLog)
|
.select_from(AuditLog)
|
||||||
.where(and_(True, *filters.filter_mapping))
|
.where(and_(True, *filters.filter_mapping))
|
||||||
).one()
|
)).one()
|
||||||
|
|
||||||
remaining = total - filters.offset
|
remaining = total - filters.offset
|
||||||
statement = (
|
statement = (
|
||||||
@ -87,7 +86,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
LogAdapt = TypeAdapter(list[AuditView])
|
LogAdapt = TypeAdapter(list[AuditView])
|
||||||
results = session.scalars(statement).all()
|
results = (await session.scalars(statement)).all()
|
||||||
entries = LogAdapt.validate_python(results, from_attributes=True)
|
entries = LogAdapt.validate_python(results, from_attributes=True)
|
||||||
return AuditListResult(
|
return AuditListResult(
|
||||||
results=entries,
|
results=entries,
|
||||||
@ -97,24 +96,23 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
|
|
||||||
@router.post("/audit/")
|
@router.post("/audit/")
|
||||||
async def add_audit_log(
|
async def add_audit_log(
|
||||||
request: Request,
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
|
||||||
entry: AuditView,
|
entry: AuditView,
|
||||||
) -> AuditView:
|
) -> AuditView:
|
||||||
"""Add entry to audit log."""
|
"""Add entry to audit log."""
|
||||||
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||||
session.add(audit_log)
|
session.add(audit_log)
|
||||||
session.commit()
|
await session.commit()
|
||||||
return AuditView.model_validate(audit_log, from_attributes=True)
|
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||||
|
|
||||||
@router.get("/audit/info")
|
@router.get("/audit/info")
|
||||||
async def get_audit_info(
|
async def get_audit_info(
|
||||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||||
) -> AuditInfo:
|
) -> AuditInfo:
|
||||||
"""Get audit info."""
|
"""Get audit info."""
|
||||||
audit_count = session.scalars(
|
audit_count = (await session.scalars(
|
||||||
select(func.count("*")).select_from(AuditLog)
|
select(func.count("*")).select_from(AuditLog)
|
||||||
).one()
|
)).one()
|
||||||
return AuditInfo(entries=audit_count)
|
return AuditInfo(entries=audit_count)
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
|||||||
@ -4,14 +4,14 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing import Annotated, Any, Self, TypeVar, cast
|
from typing import Annotated, Any, Self, TypeVar, cast
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select, func
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.sql import Select
|
from sqlalchemy.sql import Select
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import AsyncDBSessionDep
|
||||||
from sshecret_backend.models import Client, ClientSecret
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
from sshecret_backend.view_models import (
|
from sshecret_backend.view_models import (
|
||||||
ClientCreate,
|
ClientCreate,
|
||||||
@ -20,7 +20,7 @@ from sshecret_backend.view_models import (
|
|||||||
ClientUpdate,
|
ClientUpdate,
|
||||||
)
|
)
|
||||||
from sshecret_backend import audit
|
from sshecret_backend import audit
|
||||||
from .common import get_client_by_id_or_name
|
from .common import get_client_by_id_or_name, client_with_relationships
|
||||||
|
|
||||||
|
|
||||||
class ClientListParams(BaseModel):
|
class ClientListParams(BaseModel):
|
||||||
@ -74,30 +74,30 @@ def filter_client_statement(
|
|||||||
return statement.limit(params.limit).offset(params.offset)
|
return statement.limit(params.limit).offset(params.offset)
|
||||||
|
|
||||||
|
|
||||||
def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
def get_clients_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||||
"""Construct clients sub-api."""
|
"""Construct clients sub-api."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/clients/")
|
@router.get("/clients/")
|
||||||
async def get_clients(
|
async def get_clients(
|
||||||
filter_query: Annotated[ClientListParams, Query()],
|
filter_query: Annotated[ClientListParams, Query()],
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientQueryResult:
|
) -> ClientQueryResult:
|
||||||
"""Get clients."""
|
"""Get clients."""
|
||||||
# Get total results first
|
# Get total results first
|
||||||
count_statement = select(func.count("*")).select_from(Client)
|
count_statement = select(func.count("*")).select_from(Client)
|
||||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||||
|
|
||||||
total_results = session.scalars(count_statement).one()
|
total_results = (await session.scalars(count_statement)).one()
|
||||||
|
|
||||||
statement = filter_client_statement(select(Client), filter_query, False)
|
statement = filter_client_statement(client_with_relationships(), filter_query, False)
|
||||||
|
|
||||||
results = session.scalars(statement)
|
results = await session.scalars(statement)
|
||||||
remainder = total_results - filter_query.offset - filter_query.limit
|
remainder = total_results - filter_query.offset - filter_query.limit
|
||||||
if remainder < 0:
|
if remainder < 0:
|
||||||
remainder = 0
|
remainder = 0
|
||||||
|
|
||||||
clients = list(results)
|
clients = list(results.all())
|
||||||
clients_view = ClientView.from_client_list(clients)
|
clients_view = ClientView.from_client_list(clients)
|
||||||
return ClientQueryResult(
|
return ClientQueryResult(
|
||||||
clients=clients_view,
|
clients=clients_view,
|
||||||
@ -108,7 +108,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
@router.get("/clients/{name}")
|
@router.get("/clients/{name}")
|
||||||
async def get_client(
|
async def get_client(
|
||||||
name: str,
|
name: str,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientView:
|
) -> ClientView:
|
||||||
"""Fetch a client."""
|
"""Fetch a client."""
|
||||||
client = await get_client_by_id_or_name(session, name)
|
client = await get_client_by_id_or_name(session, name)
|
||||||
@ -122,7 +122,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
async def delete_client(
|
async def delete_client(
|
||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete a client."""
|
"""Delete a client."""
|
||||||
client = await get_client_by_id_or_name(session, name)
|
client = await get_client_by_id_or_name(session, name)
|
||||||
@ -131,15 +131,15 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
|
|
||||||
session.delete(client)
|
await session.delete(client)
|
||||||
session.commit()
|
await session.commit()
|
||||||
audit.audit_delete_client(session, request, client)
|
await audit.audit_delete_client(session, request, client)
|
||||||
|
|
||||||
@router.post("/clients/")
|
@router.post("/clients/")
|
||||||
async def create_client(
|
async def create_client(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: ClientCreate,
|
client: ClientCreate,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientView:
|
) -> ClientView:
|
||||||
"""Create client."""
|
"""Create client."""
|
||||||
existing = await get_client_by_id_or_name(session, client.name)
|
existing = await get_client_by_id_or_name(session, client.name)
|
||||||
@ -148,9 +148,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
|
|
||||||
db_client = client.to_client()
|
db_client = client.to_client()
|
||||||
session.add(db_client)
|
session.add(db_client)
|
||||||
session.commit()
|
await session.commit()
|
||||||
session.refresh(db_client)
|
await session.refresh(db_client)
|
||||||
audit.audit_create_client(session, request, db_client)
|
db_client = await get_client_by_id_or_name(session, client.name)
|
||||||
|
if not db_client:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Could not create the client.")
|
||||||
|
await audit.audit_create_client(session, request, db_client)
|
||||||
return ClientView.from_client(db_client)
|
return ClientView.from_client(db_client)
|
||||||
|
|
||||||
@router.post("/clients/{name}/public-key")
|
@router.post("/clients/{name}/public-key")
|
||||||
@ -158,7 +161,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
client_update: ClientUpdate,
|
client_update: ClientUpdate,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientView:
|
) -> ClientView:
|
||||||
"""Change the public key of a client.
|
"""Change the public key of a client.
|
||||||
|
|
||||||
@ -170,17 +173,16 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
client.public_key = client_update.public_key
|
client.public_key = client_update.public_key
|
||||||
for secret in session.scalars(
|
matching_secrets = await session.scalars(select(ClientSecret).where(ClientSecret.client_id == client.id))
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
for secret in matching_secrets.all():
|
||||||
).all():
|
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
secret.invalidated = True
|
secret.invalidated = True
|
||||||
secret.client_id = None
|
secret.client_id = None
|
||||||
|
|
||||||
session.add(client)
|
session.add(client)
|
||||||
session.refresh(client)
|
await session.refresh(client)
|
||||||
session.commit()
|
await session.commit()
|
||||||
audit.audit_invalidate_secrets(session, request, client)
|
await audit.audit_invalidate_secrets(session, request, client)
|
||||||
|
|
||||||
return ClientView.from_client(client)
|
return ClientView.from_client(client)
|
||||||
|
|
||||||
@ -189,7 +191,7 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
client_update: ClientCreate,
|
client_update: ClientCreate,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientView:
|
) -> ClientView:
|
||||||
"""Change the public key of a client.
|
"""Change the public key of a client.
|
||||||
|
|
||||||
@ -205,19 +207,20 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
public_key_updated = False
|
public_key_updated = False
|
||||||
if client_update.public_key != client.public_key:
|
if client_update.public_key != client.public_key:
|
||||||
public_key_updated = True
|
public_key_updated = True
|
||||||
for secret in session.scalars(
|
client_secrets = await session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all():
|
)
|
||||||
|
for secret in client_secrets.all():
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
secret.invalidated = True
|
secret.invalidated = True
|
||||||
secret.client_id = None
|
secret.client_id = None
|
||||||
|
|
||||||
session.add(client)
|
session.add(client)
|
||||||
session.commit()
|
await session.commit()
|
||||||
session.refresh(client)
|
await session.refresh(client)
|
||||||
audit.audit_update_client(session, request, client)
|
await audit.audit_update_client(session, request, client)
|
||||||
if public_key_updated:
|
if public_key_updated:
|
||||||
audit.audit_invalidate_secrets(session, request, client)
|
await audit.audit_invalidate_secrets(session, request, client)
|
||||||
|
|
||||||
return ClientView.from_client(client)
|
return ClientView.from_client(client)
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,11 @@
|
|||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
from sqlalchemy import Select
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from sshecret_backend.models import Client
|
from sshecret_backend.models import Client
|
||||||
|
|
||||||
@ -17,20 +19,38 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
|||||||
stored_bytes = stored_hash.encode("utf-8")
|
stored_bytes = stored_hash.encode("utf-8")
|
||||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||||
|
|
||||||
|
async def reload_client_with_relationships(session: AsyncSession, client: Client) -> Client:
|
||||||
|
"""Reload a client from the database."""
|
||||||
|
session.expunge(client)
|
||||||
|
stmt = (
|
||||||
|
select(Client)
|
||||||
|
.options(selectinload(Client.policies), selectinload(Client.secrets))
|
||||||
|
.where(Client.id == client.id)
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
|
||||||
|
def client_with_relationships() -> Select[tuple[Client]]:
|
||||||
|
"""Base select statement for client with relationships."""
|
||||||
|
return select(Client).options(
|
||||||
|
selectinload(Client.secrets),
|
||||||
|
selectinload(Client.policies),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_client_by_name(session: AsyncSession, name: str) -> Client | None:
|
||||||
"""Get client by name."""
|
"""Get client by name."""
|
||||||
client_filter = select(Client).where(Client.name == name)
|
client_filter = client_with_relationships().where(Client.name == name)
|
||||||
client_results = session.scalars(client_filter)
|
client_results = await session.execute(client_filter)
|
||||||
return client_results.first()
|
return client_results.scalars().first()
|
||||||
|
|
||||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
async def get_client_by_id(session: AsyncSession, id: uuid.UUID) -> Client | None:
|
||||||
"""Get client by name."""
|
"""Get client by ID."""
|
||||||
client_filter = select(Client).where(Client.id == id)
|
client_filter = client_with_relationships().where(Client.id == id)
|
||||||
client_results = session.scalars(client_filter)
|
client_results = await session.execute(client_filter)
|
||||||
return client_results.first()
|
return client_results.scalars().first()
|
||||||
|
|
||||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
async def get_client_by_id_or_name(session: AsyncSession, id_or_name: str) -> Client | None:
|
||||||
"""Get client either by id or name."""
|
"""Get client either by id or name."""
|
||||||
if RE_UUID.match(id_or_name):
|
if RE_UUID.match(id_or_name):
|
||||||
id = uuid.UUID(id_or_name)
|
id = uuid.UUID(id_or_name)
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import ClientAccessPolicy
|
from sshecret_backend.models import ClientAccessPolicy
|
||||||
@ -13,21 +13,21 @@ from sshecret_backend.view_models import (
|
|||||||
ClientPolicyView,
|
ClientPolicyView,
|
||||||
ClientPolicyUpdate,
|
ClientPolicyUpdate,
|
||||||
)
|
)
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import AsyncDBSessionDep
|
||||||
from sshecret_backend import audit
|
from sshecret_backend import audit
|
||||||
from .common import get_client_by_id_or_name
|
from .common import get_client_by_id_or_name, reload_client_with_relationships
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
def get_policy_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||||
"""Construct clients sub-api."""
|
"""Construct clients sub-api."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/clients/{name}/policies/")
|
@router.get("/clients/{name}/policies/")
|
||||||
async def get_client_policies(
|
async def get_client_policies(
|
||||||
name: str, session: Annotated[Session, Depends(get_db_session)]
|
name: str, session: Annotated[AsyncSession, Depends(get_db_session)]
|
||||||
) -> ClientPolicyView:
|
) -> ClientPolicyView:
|
||||||
"""Get client policies."""
|
"""Get client policies."""
|
||||||
client = await get_client_by_id_or_name(session, name)
|
client = await get_client_by_id_or_name(session, name)
|
||||||
@ -43,7 +43,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
policy_update: ClientPolicyUpdate,
|
policy_update: ClientPolicyUpdate,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientPolicyView:
|
) -> ClientPolicyView:
|
||||||
"""Update client policies.
|
"""Update client policies.
|
||||||
|
|
||||||
@ -55,28 +55,31 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
# Remove old policies.
|
# Remove old policies.
|
||||||
policies = session.scalars(
|
policies = await session.scalars(
|
||||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||||
).all()
|
)
|
||||||
deleted_policies: list[ClientAccessPolicy] = []
|
deleted_policies: list[ClientAccessPolicy] = []
|
||||||
added_policies: list[ClientAccessPolicy] = []
|
added_policies: list[ClientAccessPolicy] = []
|
||||||
for policy in policies:
|
for policy in policies.all():
|
||||||
session.delete(policy)
|
await session.delete(policy)
|
||||||
deleted_policies.append(policy)
|
deleted_policies.append(policy)
|
||||||
|
|
||||||
|
LOG.debug("Updating client policies with: %r", policy_update.sources)
|
||||||
for source in policy_update.sources:
|
for source in policy_update.sources:
|
||||||
LOG.debug("Source %r", source)
|
LOG.debug("Source %r", source)
|
||||||
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
policy = ClientAccessPolicy(source=str(source), client_id=client.id)
|
||||||
session.add(policy)
|
session.add(policy)
|
||||||
added_policies.append(policy)
|
added_policies.append(policy)
|
||||||
|
|
||||||
session.commit()
|
await session.flush()
|
||||||
session.refresh(client)
|
await session.commit()
|
||||||
|
|
||||||
|
client = await reload_client_with_relationships(session, client)
|
||||||
for policy in deleted_policies:
|
for policy in deleted_policies:
|
||||||
audit.audit_remove_policy(session, request, client, policy)
|
await audit.audit_remove_policy(session, request, client, policy)
|
||||||
|
|
||||||
for policy in added_policies:
|
for policy in added_policies:
|
||||||
audit.audit_update_policy(session, request, client, policy)
|
await audit.audit_update_policy(session, request, client, policy)
|
||||||
|
|
||||||
return ClientPolicyView.from_client(client)
|
return ClientPolicyView.from_client(client)
|
||||||
|
|
||||||
|
|||||||
@ -6,9 +6,11 @@ import logging
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from sshecret_backend.models import Client, ClientSecret
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
from sshecret_backend.view_models import (
|
from sshecret_backend.view_models import (
|
||||||
ClientReference,
|
ClientReference,
|
||||||
@ -19,7 +21,7 @@ from sshecret_backend.view_models import (
|
|||||||
ClientSecretResponse,
|
ClientSecretResponse,
|
||||||
)
|
)
|
||||||
from sshecret_backend import audit
|
from sshecret_backend import audit
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import AsyncDBSessionDep
|
||||||
from .common import get_client_by_id_or_name
|
from .common import get_client_by_id_or_name
|
||||||
|
|
||||||
|
|
||||||
@ -27,7 +29,7 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def lookup_client_secret(
|
async def lookup_client_secret(
|
||||||
session: Session, client: Client, name: str
|
session: AsyncSession, client: Client, name: str
|
||||||
) -> ClientSecret | None:
|
) -> ClientSecret | None:
|
||||||
"""Look up a secret for a client."""
|
"""Look up a secret for a client."""
|
||||||
statement = (
|
statement = (
|
||||||
@ -35,11 +37,11 @@ async def lookup_client_secret(
|
|||||||
.where(ClientSecret.client_id == client.id)
|
.where(ClientSecret.client_id == client.id)
|
||||||
.where(ClientSecret.name == name)
|
.where(ClientSecret.name == name)
|
||||||
)
|
)
|
||||||
results = session.scalars(statement)
|
results = await session.scalars(statement)
|
||||||
return results.first()
|
return results.first()
|
||||||
|
|
||||||
|
|
||||||
def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
def get_secrets_api(get_db_session: AsyncDBSessionDep) -> APIRouter:
|
||||||
"""Construct clients sub-api."""
|
"""Construct clients sub-api."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -48,7 +50,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
client_secret: ClientSecretPublic,
|
client_secret: ClientSecretPublic,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add secret to a client."""
|
"""Add secret to a client."""
|
||||||
client = await get_client_by_id_or_name(session, name)
|
client = await get_client_by_id_or_name(session, name)
|
||||||
@ -69,9 +71,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
name=client_secret.name, client_id=client.id, secret=client_secret.secret
|
||||||
)
|
)
|
||||||
session.add(db_secret)
|
session.add(db_secret)
|
||||||
session.commit()
|
await session.commit()
|
||||||
session.refresh(db_secret)
|
await session.refresh(db_secret)
|
||||||
audit.audit_create_secret(session, request, client, db_secret)
|
await audit.audit_create_secret(session, request, client, db_secret)
|
||||||
|
|
||||||
@router.put("/clients/{name}/secrets/{secret_name}")
|
@router.put("/clients/{name}/secrets/{secret_name}")
|
||||||
async def update_client_secret(
|
async def update_client_secret(
|
||||||
@ -79,7 +81,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
name: str,
|
name: str,
|
||||||
secret_name: str,
|
secret_name: str,
|
||||||
secret_data: BodyValue,
|
secret_data: BodyValue,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientSecretResponse:
|
) -> ClientSecretResponse:
|
||||||
"""Update a client secret.
|
"""Update a client secret.
|
||||||
|
|
||||||
@ -96,9 +98,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
existing_secret.secret = secret_data.value
|
existing_secret.secret = secret_data.value
|
||||||
|
|
||||||
session.add(existing_secret)
|
session.add(existing_secret)
|
||||||
session.commit()
|
await session.commit()
|
||||||
session.refresh(existing_secret)
|
await session.refresh(existing_secret)
|
||||||
audit.audit_update_secret(session, request, client, existing_secret)
|
await audit.audit_update_secret(session, request, client, existing_secret)
|
||||||
return ClientSecretResponse.from_client_secret(existing_secret)
|
return ClientSecretResponse.from_client_secret(existing_secret)
|
||||||
|
|
||||||
db_secret = ClientSecret(
|
db_secret = ClientSecret(
|
||||||
@ -107,9 +109,9 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
secret=secret_data.value,
|
secret=secret_data.value,
|
||||||
)
|
)
|
||||||
session.add(db_secret)
|
session.add(db_secret)
|
||||||
session.commit()
|
await session.commit()
|
||||||
session.refresh(db_secret)
|
await session.refresh(db_secret)
|
||||||
audit.audit_create_secret(session, request, client, db_secret)
|
await audit.audit_create_secret(session, request, client, db_secret)
|
||||||
return ClientSecretResponse.from_client_secret(db_secret)
|
return ClientSecretResponse.from_client_secret(db_secret)
|
||||||
|
|
||||||
@router.get("/clients/{name}/secrets/{secret_name}")
|
@router.get("/clients/{name}/secrets/{secret_name}")
|
||||||
@ -117,7 +119,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
secret_name: str,
|
secret_name: str,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientSecretResponse:
|
) -> ClientSecretResponse:
|
||||||
"""Get a client secret."""
|
"""Get a client secret."""
|
||||||
client = await get_client_by_id_or_name(session, name)
|
client = await get_client_by_id_or_name(session, name)
|
||||||
@ -133,7 +135,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
)
|
)
|
||||||
|
|
||||||
response_model = ClientSecretResponse.from_client_secret(secret)
|
response_model = ClientSecretResponse.from_client_secret(secret)
|
||||||
audit.audit_access_secret(session, request, client, secret)
|
await audit.audit_access_secret(session, request, client, secret)
|
||||||
return response_model
|
return response_model
|
||||||
|
|
||||||
@router.delete("/clients/{name}/secrets/{secret_name}")
|
@router.delete("/clients/{name}/secrets/{secret_name}")
|
||||||
@ -141,7 +143,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
secret_name: str,
|
secret_name: str,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Delete a secret."""
|
"""Delete a secret."""
|
||||||
client = await get_client_by_id_or_name(session, name)
|
client = await get_client_by_id_or_name(session, name)
|
||||||
@ -156,17 +158,20 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a secret with the given name."
|
status_code=404, detail="Cannot find a secret with the given name."
|
||||||
)
|
)
|
||||||
|
|
||||||
session.delete(secret)
|
await session.delete(secret)
|
||||||
session.commit()
|
await session.commit()
|
||||||
audit.audit_delete_secret(session, request, client, secret)
|
await audit.audit_delete_secret(session, request, client, secret)
|
||||||
|
|
||||||
@router.get("/secrets/")
|
@router.get("/secrets/")
|
||||||
async def get_secret_map(
|
async def get_secret_map(
|
||||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> list[ClientSecretList]:
|
) -> list[ClientSecretList]:
|
||||||
"""Get a list of all secrets and which clients have them."""
|
"""Get a list of all secrets and which clients have them."""
|
||||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
client_secrets = await session.scalars(
|
||||||
|
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||||
|
)
|
||||||
|
for client_secret in client_secrets.all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
if client_secret.name not in client_secret_map:
|
if client_secret.name not in client_secret_map:
|
||||||
client_secret_map[client_secret.name] = []
|
client_secret_map[client_secret.name] = []
|
||||||
@ -177,35 +182,45 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
ClientSecretList(name=secret_name, clients=clients)
|
ClientSecretList(name=secret_name, clients=clients)
|
||||||
for secret_name, clients in client_secret_map.items()
|
for secret_name, clients in client_secret_map.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
@router.get("/secrets/detailed/")
|
@router.get("/secrets/detailed/")
|
||||||
async def get_detailed_secret_map(
|
async def get_detailed_secret_map(
|
||||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> list[ClientSecretDetailList]:
|
) -> list[ClientSecretDetailList]:
|
||||||
"""Get a list of all secrets and which clients have them."""
|
"""Get a list of all secrets and which clients have them."""
|
||||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
all_client_secrets = await session.execute(
|
||||||
|
select(ClientSecret).options(selectinload(ClientSecret.client))
|
||||||
|
)
|
||||||
|
for client_secret in all_client_secrets.scalars().all():
|
||||||
if client_secret.name not in client_secrets:
|
if client_secret.name not in client_secrets:
|
||||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
client_secrets[client_secret.name] = ClientSecretDetailList(
|
||||||
|
name=client_secret.name
|
||||||
|
)
|
||||||
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
client_secrets[client_secret.name].ids.append(str(client_secret.id))
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
continue
|
continue
|
||||||
client_secrets[client_secret.name].clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
client_secrets[client_secret.name].clients.append(
|
||||||
|
ClientReference(
|
||||||
|
id=str(client_secret.client.id), name=client_secret.client.name
|
||||||
|
)
|
||||||
|
)
|
||||||
# `audit.audit_client_secret_list(session, request)
|
# `audit.audit_client_secret_list(session, request)
|
||||||
return list(client_secrets.values())
|
return list(client_secrets.values())
|
||||||
|
|
||||||
|
|
||||||
@router.get("/secrets/{name}")
|
@router.get("/secrets/{name}")
|
||||||
async def get_secret_clients(
|
async def get_secret_clients(
|
||||||
request: Request,
|
|
||||||
name: str,
|
name: str,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientSecretList:
|
) -> ClientSecretList:
|
||||||
"""Get a list of which clients has a named secret."""
|
"""Get a list of which clients has a named secret."""
|
||||||
clients: list[str] = []
|
clients: list[str] = []
|
||||||
for client_secret in session.scalars(
|
client_secrets = await session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.name == name)
|
select(ClientSecret)
|
||||||
).all():
|
.options(selectinload(ClientSecret.client))
|
||||||
|
.where(ClientSecret.name == name)
|
||||||
|
)
|
||||||
|
for client_secret in client_secrets.all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
continue
|
continue
|
||||||
clients.append(client_secret.client.name)
|
clients.append(client_secret.client.name)
|
||||||
@ -214,19 +229,23 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
|
|
||||||
@router.get("/secrets/{name}/detailed")
|
@router.get("/secrets/{name}/detailed")
|
||||||
async def get_secret_clients_detailed(
|
async def get_secret_clients_detailed(
|
||||||
request: Request,
|
|
||||||
name: str,
|
name: str,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> ClientSecretDetailList:
|
) -> ClientSecretDetailList:
|
||||||
"""Get a list of which clients has a named secret."""
|
"""Get a list of which clients has a named secret."""
|
||||||
detail_list = ClientSecretDetailList(name=name)
|
detail_list = ClientSecretDetailList(name=name)
|
||||||
for client_secret in session.scalars(
|
client_secrets = await session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.name == name)
|
select(ClientSecret).where(ClientSecret.name == name)
|
||||||
).all():
|
)
|
||||||
|
for client_secret in client_secrets.all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
continue
|
continue
|
||||||
detail_list.ids.append(str(client_secret.id))
|
detail_list.ids.append(str(client_secret.id))
|
||||||
detail_list.clients.append(ClientReference(id=str(client_secret.client.id), name=client_secret.client.name))
|
detail_list.clients.append(
|
||||||
|
ClientReference(
|
||||||
|
id=str(client_secret.client.id), name=client_secret.client.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return detail_list
|
return detail_list
|
||||||
|
|
||||||
|
|||||||
@ -13,30 +13,32 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import Engine
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
|
|
||||||
from .models import init_db
|
from .models import init_db_async
|
||||||
from .backend_api import get_backend_api
|
from .backend_api import get_backend_api
|
||||||
from .db import setup_database
|
from .db import setup_database, get_async_engine
|
||||||
|
|
||||||
from .settings import BackendSettings
|
from .settings import BackendSettings
|
||||||
from .types import DBSessionDep
|
from .types import AsyncDBSessionDep
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
|
def init_backend_app(settings: BackendSettings) -> FastAPI:
|
||||||
"""Initialize backend app."""
|
"""Initialize backend app."""
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_app: FastAPI):
|
async def lifespan(_app: FastAPI):
|
||||||
"""Create database before starting the server."""
|
"""Create database before starting the server."""
|
||||||
LOG.debug("Running lifespan")
|
LOG.debug("Running lifespan")
|
||||||
init_db(engine)
|
engine = get_async_engine(settings.async_db_url)
|
||||||
|
await init_db_async(engine)
|
||||||
yield
|
yield
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.include_router(get_backend_api(get_db_session))
|
app.include_router(get_backend_api(settings))
|
||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(
|
async def validation_exception_handler(
|
||||||
@ -60,6 +62,4 @@ def init_backend_app(engine: Engine, get_db_session: DBSessionDep) -> FastAPI:
|
|||||||
def create_backend_app(settings: BackendSettings) -> FastAPI:
|
def create_backend_app(settings: BackendSettings) -> FastAPI:
|
||||||
"""Create the backend app."""
|
"""Create the backend app."""
|
||||||
|
|
||||||
engine, get_db_session = setup_database(settings.db_url)
|
return init_backend_app(settings)
|
||||||
|
|
||||||
return init_backend_app(engine, get_db_session)
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||||
|
|
||||||
@ -17,8 +17,8 @@ def _get_origin(request: Request) -> str | None:
|
|||||||
return origin
|
return origin
|
||||||
|
|
||||||
|
|
||||||
def _write_audit_log(
|
async def _write_audit_log(
|
||||||
session: Session, request: Request, entry: AuditLog, commit: bool = True
|
session: AsyncSession, request: Request, entry: AuditLog, commit: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write the audit log."""
|
"""Write the audit log."""
|
||||||
origin = _get_origin(request)
|
origin = _get_origin(request)
|
||||||
@ -26,11 +26,11 @@ def _write_audit_log(
|
|||||||
entry.subsystem = SubSystem.BACKEND
|
entry.subsystem = SubSystem.BACKEND
|
||||||
session.add(entry)
|
session.add(entry)
|
||||||
if commit:
|
if commit:
|
||||||
session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
def audit_create_client(
|
async def audit_create_client(
|
||||||
session: Session, request: Request, client: Client, commit: bool = True
|
session: AsyncSession, request: Request, client: Client, commit: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Log the creation of a client."""
|
"""Log the creation of a client."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
@ -39,11 +39,11 @@ def audit_create_client(
|
|||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client Created",
|
message="Client Created",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_delete_client(
|
async def audit_delete_client(
|
||||||
session: Session, request: Request, client: Client, commit: bool = True
|
session: AsyncSession, request: Request, client: Client, commit: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Log the creation of a client."""
|
"""Log the creation of a client."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
@ -52,11 +52,11 @@ def audit_delete_client(
|
|||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client deleted",
|
message="Client deleted",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_create_secret(
|
async def audit_create_secret(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
secret: ClientSecret,
|
secret: ClientSecret,
|
||||||
@ -71,11 +71,11 @@ def audit_create_secret(
|
|||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Added secret to client",
|
message="Added secret to client",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_remove_policy(
|
async def audit_remove_policy(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
policy: ClientAccessPolicy,
|
policy: ClientAccessPolicy,
|
||||||
@ -90,11 +90,11 @@ def audit_remove_policy(
|
|||||||
message="Deleted client policy",
|
message="Deleted client policy",
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_update_policy(
|
async def audit_update_policy(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
policy: ClientAccessPolicy,
|
policy: ClientAccessPolicy,
|
||||||
@ -109,11 +109,11 @@ def audit_update_policy(
|
|||||||
message="Updated client policy",
|
message="Updated client policy",
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_update_client(
|
async def audit_update_client(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
@ -125,11 +125,11 @@ def audit_update_client(
|
|||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client data updated",
|
message="Client data updated",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_update_secret(
|
async def audit_update_secret(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
secret: ClientSecret,
|
secret: ClientSecret,
|
||||||
@ -144,11 +144,11 @@ def audit_update_secret(
|
|||||||
secret_id=secret.id,
|
secret_id=secret.id,
|
||||||
message="Secret value updated",
|
message="Secret value updated",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_invalidate_secrets(
|
async def audit_invalidate_secrets(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
@ -160,11 +160,11 @@ def audit_invalidate_secrets(
|
|||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
message="Client public-key changed. All secrets invalidated.",
|
message="Client public-key changed. All secrets invalidated.",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_delete_secret(
|
async def audit_delete_secret(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
secret: ClientSecret,
|
secret: ClientSecret,
|
||||||
@ -179,11 +179,11 @@ def audit_delete_secret(
|
|||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
message="Secret removed from client",
|
message="Secret removed from client",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_access_secrets(
|
async def audit_access_secrets(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
secrets: Sequence[ClientSecret] | None = None,
|
secrets: Sequence[ClientSecret] | None = None,
|
||||||
@ -194,19 +194,20 @@ def audit_access_secrets(
|
|||||||
With no secrets provided, all secrets of the client will be resolved.
|
With no secrets provided, all secrets of the client will be resolved.
|
||||||
"""
|
"""
|
||||||
if not secrets:
|
if not secrets:
|
||||||
secrets = session.scalars(
|
secrets_q = await session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all()
|
)
|
||||||
|
secrets = secrets_q.all()
|
||||||
|
|
||||||
for secret in secrets:
|
for secret in secrets:
|
||||||
audit_access_secret(session, request, client, secret, False)
|
await audit_access_secret(session, request, client, secret, False)
|
||||||
|
|
||||||
if commit:
|
if commit:
|
||||||
session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
def audit_access_secret(
|
async def audit_access_secret(
|
||||||
session: Session,
|
session: AsyncSession,
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Client,
|
client: Client,
|
||||||
secret: ClientSecret,
|
secret: ClientSecret,
|
||||||
@ -221,15 +222,15 @@ def audit_access_secret(
|
|||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
def audit_client_secret_list(
|
async def audit_client_secret_list(
|
||||||
session: Session, request: Request, commit: bool = True
|
session: AsyncSession, request: Request, commit: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Audit a list of all secrets."""
|
"""Audit a list of all secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation=Operation.READ,
|
operation=Operation.READ,
|
||||||
message="All secret names and their clients was viewed",
|
message="All secret names and their clients was viewed",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
await _write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,19 @@
|
|||||||
"""Backend API."""
|
"""Backend API."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
import logging
|
import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sshecret_backend.db import DatabaseSessionManager
|
||||||
|
from sshecret_backend.settings import BackendSettings
|
||||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||||
from .auth import verify_token
|
from .auth import verify_token
|
||||||
from .models import (
|
from .models import (
|
||||||
APIClient,
|
APIClient,
|
||||||
)
|
)
|
||||||
from .types import DBSessionDep
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -19,20 +21,25 @@ API_VERSION = "v1"
|
|||||||
|
|
||||||
|
|
||||||
def get_backend_api(
|
def get_backend_api(
|
||||||
get_db_session: DBSessionDep,
|
settings: BackendSettings,
|
||||||
) -> APIRouter:
|
) -> APIRouter:
|
||||||
"""Construct backend API."""
|
"""Construct backend API."""
|
||||||
|
|
||||||
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
sessionmanager = DatabaseSessionManager(settings.async_db_url)
|
||||||
|
async with sessionmanager.session() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
async def validate_token(
|
async def validate_token(
|
||||||
x_api_token: Annotated[str, Header()],
|
x_api_token: Annotated[str, Header()],
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[AsyncSession, Depends(get_db_session)],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Validate token."""
|
"""Validate token."""
|
||||||
LOG.debug("Validating token %s", x_api_token)
|
LOG.debug("Validating token %s", x_api_token)
|
||||||
statement = select(APIClient)
|
statement = select(APIClient)
|
||||||
results = session.scalars(statement)
|
results = await session.scalars(statement)
|
||||||
valid = False
|
valid = False
|
||||||
for result in results:
|
for result in results.all():
|
||||||
if verify_token(x_api_token, result.token):
|
if verify_token(x_api_token, result.token):
|
||||||
valid = True
|
valid = True
|
||||||
LOG.debug("Token is valid")
|
LOG.debug("Token is valid")
|
||||||
|
|||||||
@ -4,10 +4,11 @@ import logging
|
|||||||
import secrets
|
import secrets
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
from collections.abc import Generator, Callable
|
from collections.abc import AsyncIterator, Generator, Callable
|
||||||
from typing import Literal
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any, Literal
|
||||||
from sqlalchemy import create_engine, Engine, event, select
|
from sqlalchemy import create_engine, Engine, event, select
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession, async_sessionmaker, create_async_engine, AsyncEngine
|
||||||
|
|
||||||
from sqlalchemy.orm import sessionmaker, Session
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
|
||||||
@ -20,6 +21,47 @@ from .models import APIClient, SubSystem
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseSessionManager:
|
||||||
|
def __init__(self, host: URL, **engine_kwargs: str):
|
||||||
|
self._engine: AsyncEngine | None = get_async_engine(host)
|
||||||
|
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
await self._engine.dispose()
|
||||||
|
|
||||||
|
self._engine = None
|
||||||
|
self._sessionmaker = None
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def connect(self) -> AsyncIterator[AsyncConnection]:
|
||||||
|
if self._engine is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
async with self._engine.begin() as connection:
|
||||||
|
try:
|
||||||
|
yield connection
|
||||||
|
except Exception:
|
||||||
|
await connection.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||||
|
if self._sessionmaker is None:
|
||||||
|
raise Exception("DatabaseSessionManager is not initialized")
|
||||||
|
|
||||||
|
session = self._sessionmaker()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
except Exception:
|
||||||
|
await session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
def setup_database(
|
def setup_database(
|
||||||
db_url: URL,
|
db_url: URL,
|
||||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||||
@ -39,9 +81,10 @@ def setup_database(
|
|||||||
return engine, get_db_session
|
return engine, get_db_session
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||||
"""Initialize the engine."""
|
"""Initialize the engine."""
|
||||||
engine = create_engine(url, echo=echo, future=True)
|
engine = create_engine(url, echo=echo)
|
||||||
if url.drivername.startswith("sqlite"):
|
if url.drivername.startswith("sqlite"):
|
||||||
|
|
||||||
@event.listens_for(engine, "connect")
|
@event.listens_for(engine, "connect")
|
||||||
@ -55,12 +98,11 @@ def get_engine(url: URL, echo: bool = False) -> Engine:
|
|||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
|
def get_async_engine(url: URL, echo: bool = False, **engine_kwargs: str) -> AsyncEngine:
|
||||||
"""Get an async engine."""
|
"""Get an async engine."""
|
||||||
engine = create_async_engine(url, echo=echo, future=True)
|
engine = create_async_engine(url, echo=echo, **engine_kwargs)
|
||||||
if url.drivername.startswith("sqlite+"):
|
if url.drivername.startswith("sqlite+"):
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
@event.listens_for(engine, "connect")
|
|
||||||
def set_sqlite_pragma(
|
def set_sqlite_pragma(
|
||||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
|
||||||
@ -186,3 +187,9 @@ class AuditLog(Base):
|
|||||||
def init_db(engine: sa.Engine) -> None:
|
def init_db(engine: sa.Engine) -> None:
|
||||||
"""Initialize database."""
|
"""Initialize database."""
|
||||||
Base.metadata.create_all(engine)
|
Base.metadata.create_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db_async(engine: AsyncEngine) -> None:
|
||||||
|
"""Initialize database."""
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
"""Common type definitions."""
|
"""Common type definitions."""
|
||||||
|
|
||||||
from collections.abc import Callable, Generator
|
from collections.abc import AsyncGenerator, Callable, Generator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
|
||||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||||
|
|
||||||
|
AsyncDBSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||||
|
|||||||
Reference in New Issue
Block a user