215 lines
6.6 KiB
Python
215 lines
6.6 KiB
Python
"""Authentication related views factory."""
|
|
|
|
# pyright: reportUnusedFunction=false
|
|
import logging
|
|
from pydantic import BaseModel
|
|
from typing import Annotated
|
|
from fastapi import APIRouter, Depends, Query, Request, Response, status
|
|
from fastapi.responses import RedirectResponse
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sshecret_admin.services import AdminBackend
|
|
from starlette.datastructures import URL
|
|
|
|
from sshecret_admin.auth import (
|
|
IdentityClaims,
|
|
authenticate_user_async,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
)
|
|
|
|
from sshecret.backend.models import Operation
|
|
|
|
from ..dependencies import FrontendDependencies
|
|
from ..exceptions import RedirectException
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
class LoginError(BaseModel):
|
|
"""Login error."""
|
|
|
|
title: str
|
|
message: str
|
|
|
|
|
|
class OidcLogin(BaseModel):
|
|
"""Small container to hold OIDC info for the login box."""
|
|
|
|
enabled: bool = False
|
|
provider_name: str | None = None
|
|
|
|
|
|
async def audit_login_failure(
|
|
admin: AdminBackend, username: str, request: Request
|
|
) -> None:
|
|
"""Write login failure to audit log."""
|
|
origin: str | None = None
|
|
if request.client:
|
|
origin = request.client.host
|
|
await admin.write_audit_message(
|
|
operation=Operation.DENY,
|
|
message="Login failed",
|
|
origin=origin or "UNKNOWN",
|
|
username=username,
|
|
)
|
|
|
|
|
|
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
|
"""Create auth router."""
|
|
|
|
app = APIRouter()
|
|
templates = dependencies.templates
|
|
|
|
@app.get("/login")
|
|
async def get_login(
|
|
request: Request,
|
|
login_status: Annotated[bool, Depends(dependencies.get_login_status)],
|
|
error_title: str | None = None,
|
|
error_message: str | None = None,
|
|
):
|
|
"""Get index."""
|
|
if login_status:
|
|
return RedirectResponse("/dashboard")
|
|
login_error: LoginError | None = None
|
|
if error_title and error_message:
|
|
LOG.info("Got an error here: %s %s", error_title, error_message)
|
|
login_error = LoginError(title=error_title, message=error_message)
|
|
else:
|
|
LOG.info("Got no errors")
|
|
|
|
oidc_login = OidcLogin()
|
|
if dependencies.settings.oidc:
|
|
oidc_login.enabled = True
|
|
oidc_login.provider_name = dependencies.settings.oidc.name
|
|
|
|
return templates.TemplateResponse(
|
|
request,
|
|
"login.html",
|
|
{
|
|
"page_title": "Login",
|
|
"page_description": "Login page.",
|
|
"login_error": login_error,
|
|
"oidc": oidc_login,
|
|
},
|
|
)
|
|
|
|
@app.post("/login")
|
|
async def login_user(
|
|
request: Request,
|
|
response: Response,
|
|
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
next: Annotated[str, Query()] = "/dashboard",
|
|
error_title: str | None = None,
|
|
error_message: str | None = None,
|
|
):
|
|
"""Log in user."""
|
|
if error_title and error_message:
|
|
login_error = LoginError(title=error_title, message=error_message)
|
|
return templates.TemplateResponse(
|
|
request,
|
|
"login.html",
|
|
{
|
|
"page_title": "Login",
|
|
"page_description": "Login page.",
|
|
"login_error": login_error,
|
|
},
|
|
)
|
|
|
|
user = await authenticate_user_async(
|
|
session, form_data.username, form_data.password
|
|
)
|
|
login_failed = RedirectException(
|
|
to=URL("/login").include_query_params(
|
|
error_title="Login Error", error_message="Invalid username or password"
|
|
)
|
|
)
|
|
if not user:
|
|
await audit_login_failure(admin, form_data.username, request)
|
|
raise login_failed
|
|
token_data: dict[str, str] = {"sub": user.username}
|
|
access_token = create_access_token(dependencies.settings, data=token_data)
|
|
refresh_token = create_refresh_token(dependencies.settings, data=token_data)
|
|
response = RedirectResponse(url=next, status_code=status.HTTP_302_FOUND)
|
|
response.set_cookie(
|
|
"access_token",
|
|
value=access_token,
|
|
httponly=True,
|
|
secure=False,
|
|
samesite="strict",
|
|
)
|
|
response.set_cookie(
|
|
"refresh_token",
|
|
value=refresh_token,
|
|
httponly=True,
|
|
secure=False,
|
|
samesite="strict",
|
|
)
|
|
origin = "UNKNOWN"
|
|
if request.client:
|
|
origin = request.client.host
|
|
|
|
await admin.write_audit_message(
|
|
operation=Operation.LOGIN,
|
|
message="Logged in to admin frontend",
|
|
origin=origin,
|
|
username=form_data.username,
|
|
)
|
|
|
|
return response
|
|
|
|
@app.get("/refresh")
|
|
async def get_refresh_token(
|
|
response: Response,
|
|
refresh_claims: Annotated[
|
|
IdentityClaims, Depends(dependencies.get_refresh_claims)
|
|
],
|
|
next: Annotated[str, Query()],
|
|
):
|
|
"""Refresh tokens.
|
|
|
|
We might as well refresh the long-lived one here.
|
|
"""
|
|
token_data: dict[str, str] = {"sub": refresh_claims.sub}
|
|
access_token = create_access_token(
|
|
dependencies.settings, data=token_data, provider=refresh_claims.provider
|
|
)
|
|
refresh_token = create_refresh_token(
|
|
dependencies.settings, data=token_data, provider=refresh_claims.provider
|
|
)
|
|
response = RedirectResponse(url=next, status_code=status.HTTP_302_FOUND)
|
|
response.set_cookie(
|
|
"access_token",
|
|
value=access_token,
|
|
httponly=True,
|
|
secure=False,
|
|
samesite="strict",
|
|
)
|
|
response.set_cookie(
|
|
"refresh_token",
|
|
value=refresh_token,
|
|
httponly=True,
|
|
secure=False,
|
|
samesite="strict",
|
|
)
|
|
return response
|
|
|
|
@app.get("/logout")
|
|
async def logout(
|
|
response: Response,
|
|
):
|
|
"""Log out user."""
|
|
response = RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND)
|
|
response.delete_cookie(
|
|
"refresh_token", httponly=True, secure=False, samesite="strict"
|
|
)
|
|
response.delete_cookie(
|
|
"access_token", httponly=True, secure=False, samesite="strict"
|
|
)
|
|
return response
|
|
|
|
return app
|