Implement oidc login
This commit is contained in:
@ -14,6 +14,8 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from sshecret_admin import api, frontend
|
||||
from sshecret_admin.auth.models import PasswordDB, init_db
|
||||
from sshecret_admin.core.db import setup_database
|
||||
@ -28,9 +30,7 @@ LOG = logging.getLogger(__name__)
|
||||
# dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def setup_frontend(
|
||||
app: FastAPI, dependencies: BaseDependencies
|
||||
) -> None:
|
||||
def setup_frontend(app: FastAPI, dependencies: BaseDependencies) -> None:
|
||||
"""Setup frontend."""
|
||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
static_path = script_path.parent / "static"
|
||||
@ -51,15 +51,21 @@ def create_admin_app(
|
||||
settings=settings, regenerate=False
|
||||
)
|
||||
with Session(engine) as session:
|
||||
existing_password = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
existing_password = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
|
||||
if not encr_master_password:
|
||||
if existing_password:
|
||||
LOG.info("Master password already defined.")
|
||||
return
|
||||
# Looks like we have to regenerate it
|
||||
LOG.warning("Master password was set, but not saved to the database. Regenerating it.")
|
||||
encr_master_password = setup_master_password(settings=settings, regenerate=True)
|
||||
LOG.warning(
|
||||
"Master password was set, but not saved to the database. Regenerating it."
|
||||
)
|
||||
encr_master_password = setup_master_password(
|
||||
settings=settings, regenerate=True
|
||||
)
|
||||
|
||||
assert encr_master_password is not None
|
||||
|
||||
@ -76,6 +82,7 @@ def create_admin_app(
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.secret_key)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
@ -95,7 +102,6 @@ def create_admin_app(
|
||||
return response
|
||||
return RedirectResponse(url=str(exc.to))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def get_health() -> JSONResponse:
|
||||
"""Provide simple health check."""
|
||||
@ -105,7 +111,6 @@ def create_admin_app(
|
||||
|
||||
dependencies = BaseDependencies(settings, get_db_session)
|
||||
|
||||
|
||||
app.include_router(api.create_api_router(dependencies))
|
||||
if with_frontend:
|
||||
setup_frontend(app, dependencies)
|
||||
|
||||
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
||||
from sqlalchemy import select, create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.auth.models import PasswordDB, User, init_db
|
||||
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User, init_db
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
|
||||
@ -28,10 +28,15 @@ LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def create_user(session: Session, username: str, password: str) -> None:
|
||||
def create_user(session: Session, username: str, email: str, password: str) -> None:
|
||||
"""Create a user."""
|
||||
hashed_password = hash_password(password)
|
||||
user = User(username=username, hashed_password=hashed_password)
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
hashed_password=hashed_password,
|
||||
provider=AuthProvider.LOCAL,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
@ -58,15 +63,18 @@ def cli(ctx: click.Context, debug: bool) -> None:
|
||||
|
||||
@cli.command("adduser")
|
||||
@click.argument("username")
|
||||
@click.argument("email")
|
||||
@click.password_option()
|
||||
@click.pass_context
|
||||
def cli_create_user(ctx: click.Context, username: str, password: str) -> None:
|
||||
def cli_create_user(
|
||||
ctx: click.Context, username: str, email: str, password: str
|
||||
) -> None:
|
||||
"""Create user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
create_user(session, username, password)
|
||||
create_user(session, username, email, password)
|
||||
|
||||
click.echo("User created.")
|
||||
|
||||
@ -143,7 +151,9 @@ def cli_repl(ctx: click.Context) -> None:
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
|
||||
if not password_db:
|
||||
raise click.ClickException(
|
||||
|
||||
@ -8,7 +8,13 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy import create_engine, Engine
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
create_async_engine,
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
)
|
||||
|
||||
|
||||
def setup_database(
|
||||
@ -16,7 +22,7 @@ def setup_database(
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=False, future=True)
|
||||
engine = create_engine(db_url, echo=True, future=True)
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
@ -29,7 +35,11 @@ def setup_database(
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, host: URL | str, **engine_kwargs: str):
|
||||
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
||||
async_sessionmaker(
|
||||
autocommit=False, bind=self._engine, expire_on_commit=False
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
if self._engine is None:
|
||||
|
||||
@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Main server app."""
|
||||
|
||||
import sys
|
||||
import click
|
||||
from pydantic import ValidationError
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""SSH Server settings."""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from sqlalchemy import URL
|
||||
|
||||
@ -11,11 +11,23 @@ DEFAULT_LISTEN_PORT = 8822
|
||||
DEFAULT_DATABASE = "sshecret_admin.db"
|
||||
|
||||
|
||||
class OidcSettings(BaseModel):
|
||||
"""OIDC settings."""
|
||||
|
||||
name: str
|
||||
config_url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
class AdminServerSettings(BaseSettings):
|
||||
"""Server Settings."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".admin.env", env_prefix="sshecret_admin_", secrets_dir="/var/run"
|
||||
env_file=".admin.env",
|
||||
env_prefix="sshecret_admin_",
|
||||
secrets_dir="/var/run",
|
||||
env_nested_delimiter="__",
|
||||
)
|
||||
|
||||
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
|
||||
@ -26,6 +38,7 @@ class AdminServerSettings(BaseSettings):
|
||||
database: str = Field(default=DEFAULT_DATABASE)
|
||||
debug: bool = False
|
||||
password_manager_directory: Path | None = None
|
||||
oidc: OidcSettings | None = None
|
||||
|
||||
@property
|
||||
def admin_db(self) -> URL:
|
||||
|
||||
Reference in New Issue
Block a user