Implement oidc login
This commit is contained in:
@ -54,6 +54,7 @@ def run_migrations_offline() -> None:
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
render_as_batch=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@ -74,7 +75,9 @@ def run_migrations_online() -> None:
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata, render_as_batch=True
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
@ -1,47 +0,0 @@
|
||||
"""Create initial migration
|
||||
|
||||
Revision ID: 2a5a599271aa
|
||||
Revises:
|
||||
Create Date: 2025-05-18 22:19:03.739902
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2a5a599271aa'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('password_db',
|
||||
sa.Column('id', sa.INTEGER(), nullable=False),
|
||||
sa.Column('encrypted_password', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('user',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('username', sa.String(), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
||||
sa.Column('disabled', sa.BOOLEAN(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('user')
|
||||
op.drop_table('password_db')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,41 @@
|
||||
"""Make passwords non-optional
|
||||
|
||||
Revision ID: 6c148590471f
|
||||
Revises: 73d5569a8a26
|
||||
Create Date: 2025-05-30 10:15:03.665371
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6c148590471f"
|
||||
down_revision: Union[str, None] = "73d5569a8a26"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"hashed_password", existing_type=sa.VARCHAR(), nullable=True
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("user", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"hashed_password", existing_type=sa.VARCHAR(), nullable=False
|
||||
)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,82 @@
|
||||
"""Create initial migration
|
||||
|
||||
Revision ID: 73d5569a8a26
|
||||
Revises:
|
||||
Create Date: 2025-05-30 10:02:05.130137
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "73d5569a8a26"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"password_db",
|
||||
sa.Column("id", sa.INTEGER(), nullable=False),
|
||||
sa.Column("encrypted_password", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("email", sa.String(), nullable=False),
|
||||
sa.Column("full_name", sa.String(), nullable=True),
|
||||
sa.Column("disabled", sa.BOOLEAN(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("username", sa.String(), nullable=True),
|
||||
sa.Column("hashed_password", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("(CURRENT_TIMESTAMP)"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("oidc_sub", sa.String(), nullable=True),
|
||||
sa.Column("oidc_issuer", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"provider", sa.Enum("LOCAL", "OIDC", name="authprovider"), nullable=False
|
||||
),
|
||||
sa.Column("last_login", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email", name="uq_user_email"),
|
||||
sa.UniqueConstraint("oidc_sub", name="uq_user_oidc_sub"),
|
||||
sa.UniqueConstraint("username", name="uq_user_username"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user")
|
||||
op.drop_table("password_db")
|
||||
# ### end Alembic commands ###
|
||||
@ -8,13 +8,17 @@ authors = [
|
||||
]
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"alembic>=1.15.2",
|
||||
"authlib>=1.6.0",
|
||||
"bcrypt>=4.3.0",
|
||||
"click>=8.1.8",
|
||||
"cryptography>=44.0.2",
|
||||
"fastapi[standard]>=0.115.12",
|
||||
"httpx>=0.28.1",
|
||||
"itsdangerous>=2.2.0",
|
||||
"jinja2>=3.1.6",
|
||||
"jinja2-fragments>=1.9.0",
|
||||
"joserfc>=1.1.0",
|
||||
"pydantic>=2.10.6",
|
||||
"pyjwt>=2.10.1",
|
||||
"pykeepass>=4.1.1.post1",
|
||||
|
||||
@ -12,6 +12,7 @@ from sshecret_admin.core.dependencies import AdminDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
"""Create auth router."""
|
||||
app = APIRouter()
|
||||
@ -35,5 +36,4 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
return app
|
||||
|
||||
@ -25,7 +25,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
|
||||
@app.get("/clients/")
|
||||
async def get_clients(
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)]
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> list[Client]:
|
||||
"""Get clients."""
|
||||
clients = await admin.get_clients()
|
||||
|
||||
@ -23,7 +23,7 @@ def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
|
||||
@app.get("/secrets/")
|
||||
async def get_secret_names(
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)]
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> list[Secret]:
|
||||
"""Get Secret Names."""
|
||||
return await admin.get_secrets()
|
||||
|
||||
@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
from sshecret_admin.core.dependencies import BaseDependencies, AdminDependencies
|
||||
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||
from sshecret_admin.auth.constants import LOCAL_ISSUER
|
||||
|
||||
from .endpoints import auth, clients, secrets
|
||||
|
||||
@ -41,9 +42,17 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
if not token_data:
|
||||
raise credentials_exception
|
||||
|
||||
user = session.scalars(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if token_data.provider == LOCAL_ISSUER:
|
||||
user = session.scalars(
|
||||
select(User).where(User.username == token_data.sub)
|
||||
).first()
|
||||
else:
|
||||
user = session.scalars(
|
||||
select(User)
|
||||
.where(User.oidc_issuer == token_data.provider)
|
||||
.where(User.oidc_sub == token_data.sub)
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
return user
|
||||
@ -57,10 +66,12 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
return current_user
|
||||
|
||||
async def get_admin_backend(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
if not password_db:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
|
||||
@ -9,10 +9,12 @@ from .authentication import (
|
||||
decode_token,
|
||||
verify_password,
|
||||
)
|
||||
from .models import User, Token, PasswordDB
|
||||
from .models import User, Token, PasswordDB, IdentityClaims, LocalUserInfo
|
||||
|
||||
|
||||
__all__ = [
|
||||
"IdentityClaims",
|
||||
"LocalUserInfo",
|
||||
"PasswordDB",
|
||||
"Token",
|
||||
"User",
|
||||
|
||||
@ -5,21 +5,26 @@ from datetime import datetime, timezone, timedelta
|
||||
from typing import cast, Any
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from joserfc import jwt
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from joserfc.jwk import OctKey
|
||||
from joserfc.errors import JoseError
|
||||
|
||||
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
from .models import User, TokenData
|
||||
from .models import AuthProvider, LocalUserInfo, User, IdentityClaims
|
||||
from .exceptions import AuthenticationFailedError
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
|
||||
# sensitive application, seems reasonable.
|
||||
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||
from .constants import (
|
||||
JWT_ALGORITHM,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
REFRESH_TOKEN_EXPIRE_HOURS,
|
||||
LOCAL_ISSUER,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -28,12 +33,14 @@ def create_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta,
|
||||
provider: str,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM)
|
||||
to_encode.update({"exp": expire, "iss": provider})
|
||||
key = OctKey.import_key(settings.secret_key)
|
||||
encoded_jwt = jwt.encode({"alg": JWT_ALGORITHM}, to_encode, key)
|
||||
return str(encoded_jwt)
|
||||
|
||||
|
||||
@ -41,22 +48,24 @@ def create_access_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
provider: str = LOCAL_ISSUER,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
if not expires_delta:
|
||||
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return create_token(settings, data, expires_delta)
|
||||
return create_token(settings, data, expires_delta, provider)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
provider: str = LOCAL_ISSUER,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
if not expires_delta:
|
||||
expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS)
|
||||
return create_token(settings, data, expires_delta)
|
||||
return create_token(settings, data, expires_delta, provider)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
@ -73,9 +82,13 @@ def check_password(plain_password: str, hashed_password: str) -> None:
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
|
||||
async def authenticate_user_async(session: AsyncSession, username: str, password: str) -> User | None:
|
||||
async def authenticate_user_async(
|
||||
session: AsyncSession, username: str, password: str
|
||||
) -> User | None:
|
||||
"""Authenticate user async."""
|
||||
user = (await session.scalars(select(User).where(User.username == username))).first()
|
||||
user = (
|
||||
await session.scalars(select(User).where(User.username == username))
|
||||
).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
@ -83,6 +96,44 @@ async def authenticate_user_async(session: AsyncSession, username: str, password
|
||||
return user
|
||||
|
||||
|
||||
async def handle_oidc_claim(session: AsyncSession, claim: IdentityClaims) -> User:
|
||||
"""Handle OIDC claim.
|
||||
|
||||
Either return an existing user, or create a new one.
|
||||
"""
|
||||
LOG.debug("Looking up OIDC token claim %r", claim)
|
||||
if claim.provider == LOCAL_ISSUER:
|
||||
raise ValueError("IdentityClaims do not originate from OIDC.")
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.oidc_sub == claim.sub)
|
||||
.where(User.oidc_issuer == claim.provider)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
if user := result.scalar_one_or_none():
|
||||
LOG.debug("Found existing user %s", user.id)
|
||||
return user
|
||||
|
||||
LOG.debug("User not found in local database. Creating a new user")
|
||||
user = User(
|
||||
username=claim.username,
|
||||
email=claim.email,
|
||||
disabled=False,
|
||||
oidc_sub=claim.sub,
|
||||
oidc_issuer=claim.provider,
|
||||
provider=AuthProvider.OIDC,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.oidc_sub == claim.sub)
|
||||
.where(User.oidc_issuer == claim.provider)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
||||
"""Authenticate user."""
|
||||
user = session.scalars(select(User).where(User.username == username)).first()
|
||||
@ -93,22 +144,48 @@ def authenticate_user(session: Session, username: str, password: str) -> User |
|
||||
return user
|
||||
|
||||
|
||||
def decode_token(settings: AdminServerSettings, token: str) -> TokenData | None:
|
||||
def decode_token(settings: AdminServerSettings, token: str) -> IdentityClaims | None:
|
||||
"""Decode token."""
|
||||
key = OctKey.import_key(settings.secret_key)
|
||||
try:
|
||||
decoded = jwt.decode(token, key)
|
||||
claims_requests = jwt.JWTClaimsRegistry(
|
||||
exp={"essential": True},
|
||||
sub={"essential": True},
|
||||
)
|
||||
|
||||
claims_requests.validate(decoded.claims)
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
||||
username = cast("str | None", payload.get("sub"))
|
||||
if not username:
|
||||
sub = cast("str | None", payload.claims.get("sub"))
|
||||
if not sub:
|
||||
return None
|
||||
|
||||
token_data = TokenData(username=username)
|
||||
return token_data
|
||||
except jwt.InvalidTokenError as e:
|
||||
issuer = payload.claims.get("iss") or LOCAL_ISSUER
|
||||
|
||||
identity_claims = IdentityClaims(sub=sub, provider=issuer)
|
||||
if issuer == LOCAL_ISSUER:
|
||||
identity_claims.username = sub
|
||||
|
||||
return identity_claims
|
||||
|
||||
except JoseError as e:
|
||||
LOG.debug("Could not decode token: %s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash password."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||||
return hashed_password.decode()
|
||||
|
||||
|
||||
def generate_user_info(user: User) -> LocalUserInfo:
|
||||
"""Generate user info object from a user entry."""
|
||||
is_local = user.provider == AuthProvider.LOCAL
|
||||
if user.username:
|
||||
LOG.info("User has a username: %s", user.username)
|
||||
return LocalUserInfo(id=user.id, display_name=user.username, local=is_local)
|
||||
assert user.email is not None
|
||||
LOG.info("User has no username")
|
||||
return LocalUserInfo(id=user.id, display_name=user.email, local=is_local)
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
"""Constants."""
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
|
||||
# sensitive application, seems reasonable.
|
||||
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||
LOCAL_ISSUER = "urn:sshecret:admin:auth"
|
||||
@ -1,4 +1,5 @@
|
||||
"""Authentication related exceptions."""
|
||||
|
||||
from typing import override
|
||||
|
||||
from .models import LoginError
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Models for authentication."""
|
||||
|
||||
import enum
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
import sqlalchemy as sa
|
||||
@ -15,6 +16,13 @@ ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||
|
||||
|
||||
class AuthProvider(enum.Enum):
|
||||
"""Auth providers."""
|
||||
|
||||
LOCAL = "local"
|
||||
OIDC = "oidc"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
@ -23,17 +31,43 @@ class User(Base):
|
||||
"""Users."""
|
||||
|
||||
__tablename__: str = "user"
|
||||
__table_args__: tuple[sa.UniqueConstraint, ...] = (
|
||||
sa.UniqueConstraint("username", name="uq_user_username"),
|
||||
sa.UniqueConstraint("email", name="uq_user_email"),
|
||||
sa.UniqueConstraint("oidc_sub", name="uq_user_oidc_sub"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
sa.Uuid(as_uuid=True), primary_key=True, default=uuid.uuid4
|
||||
)
|
||||
username: Mapped[str] = mapped_column(sa.String)
|
||||
hashed_password: Mapped[str] = mapped_column(sa.String)
|
||||
|
||||
email: Mapped[str] = mapped_column(sa.String, nullable=False)
|
||||
full_name: Mapped[str] = mapped_column(sa.String, nullable=True)
|
||||
disabled: Mapped[bool] = mapped_column(sa.BOOLEAN, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False
|
||||
)
|
||||
|
||||
username: Mapped[str] = mapped_column(sa.String, nullable=True)
|
||||
hashed_password: Mapped[str] = mapped_column(sa.String, nullable=True)
|
||||
|
||||
updated_at: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
)
|
||||
|
||||
oidc_sub: Mapped[str] = mapped_column(sa.String, nullable=True)
|
||||
oidc_issuer: Mapped[str] = mapped_column(sa.String, nullable=True)
|
||||
|
||||
provider: Mapped[AuthProvider] = mapped_column(
|
||||
sa.Enum(AuthProvider), nullable=False
|
||||
)
|
||||
|
||||
last_login: Mapped[datetime | None] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class PasswordDB(Base):
|
||||
"""Password database."""
|
||||
@ -54,6 +88,15 @@ class PasswordDB(Base):
|
||||
)
|
||||
|
||||
|
||||
class IdentityClaims(BaseModel):
|
||||
"""Normalized identity claim model."""
|
||||
|
||||
sub: str
|
||||
email: str | None = None
|
||||
username: str | None = None
|
||||
provider: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""Token data."""
|
||||
|
||||
@ -74,6 +117,14 @@ class LoginError(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class LocalUserInfo(BaseModel):
|
||||
"""Model used to present a user in the web ui."""
|
||||
|
||||
id: uuid.UUID
|
||||
display_name: str
|
||||
local: bool
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
78
packages/sshecret-admin/src/sshecret_admin/auth/oidc.py
Normal file
78
packages/sshecret-admin/src/sshecret_admin/auth/oidc.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""OIDC Handler class."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from typing import cast
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from fastapi import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sshecret_admin.auth.exceptions import AuthenticationFailedError
|
||||
from sshecret_admin.auth.models import IdentityClaims
|
||||
from sshecret_admin.core.settings import OidcSettings
|
||||
from starlette.datastructures import URL
|
||||
|
||||
|
||||
class OIDCUserInfo(BaseModel):
|
||||
sub: str
|
||||
email: str | None
|
||||
preferred_username: str | None = None
|
||||
name: str | None = None
|
||||
picture: str | None = None
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdminOidc:
|
||||
"""Admin OIDC handler."""
|
||||
|
||||
def __init__(self, settings: OidcSettings) -> None:
|
||||
"""Initialize OIDC handler class."""
|
||||
self.settings: OidcSettings = settings
|
||||
self.provider_name: str = settings.name
|
||||
self.oauth: OAuth = OAuth()
|
||||
self.oauth.register(
|
||||
name=settings.name,
|
||||
server_metadata_url=settings.config_url,
|
||||
client_id=settings.client_id,
|
||||
client_secret=settings.client_secret,
|
||||
client_kwargs={"scope": "openid email profile"},
|
||||
)
|
||||
|
||||
@property
|
||||
def client(self) -> StarletteOAuth2App:
|
||||
"""Get client."""
|
||||
app = cast(
|
||||
StarletteOAuth2App | None, self.oauth.create_client(self.provider_name)
|
||||
)
|
||||
if app is None:
|
||||
raise RuntimeError("Unexpected error when creating Oauth2 client.")
|
||||
return app
|
||||
|
||||
async def start_auth(self, request: Request, redirect_url: URL) -> RedirectResponse:
|
||||
"""Start authentication flow."""
|
||||
response = cast(
|
||||
Awaitable[RedirectResponse],
|
||||
self.client.authorize_redirect(request, redirect_url),
|
||||
)
|
||||
return await response
|
||||
|
||||
async def handle_auth_callback(self, request: Request) -> IdentityClaims:
|
||||
"""Handle auth callback."""
|
||||
try:
|
||||
token = await self.client.authorize_access_token(request)
|
||||
except OAuthError as error:
|
||||
LOG.error("Error from OIDC: %s", error, exc_info=True)
|
||||
raise AuthenticationFailedError(str(error))
|
||||
LOG.info("Token: %r", token)
|
||||
claims = await self.client.parse_id_token(token, None)
|
||||
user_info = OIDCUserInfo.model_validate(claims)
|
||||
return IdentityClaims(
|
||||
sub=user_info.sub,
|
||||
email=user_info.email,
|
||||
provider=self.provider_name,
|
||||
username=user_info.preferred_username,
|
||||
)
|
||||
@ -14,6 +14,8 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
|
||||
from sshecret_admin import api, frontend
|
||||
from sshecret_admin.auth.models import PasswordDB, init_db
|
||||
from sshecret_admin.core.db import setup_database
|
||||
@ -28,9 +30,7 @@ LOG = logging.getLogger(__name__)
|
||||
# dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def setup_frontend(
|
||||
app: FastAPI, dependencies: BaseDependencies
|
||||
) -> None:
|
||||
def setup_frontend(app: FastAPI, dependencies: BaseDependencies) -> None:
|
||||
"""Setup frontend."""
|
||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
static_path = script_path.parent / "static"
|
||||
@ -51,15 +51,21 @@ def create_admin_app(
|
||||
settings=settings, regenerate=False
|
||||
)
|
||||
with Session(engine) as session:
|
||||
existing_password = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
existing_password = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
|
||||
if not encr_master_password:
|
||||
if existing_password:
|
||||
LOG.info("Master password already defined.")
|
||||
return
|
||||
# Looks like we have to regenerate it
|
||||
LOG.warning("Master password was set, but not saved to the database. Regenerating it.")
|
||||
encr_master_password = setup_master_password(settings=settings, regenerate=True)
|
||||
LOG.warning(
|
||||
"Master password was set, but not saved to the database. Regenerating it."
|
||||
)
|
||||
encr_master_password = setup_master_password(
|
||||
settings=settings, regenerate=True
|
||||
)
|
||||
|
||||
assert encr_master_password is not None
|
||||
|
||||
@ -76,6 +82,7 @@ def create_admin_app(
|
||||
yield
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.secret_key)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
@ -95,7 +102,6 @@ def create_admin_app(
|
||||
return response
|
||||
return RedirectResponse(url=str(exc.to))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def get_health() -> JSONResponse:
|
||||
"""Provide simple health check."""
|
||||
@ -105,7 +111,6 @@ def create_admin_app(
|
||||
|
||||
dependencies = BaseDependencies(settings, get_db_session)
|
||||
|
||||
|
||||
app.include_router(api.create_api_router(dependencies))
|
||||
if with_frontend:
|
||||
setup_frontend(app, dependencies)
|
||||
|
||||
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
||||
from sqlalchemy import select, create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth.authentication import hash_password
|
||||
from sshecret_admin.auth.models import PasswordDB, User, init_db
|
||||
from sshecret_admin.auth.models import AuthProvider, PasswordDB, User, init_db
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
|
||||
@ -28,10 +28,15 @@ LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def create_user(session: Session, username: str, password: str) -> None:
|
||||
def create_user(session: Session, username: str, email: str, password: str) -> None:
|
||||
"""Create a user."""
|
||||
hashed_password = hash_password(password)
|
||||
user = User(username=username, hashed_password=hashed_password)
|
||||
user = User(
|
||||
username=username,
|
||||
email=email,
|
||||
hashed_password=hashed_password,
|
||||
provider=AuthProvider.LOCAL,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
@ -58,15 +63,18 @@ def cli(ctx: click.Context, debug: bool) -> None:
|
||||
|
||||
@cli.command("adduser")
|
||||
@click.argument("username")
|
||||
@click.argument("email")
|
||||
@click.password_option()
|
||||
@click.pass_context
|
||||
def cli_create_user(ctx: click.Context, username: str, password: str) -> None:
|
||||
def cli_create_user(
|
||||
ctx: click.Context, username: str, email: str, password: str
|
||||
) -> None:
|
||||
"""Create user."""
|
||||
settings = cast(AdminServerSettings, ctx.obj)
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
create_user(session, username, password)
|
||||
create_user(session, username, email, password)
|
||||
|
||||
click.echo("User created.")
|
||||
|
||||
@ -143,7 +151,9 @@ def cli_repl(ctx: click.Context) -> None:
|
||||
engine = create_engine(settings.admin_db)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
password_db = session.scalars(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
password_db = session.scalars(
|
||||
select(PasswordDB).where(PasswordDB.id == 1)
|
||||
).first()
|
||||
|
||||
if not password_db:
|
||||
raise click.ClickException(
|
||||
|
||||
@ -8,7 +8,13 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy import create_engine, Engine
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncConnection, create_async_engine, AsyncEngine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
create_async_engine,
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
)
|
||||
|
||||
|
||||
def setup_database(
|
||||
@ -16,7 +22,7 @@ def setup_database(
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=False, future=True)
|
||||
engine = create_engine(db_url, echo=True, future=True)
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
@ -29,7 +35,11 @@ def setup_database(
|
||||
class DatabaseSessionManager:
|
||||
def __init__(self, host: URL | str, **engine_kwargs: str):
|
||||
self._engine: AsyncEngine | None = create_async_engine(host, **engine_kwargs)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = async_sessionmaker(autocommit=False, bind=self._engine, expire_on_commit=False)
|
||||
self._sessionmaker: async_sessionmaker[AsyncSession] | None = (
|
||||
async_sessionmaker(
|
||||
autocommit=False, bind=self._engine, expire_on_commit=False
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
if self._engine is None:
|
||||
|
||||
@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Main server app."""
|
||||
|
||||
import sys
|
||||
import click
|
||||
from pydantic import ValidationError
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""SSH Server settings."""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from sqlalchemy import URL
|
||||
|
||||
@ -11,11 +11,23 @@ DEFAULT_LISTEN_PORT = 8822
|
||||
DEFAULT_DATABASE = "sshecret_admin.db"
|
||||
|
||||
|
||||
class OidcSettings(BaseModel):
|
||||
"""OIDC settings."""
|
||||
|
||||
name: str
|
||||
config_url: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
class AdminServerSettings(BaseSettings):
|
||||
"""Server Settings."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".admin.env", env_prefix="sshecret_admin_", secrets_dir="/var/run"
|
||||
env_file=".admin.env",
|
||||
env_prefix="sshecret_admin_",
|
||||
secrets_dir="/var/run",
|
||||
env_nested_delimiter="__",
|
||||
)
|
||||
|
||||
backend_url: AnyHttpUrl = Field(alias="sshecret_backend_url")
|
||||
@ -26,6 +38,7 @@ class AdminServerSettings(BaseSettings):
|
||||
database: str = Field(default=DEFAULT_DATABASE)
|
||||
debug: bool = False
|
||||
password_manager_directory: Path | None = None
|
||||
oidc: OidcSettings | None = None
|
||||
|
||||
@property
|
||||
def admin_db(self) -> URL:
|
||||
|
||||
@ -11,11 +11,14 @@ from fastapi import Request
|
||||
|
||||
from sshecret_admin.core.dependencies import AdminDep, BaseDependencies
|
||||
|
||||
from sshecret_admin.auth.models import User
|
||||
from sshecret_admin.auth.models import IdentityClaims, LocalUserInfo, User
|
||||
|
||||
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
||||
UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
|
||||
LoginStatusDep = Callable[[Request], Awaitable[bool]]
|
||||
AsyncSessionDep = Callable[[], AsyncGenerator[AsyncSession, None]]
|
||||
UserInfoDep = Callable[[Request, AsyncSession], Awaitable[LocalUserInfo]]
|
||||
RefreshTokenDep = Callable[[Request], IdentityClaims]
|
||||
LoginGuardDep = Callable[[Request], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -24,10 +27,11 @@ class FrontendDependencies(BaseDependencies):
|
||||
|
||||
get_admin_backend: AdminDep
|
||||
templates: Jinja2Blocks
|
||||
get_user_from_access_token: UserTokenDep
|
||||
get_user_from_refresh_token: UserTokenDep
|
||||
get_login_status: UserLoginDep
|
||||
get_refresh_claims: RefreshTokenDep
|
||||
get_login_status: LoginStatusDep
|
||||
get_user_info: UserInfoDep
|
||||
get_async_session: AsyncSessionDep
|
||||
require_login: LoginGuardDep
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@ -35,10 +39,11 @@ class FrontendDependencies(BaseDependencies):
|
||||
deps: BaseDependencies,
|
||||
get_admin_backend: AdminDep,
|
||||
templates: Jinja2Blocks,
|
||||
get_user_from_access_token: UserTokenDep,
|
||||
get_user_from_refresh_token: UserTokenDep,
|
||||
get_login_status: UserLoginDep,
|
||||
get_async_session: AsyncSessionDep
|
||||
get_refresh_claims: RefreshTokenDep,
|
||||
get_login_status: LoginStatusDep,
|
||||
get_user_info: UserInfoDep,
|
||||
get_async_session: AsyncSessionDep,
|
||||
require_login: LoginGuardDep,
|
||||
) -> Self:
|
||||
"""Create from base dependencies."""
|
||||
return cls(
|
||||
@ -46,8 +51,9 @@ class FrontendDependencies(BaseDependencies):
|
||||
get_db_session=deps.get_db_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
templates=templates,
|
||||
get_user_from_access_token=get_user_from_access_token,
|
||||
get_user_from_refresh_token=get_user_from_refresh_token,
|
||||
get_refresh_claims=get_refresh_claims,
|
||||
get_login_status=get_login_status,
|
||||
get_user_info=get_user_info,
|
||||
get_async_session=get_async_session,
|
||||
require_login=require_login,
|
||||
)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Frontend exceptions."""
|
||||
|
||||
from starlette.datastructures import URL
|
||||
|
||||
|
||||
|
||||
@ -12,18 +12,23 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from sshecret_admin.auth.authentication import generate_user_info
|
||||
from sshecret_admin.auth.models import AuthProvider, IdentityClaims, LocalUserInfo
|
||||
from starlette.datastructures import URL
|
||||
|
||||
|
||||
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||
from sshecret_admin.auth.constants import LOCAL_ISSUER
|
||||
|
||||
from sshecret_admin.core.dependencies import BaseDependencies
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
from sshecret_admin.core.db import DatabaseSessionManager
|
||||
|
||||
from .dependencies import FrontendDependencies
|
||||
from .exceptions import RedirectException
|
||||
from .views import audit, auth, clients, index, secrets
|
||||
from .views import audit, auth, clients, index, secrets, oidc_auth
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -45,7 +50,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
templates = Jinja2Blocks(directory=template_path)
|
||||
|
||||
async def get_admin_backend(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.scalars(
|
||||
@ -58,66 +63,50 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
yield admin
|
||||
|
||||
async def get_user_from_token(
|
||||
token: str,
|
||||
session: Session,
|
||||
) -> User | None:
|
||||
"""Get user from a token."""
|
||||
token_data = decode_token(dependencies.settings, token)
|
||||
if not token_data:
|
||||
return None
|
||||
user = session.scalars(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if not user or user.disabled:
|
||||
return None
|
||||
return user
|
||||
|
||||
async def get_user_from_refresh_token(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> User:
|
||||
"""Get user from refresh token."""
|
||||
next = URL("/login").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
token = request.cookies.get("refresh_token")
|
||||
if not token:
|
||||
raise credentials_error
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
raise credentials_error
|
||||
return user
|
||||
|
||||
async def get_user_from_access_token(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> User:
|
||||
"""Get user from access token."""
|
||||
def get_identity_claims(request: Request) -> IdentityClaims:
|
||||
"""Get identity claim from session."""
|
||||
token = request.cookies.get("access_token")
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
if not token:
|
||||
raise credentials_error
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
claims = decode_token(dependencies.settings, token)
|
||||
if not claims:
|
||||
raise credentials_error
|
||||
return user
|
||||
return claims
|
||||
|
||||
async def get_login_status(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> bool:
|
||||
def refresh_identity_claims(request: Request) -> IdentityClaims:
|
||||
"""Get identity claim from session for refreshing the token."""
|
||||
token = request.cookies.get("refresh_token")
|
||||
next = URL("/login").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
if not token:
|
||||
raise credentials_error
|
||||
claims = decode_token(dependencies.settings, token)
|
||||
if not claims:
|
||||
raise credentials_error
|
||||
return claims
|
||||
|
||||
async def get_login_status(request: Request) -> bool:
|
||||
"""Get login status."""
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
return False
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
return False
|
||||
return True
|
||||
claims = decode_token(dependencies.settings, token)
|
||||
return claims is not None
|
||||
|
||||
async def require_login(request: Request) -> None:
|
||||
"""Enforce login requirement."""
|
||||
token = request.cookies.get("access_token")
|
||||
LOG.info("User has no cookie")
|
||||
if not token:
|
||||
url = URL("/login").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=url)
|
||||
is_logged_in = await get_login_status(request)
|
||||
if not is_logged_in:
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
async def get_async_session():
|
||||
"""Get async session."""
|
||||
@ -125,14 +114,43 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
async with sessionmanager.session() as session:
|
||||
yield session
|
||||
|
||||
async def get_user_info(
|
||||
request: Request, session: Annotated[AsyncSession, Depends(get_async_session)]
|
||||
) -> LocalUserInfo:
|
||||
"""Get User information."""
|
||||
claims = get_identity_claims(request)
|
||||
if claims.provider == LOCAL_ISSUER:
|
||||
LOG.info("Local user, finding username %s", claims.sub)
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.username == claims.sub)
|
||||
.where(User.provider == AuthProvider.LOCAL)
|
||||
)
|
||||
else:
|
||||
query = (
|
||||
select(User)
|
||||
.where(User.oidc_issuer == claims.provider)
|
||||
.where(User.oidc_sub == claims.sub)
|
||||
)
|
||||
|
||||
result = await session.scalars(query)
|
||||
if user := result.first():
|
||||
if user.disabled:
|
||||
raise RedirectException(to=URL("/logout"))
|
||||
return generate_user_info(user)
|
||||
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
raise RedirectException(to=next)
|
||||
|
||||
view_dependencies = FrontendDependencies.create(
|
||||
dependencies,
|
||||
get_admin_backend,
|
||||
templates,
|
||||
get_user_from_access_token,
|
||||
get_user_from_refresh_token,
|
||||
refresh_identity_claims,
|
||||
get_login_status,
|
||||
get_user_info,
|
||||
get_async_session,
|
||||
require_login,
|
||||
)
|
||||
|
||||
app.include_router(audit.create_router(view_dependencies))
|
||||
@ -140,5 +158,7 @@ def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
app.include_router(clients.create_router(view_dependencies))
|
||||
app.include_router(index.create_router(view_dependencies))
|
||||
app.include_router(secrets.create_router(view_dependencies))
|
||||
if dependencies.settings.oidc:
|
||||
app.include_router(oidc_auth.create_router(view_dependencies))
|
||||
|
||||
return app
|
||||
|
||||
@ -0,0 +1,38 @@
|
||||
<div
|
||||
id="drawer-create-client-default"
|
||||
class="fixed top-0 right-0 z-40 w-full h-screen max-w-xs p-4 overflow-y-auto transition-transform translate-x-full bg-white dark:bg-gray-800"
|
||||
tabindex="-1"
|
||||
aria-labelledby="drawer-label"
|
||||
aria-hidden="true"
|
||||
>
|
||||
<h5
|
||||
id="drawer-label"
|
||||
class="inline-flex items-center mb-6 text-sm font-semibold text-gray-500 uppercase dark:text-gray-400"
|
||||
>
|
||||
New Client
|
||||
</h5>
|
||||
<button
|
||||
type="button"
|
||||
data-drawer-dismiss="drawer-create-client-default"
|
||||
aria-controls="drawer-create-client-default"
|
||||
class="text-gray-400 bg-transparent hover:bg-gray-200 hover:text-gray-900 rounded-lg text-sm p-1.5 absolute top-2.5 right-2.5 inline-flex items-center dark:hover:bg-gray-600 dark:hover:text-white"
|
||||
>
|
||||
<svg
|
||||
aria-hidden="true"
|
||||
class="w-5 h-5"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M4.293 4.293a1 1 0 011.414 0L10 8.586l4.293-4.293a1 1 0 111.414 1.414L11.414 10l4.293 4.293a1 1 0 01-1.414 1.414L10 11.414l-4.293 4.293a1 1 0 01-1.414-1.414L8.586 10 4.293 5.707a1 1 0 010-1.414z"
|
||||
clip-rule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
<span class="sr-only">Close menu</span>
|
||||
</button>
|
||||
<form hx-post="/clients/" hx-target="none">
|
||||
{% include '/clients/drawer_client_create_inner.html.j2' %}
|
||||
</form>
|
||||
</div>
|
||||
@ -0,0 +1,38 @@
|
||||
<div
|
||||
id="drawer-create-secret-default"
|
||||
class="fixed top-0 right-0 z-40 w-full h-screen max-w-xs p-4 overflow-y-auto transition-transform translate-x-full bg-white dark:bg-gray-800"
|
||||
tabindex="-1"
|
||||
aria-labelledby="drawer-label"
|
||||
aria-hidden="true"
|
||||
>
|
||||
<h5
|
||||
id="drawer-label"
|
||||
class="inline-flex items-center mb-6 text-sm font-semibold text-gray-500 uppercase dark:text-gray-400"
|
||||
>
|
||||
New Secret
|
||||
</h5>
|
||||
<button
|
||||
type="button"
|
||||
data-drawer-dismiss="drawer-create-secret-default"
|
||||
aria-controls="drawer-create-secret-default"
|
||||
class="text-gray-400 bg-transparent hover:bg-gray-200 hover:text-gray-900 rounded-lg text-sm p-1.5 absolute top-2.5 right-2.5 inline-flex items-center dark:hover:bg-gray-600 dark:hover:text-white"
|
||||
>
|
||||
<svg
|
||||
aria-hidden="true"
|
||||
class="w-5 h-5"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M4.293 4.293a1 1 0 011.414 0L10 8.586l4.293-4.293a1 1 0 111.414 1.414L11.414 10l4.293 4.293a1 1 0 01-1.414 1.414L10 11.414l-4.293 4.293a1 1 0 01-1.414-1.414L8.586 10 4.293 5.707a1 1 0 010-1.414z"
|
||||
clip-rule="evenodd"
|
||||
></path>
|
||||
</svg>
|
||||
<span class="sr-only">Close menu</span>
|
||||
</button>
|
||||
<form hx-post="/secrets/" hx-target="none">
|
||||
{% include '/secrets/drawer_secret_create_inner.html.j2' %}
|
||||
</form>
|
||||
</div>
|
||||
@ -64,7 +64,31 @@
|
||||
Sign In
|
||||
</button>
|
||||
</form>
|
||||
{% if oidc.enabled %}
|
||||
<div class="w-full items-center text-center my-4 flex">
|
||||
<div
|
||||
class="w-full h-[0.125rem] box-border bg-gray-200 dark:bg-gray-700"
|
||||
></div>
|
||||
<div
|
||||
class="px-4 text-lg text-sm font-medium text-gray-500 dark:text-gray-400"
|
||||
>
|
||||
Or
|
||||
</div>
|
||||
<div
|
||||
class="w-full h-[0.125rem] box-border bg-gray-200 dark:bg-gray-700"
|
||||
></div>
|
||||
</div>
|
||||
<div class="w-full text-center my-4">
|
||||
<a href="/oidc/login">
|
||||
<button
|
||||
class="w-full bg-white hover:bg-gray-100 text-gray-900 border border-gray-300 transition-colors font-medium py-2.5 rounded-lg dark:bg-gray-800 dark:text-gray-400 dark:border-gray-600 dark:hover:text-white dark:hover:bg-gray-700"
|
||||
>
|
||||
Sign in with {{ oidc.provider_name }}
|
||||
</button>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
</div>
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import BaseModel
|
||||
|
||||
from sshecret.backend import AuditFilter, Operation
|
||||
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.auth import LocalUserInfo
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
@ -18,7 +18,6 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PagingInfo(BaseModel):
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
@ -48,7 +47,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
|
||||
async def resolve_audit_entries(
|
||||
request: Request,
|
||||
current_user: User,
|
||||
current_user: LocalUserInfo,
|
||||
admin: AdminBackend,
|
||||
page: int,
|
||||
filters: AuditFilter,
|
||||
@ -82,7 +81,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
{
|
||||
"page_title": "Audit",
|
||||
"entries": audit_log.results,
|
||||
"user": current_user.username,
|
||||
"user": current_user.display_name,
|
||||
"page_info": page_info,
|
||||
"operations": operations,
|
||||
},
|
||||
@ -91,7 +90,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.get("/audit/")
|
||||
async def get_audit_entries(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
current_user: Annotated[LocalUserInfo, Depends(dependencies.get_user_info)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
filters: Annotated[AuditFilter, Depends()],
|
||||
) -> Response:
|
||||
@ -101,7 +100,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.get("/audit/page/{page}")
|
||||
async def get_audit_entries_page(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
current_user: Annotated[LocalUserInfo, Depends(dependencies.get_user_info)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
filters: Annotated[AuditFilter, Depends()],
|
||||
page: int,
|
||||
|
||||
@ -13,7 +13,7 @@ from sshecret_admin.services import AdminBackend
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from sshecret_admin.auth import (
|
||||
User,
|
||||
IdentityClaims,
|
||||
authenticate_user_async,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
@ -34,7 +34,16 @@ class LoginError(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
async def audit_login_failure(admin: AdminBackend, username: str, request: Request) -> None:
|
||||
class OidcLogin(BaseModel):
|
||||
"""Small container to hold OIDC info for the login box."""
|
||||
|
||||
enabled: bool = False
|
||||
provider_name: str | None = None
|
||||
|
||||
|
||||
async def audit_login_failure(
|
||||
admin: AdminBackend, username: str, request: Request
|
||||
) -> None:
|
||||
"""Write login failure to audit log."""
|
||||
origin: str | None = None
|
||||
if request.client:
|
||||
@ -65,7 +74,16 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
return RedirectResponse("/dashboard")
|
||||
login_error: LoginError | None = None
|
||||
if error_title and error_message:
|
||||
LOG.info("Got an error here: %s %s", error_title, error_message)
|
||||
login_error = LoginError(title=error_title, message=error_message)
|
||||
else:
|
||||
LOG.info("Got no errors")
|
||||
|
||||
oidc_login = OidcLogin()
|
||||
if dependencies.settings.oidc:
|
||||
oidc_login.enabled = True
|
||||
oidc_login.provider_name = dependencies.settings.oidc.name
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"login.html",
|
||||
@ -73,6 +91,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"page_title": "Login",
|
||||
"page_description": "Login page.",
|
||||
"login_error": login_error,
|
||||
"oidc": oidc_login,
|
||||
},
|
||||
)
|
||||
|
||||
@ -100,7 +119,9 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
},
|
||||
)
|
||||
|
||||
user = await authenticate_user_async(session, form_data.username, form_data.password)
|
||||
user = await authenticate_user_async(
|
||||
session, form_data.username, form_data.password
|
||||
)
|
||||
login_failed = RedirectException(
|
||||
to=URL("/login").include_query_params(
|
||||
error_title="Login Error", error_message="Invalid username or password"
|
||||
@ -143,16 +164,22 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.get("/refresh")
|
||||
async def get_refresh_token(
|
||||
response: Response,
|
||||
user: Annotated[User, Depends(dependencies.get_user_from_refresh_token)],
|
||||
refresh_claims: Annotated[
|
||||
IdentityClaims, Depends(dependencies.get_refresh_claims)
|
||||
],
|
||||
next: Annotated[str, Query()],
|
||||
):
|
||||
"""Refresh tokens.
|
||||
|
||||
We might as well refresh the long-lived one here.
|
||||
"""
|
||||
token_data: dict[str, str] = {"sub": user.username}
|
||||
access_token = create_access_token(dependencies.settings, data=token_data)
|
||||
refresh_token = create_refresh_token(dependencies.settings, data=token_data)
|
||||
token_data: dict[str, str] = {"sub": refresh_claims.sub}
|
||||
access_token = create_access_token(
|
||||
dependencies.settings, data=token_data, provider=refresh_claims.provider
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
dependencies.settings, data=token_data, provider=refresh_claims.provider
|
||||
)
|
||||
response = RedirectResponse(url=next, status_code=status.HTTP_302_FOUND)
|
||||
response.set_cookie(
|
||||
"access_token",
|
||||
@ -176,8 +203,12 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
):
|
||||
"""Log out user."""
|
||||
response = RedirectResponse(url="/login", status_code=status.HTTP_302_FOUND)
|
||||
response.delete_cookie("refresh_token", httponly=True, secure=False, samesite="strict")
|
||||
response.delete_cookie("access_token", httponly=True, secure=False, samesite="strict")
|
||||
response.delete_cookie(
|
||||
"refresh_token", httponly=True, secure=False, samesite="strict"
|
||||
)
|
||||
response.delete_cookie(
|
||||
"access_token", httponly=True, secure=False, samesite="strict"
|
||||
)
|
||||
return response
|
||||
|
||||
return app
|
||||
|
||||
@ -11,7 +11,7 @@ from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||
from sshecret.backend import ClientFilter
|
||||
from sshecret.backend.models import FilterType
|
||||
from sshecret.crypto import validate_public_key
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.auth import LocalUserInfo
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
@ -20,7 +20,6 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClientUpdate(BaseModel):
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
description: str
|
||||
@ -29,7 +28,6 @@ class ClientUpdate(BaseModel):
|
||||
|
||||
|
||||
class ClientCreate(BaseModel):
|
||||
|
||||
name: str
|
||||
public_key: str
|
||||
description: str | None
|
||||
@ -39,13 +37,14 @@ class ClientCreate(BaseModel):
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create clients router."""
|
||||
|
||||
app = APIRouter()
|
||||
app = APIRouter(dependencies=[Depends(dependencies.require_login)])
|
||||
|
||||
templates = dependencies.templates
|
||||
|
||||
@app.get("/clients")
|
||||
async def get_clients(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
current_user: Annotated[LocalUserInfo, Depends(dependencies.get_user_info)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> Response:
|
||||
"""Get clients."""
|
||||
@ -57,16 +56,13 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
{
|
||||
"page_title": "Clients",
|
||||
"clients": clients,
|
||||
"user": current_user.username,
|
||||
"user": current_user.display_name,
|
||||
},
|
||||
)
|
||||
|
||||
@app.post("/clients/query")
|
||||
async def query_clients(
|
||||
request: Request,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
query: Annotated[str, Form()],
|
||||
) -> Response:
|
||||
@ -88,9 +84,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
async def update_client(
|
||||
request: Request,
|
||||
id: str,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
client: Annotated[ClientUpdate, Form()],
|
||||
):
|
||||
@ -135,9 +128,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
id: str,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> Response:
|
||||
"""Delete a client."""
|
||||
@ -156,9 +146,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
client: Annotated[ClientCreate, Form()],
|
||||
) -> Response:
|
||||
@ -183,9 +170,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.post("/clients/validate/source")
|
||||
async def validate_client_source(
|
||||
request: Request,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
sources: Annotated[str, Form()],
|
||||
) -> Response:
|
||||
"""Validate source."""
|
||||
@ -217,9 +201,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.post("/clients/validate/public_key")
|
||||
async def validate_client_public_key(
|
||||
request: Request,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
public_key: Annotated[str, Form()],
|
||||
) -> Response:
|
||||
"""Validate source."""
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.auth import LocalUserInfo
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
@ -51,25 +51,27 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.get("/dashboard")
|
||||
async def get_dashboard(
|
||||
request: Request,
|
||||
current_user: Annotated[LocalUserInfo, Depends(dependencies.get_user_info)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
):
|
||||
"""Dashboard for mocking up the dashboard."""
|
||||
stats = await get_stats(admin)
|
||||
last_login_events = await admin.get_audit_log_detailed(limit=5, operation="login")
|
||||
last_login_events = await admin.get_audit_log_detailed(
|
||||
limit=5, operation="login"
|
||||
)
|
||||
last_audit_events = await admin.get_audit_log_detailed(limit=10)
|
||||
|
||||
LOG.info("CurrentUser: %r", current_user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"dashboard.html",
|
||||
{
|
||||
"page_title": "sshecret",
|
||||
"user": current_user.username,
|
||||
"user": current_user.display_name,
|
||||
"stats": stats,
|
||||
"last_login_events": last_login_events,
|
||||
"last_audit_events": last_audit_events,
|
||||
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -0,0 +1,142 @@
|
||||
"""Optional OIDC auth module."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sshecret_admin.auth import create_access_token, create_refresh_token
|
||||
from sshecret_admin.auth.authentication import generate_user_info, handle_oidc_claim
|
||||
from sshecret_admin.auth.exceptions import AuthenticationFailedError
|
||||
from sshecret_admin.auth.oidc import AdminOidc
|
||||
from sshecret_admin.frontend.exceptions import RedirectException
|
||||
from sshecret_admin.services import AdminBackend
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from sshecret.backend.models import Operation
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def audit_login_failure(
|
||||
admin: AdminBackend,
|
||||
error_message: str,
|
||||
request: Request,
|
||||
) -> None:
|
||||
"""Write login failure to audit log."""
|
||||
origin: str | None = None
|
||||
if request.client:
|
||||
origin = request.client.host
|
||||
await admin.write_audit_message(
|
||||
operation=Operation.DENY,
|
||||
message="Login failed",
|
||||
origin=origin or "UNKNOWN",
|
||||
provider_error_message=error_message,
|
||||
)
|
||||
|
||||
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create auth router."""
|
||||
|
||||
app = APIRouter()
|
||||
|
||||
def get_oidc_client() -> AdminOidc:
|
||||
"""Get OIDC client dependency."""
|
||||
if not dependencies.settings.oidc:
|
||||
raise RuntimeError("OIDC authentication not configured.")
|
||||
oidc = AdminOidc(dependencies.settings.oidc)
|
||||
return oidc
|
||||
|
||||
@app.get("/oidc/login")
|
||||
async def oidc_login(
|
||||
request: Request, oidc: Annotated[AdminOidc, Depends(get_oidc_client)]
|
||||
) -> RedirectResponse:
|
||||
"""Redirect to oidc login."""
|
||||
redirect_url = request.url_for("oidc_auth")
|
||||
return await oidc.start_auth(request, redirect_url)
|
||||
|
||||
@app.get("/oidc/auth")
|
||||
async def oidc_auth(
|
||||
request: Request,
|
||||
session: Annotated[AsyncSession, Depends(dependencies.get_async_session)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
oidc: Annotated[AdminOidc, Depends(get_oidc_client)],
|
||||
):
|
||||
"""Handle OIDC auth callback."""
|
||||
try:
|
||||
claims = await oidc.handle_auth_callback(request)
|
||||
except AuthenticationFailedError as error:
|
||||
raise RedirectException(
|
||||
to=URL("/login").include_query_params(
|
||||
error_title="Login error from external provider",
|
||||
error_message=str(error),
|
||||
)
|
||||
)
|
||||
except ValidationError as error:
|
||||
LOG.error("Validation error: %s", error, exc_info=True)
|
||||
raise RedirectException(
|
||||
to=URL("/login").include_query_params(
|
||||
error_title="Error parsing claim",
|
||||
error_message="One or more required parameters were not included in the claim.",
|
||||
)
|
||||
)
|
||||
|
||||
# We now have a IdentityClaims object.
|
||||
# We need to check if this matches an existing user, or we need to create a new one.
|
||||
|
||||
user = await handle_oidc_claim(session, claims)
|
||||
user.last_login = datetime.now()
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
# Set cookies
|
||||
token_data: dict[str, str] = {"sub": claims.sub}
|
||||
access_token = create_access_token(
|
||||
dependencies.settings, data=token_data, provider=claims.provider
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
dependencies.settings, data=token_data, provider=claims.provider
|
||||
)
|
||||
user_info = generate_user_info(user)
|
||||
response = HTMLResponse("""
|
||||
<html>
|
||||
<body>
|
||||
<p>Login successful. Redirecting...</p>
|
||||
<script>
|
||||
setTimeout(() => { window.location.href = "/dashboard"; }, 500);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
""")
|
||||
response.set_cookie(
|
||||
"access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
)
|
||||
response.set_cookie(
|
||||
"refresh_token",
|
||||
value=refresh_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
)
|
||||
origin = "UNKNOWN"
|
||||
if request.client:
|
||||
origin = request.client.host
|
||||
await admin.write_audit_message(
|
||||
operation=Operation.LOGIN,
|
||||
message="Logged in to admin frontend",
|
||||
origin=origin,
|
||||
username=user_info.display_name,
|
||||
oidc=claims.provider,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
return app
|
||||
@ -8,7 +8,7 @@ from typing import Annotated, Any
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
from pydantic import BaseModel, BeforeValidator, Field
|
||||
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.auth import LocalUserInfo
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
@ -51,13 +51,13 @@ class CreateSecret(BaseModel):
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create secrets router."""
|
||||
|
||||
app = APIRouter()
|
||||
app = APIRouter(dependencies=[Depends(dependencies.require_login)])
|
||||
templates = dependencies.templates
|
||||
|
||||
@app.get("/secrets/")
|
||||
async def get_secrets(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
current_user: Annotated[LocalUserInfo, Depends(dependencies.get_user_info)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Get secrets index page."""
|
||||
@ -69,7 +69,7 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
{
|
||||
"page_title": "Secrets",
|
||||
"secrets": secrets,
|
||||
"user": current_user.username,
|
||||
"user": current_user.display_name,
|
||||
"clients": clients,
|
||||
},
|
||||
)
|
||||
@ -77,9 +77,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
@app.post("/secrets/")
|
||||
async def add_secret(
|
||||
request: Request,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
secret: Annotated[CreateSecret, Form()],
|
||||
):
|
||||
@ -108,9 +105,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
id: str,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Remove a client's access to a secret."""
|
||||
@ -132,9 +126,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
request: Request,
|
||||
name: str,
|
||||
client: Annotated[str, Form()],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Add a secret to a client."""
|
||||
@ -157,9 +148,6 @@ def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
async def delete_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Delete a secret."""
|
||||
|
||||
@ -9,7 +9,6 @@ from contextlib import contextmanager
|
||||
|
||||
from sshecret.backend import (
|
||||
AuditLog,
|
||||
AuditFilter,
|
||||
AuditListResult,
|
||||
Client,
|
||||
ClientFilter,
|
||||
|
||||
@ -44,7 +44,9 @@ def decrypt_master_password(
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
|
||||
private_key = load_private_key(str(keyfile.absolute()), password=settings.secret_key)
|
||||
private_key = load_private_key(
|
||||
str(keyfile.absolute()), password=settings.secret_key
|
||||
)
|
||||
return decode_string(encrypted, private_key)
|
||||
|
||||
|
||||
@ -69,16 +71,16 @@ def _initial_key_setup(
|
||||
return True
|
||||
|
||||
|
||||
def _generate_master_password(
|
||||
settings: AdminServerSettings, keyfile: Path
|
||||
) -> str:
|
||||
def _generate_master_password(settings: AdminServerSettings, keyfile: Path) -> str:
|
||||
"""Generate master password for password database.
|
||||
|
||||
Returns the encrypted string, base64 encoded.
|
||||
"""
|
||||
if not keyfile.exists():
|
||||
raise RuntimeError("Error: Private key has not been generated yet.")
|
||||
private_key = load_private_key(str(keyfile.absolute()), password=settings.secret_key)
|
||||
private_key = load_private_key(
|
||||
str(keyfile.absolute()), password=settings.secret_key
|
||||
)
|
||||
public_key = private_key.public_key()
|
||||
master_password = _generate_password()
|
||||
return encrypt_string(master_password, public_key)
|
||||
|
||||
@ -75,7 +75,7 @@ class SecretUpdate(BaseModel):
|
||||
|
||||
value: str | AutoGenerateOpts = Field(
|
||||
description="Secret as string value or auto-generated with optional length",
|
||||
examples=["MySecretString", {"auto_generate": True, "length": 32}]
|
||||
examples=["MySecretString", {"auto_generate": True, "length": 32}],
|
||||
)
|
||||
|
||||
def get_secret(self) -> str:
|
||||
@ -85,7 +85,7 @@ class SecretUpdate(BaseModel):
|
||||
"""
|
||||
if isinstance(self.value, str):
|
||||
return self.value
|
||||
secret = secrets.token_urlsafe(32)[:self.value.length]
|
||||
secret = secrets.token_urlsafe(32)[: self.value.length]
|
||||
return secret
|
||||
|
||||
|
||||
@ -93,7 +93,9 @@ class SecretCreate(SecretUpdate):
|
||||
"""Model to create a secret."""
|
||||
|
||||
name: str
|
||||
clients: list[str] | None = Field(default=None, description="Assign the secret to a list of clients.")
|
||||
clients: list[str] | None = Field(
|
||||
default=None, description="Assign the secret to a list of clients."
|
||||
)
|
||||
|
||||
model_config: ConfigDict = ConfigDict(
|
||||
json_schema_extra={
|
||||
@ -101,12 +103,12 @@ class SecretCreate(SecretUpdate):
|
||||
{
|
||||
"name": "MySecret",
|
||||
"clients": ["client-1", "client-2"],
|
||||
"value": { "auto_generate": True, "length": 32 }
|
||||
"value": {"auto_generate": True, "length": 32},
|
||||
},
|
||||
{
|
||||
"name": "MySecret",
|
||||
"value": "mysecretstring",
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
@ -37,6 +37,7 @@
|
||||
--color-teal-300: oklch(85.5% 0.138 181.071);
|
||||
--color-teal-500: oklch(70.4% 0.14 182.503);
|
||||
--color-teal-600: oklch(60% 0.118 184.704);
|
||||
--color-teal-700: oklch(51.1% 0.096 186.391);
|
||||
--color-teal-900: oklch(38.6% 0.063 188.416);
|
||||
--color-blue-200: oklch(88.2% 0.059 254.128);
|
||||
--color-blue-300: oklch(80.9% 0.105 251.813);
|
||||
@ -44,6 +45,7 @@
|
||||
--color-blue-600: oklch(54.6% 0.245 262.881);
|
||||
--color-blue-700: oklch(48.8% 0.243 264.376);
|
||||
--color-blue-800: oklch(42.4% 0.199 265.638);
|
||||
--color-indigo-200: oklch(87% 0.065 274.039);
|
||||
--color-indigo-500: oklch(58.5% 0.233 277.117);
|
||||
--color-indigo-600: oklch(51.1% 0.262 276.966);
|
||||
--color-indigo-700: oklch(45.7% 0.24 277.023);
|
||||
@ -55,12 +57,6 @@
|
||||
--color-pink-200: oklch(89.9% 0.061 343.231);
|
||||
--color-pink-500: oklch(65.6% 0.241 354.308);
|
||||
--color-rose-500: oklch(64.5% 0.246 16.439);
|
||||
--color-slate-50: oklch(98.4% 0.003 247.858);
|
||||
--color-slate-200: oklch(92.9% 0.013 255.508);
|
||||
--color-slate-400: oklch(70.4% 0.04 256.788);
|
||||
--color-slate-500: oklch(55.4% 0.046 257.417);
|
||||
--color-slate-600: oklch(44.6% 0.043 257.281);
|
||||
--color-slate-800: oklch(27.9% 0.041 260.031);
|
||||
--color-gray-50: oklch(98.5% 0.002 247.839);
|
||||
--color-gray-100: oklch(96.7% 0.003 264.542);
|
||||
--color-gray-200: oklch(92.8% 0.006 264.531);
|
||||
@ -417,6 +413,9 @@
|
||||
.m-361 {
|
||||
margin: calc(var(--spacing) * 361);
|
||||
}
|
||||
.mx-2 {
|
||||
margin-inline: calc(var(--spacing) * 2);
|
||||
}
|
||||
.mx-3 {
|
||||
margin-inline: calc(var(--spacing) * 3);
|
||||
}
|
||||
@ -444,6 +443,12 @@
|
||||
.my-10 {
|
||||
margin-block: calc(var(--spacing) * 10);
|
||||
}
|
||||
.my-\[0\.5rem\] {
|
||||
margin-block: 0.5rem;
|
||||
}
|
||||
.my-\[1rem\] {
|
||||
margin-block: 1rem;
|
||||
}
|
||||
.my-auto {
|
||||
margin-block: auto;
|
||||
}
|
||||
@ -585,6 +590,12 @@
|
||||
.ml-auto {
|
||||
margin-left: auto;
|
||||
}
|
||||
.box-border {
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.box-content {
|
||||
box-sizing: content-box;
|
||||
}
|
||||
.block {
|
||||
display: block;
|
||||
}
|
||||
@ -663,6 +674,9 @@
|
||||
.h-32 {
|
||||
height: calc(var(--spacing) * 32);
|
||||
}
|
||||
.h-\[0\.125rem\] {
|
||||
height: 0.125rem;
|
||||
}
|
||||
.h-\[12px\] {
|
||||
height: 12px;
|
||||
}
|
||||
@ -759,24 +773,18 @@
|
||||
.w-\[12px\] {
|
||||
width: 12px;
|
||||
}
|
||||
.w-\[200px\] {
|
||||
width: 200px;
|
||||
}
|
||||
.w-\[400px\] {
|
||||
width: 400px;
|
||||
}
|
||||
.w-auto {
|
||||
width: auto;
|
||||
}
|
||||
.w-full {
|
||||
width: 100%;
|
||||
}
|
||||
.w-max {
|
||||
width: max-content;
|
||||
}
|
||||
.max-w-2xl {
|
||||
max-width: var(--container-2xl);
|
||||
}
|
||||
.max-w-\[20rem\] {
|
||||
max-width: 20rem;
|
||||
}
|
||||
.max-w-\[140px\] {
|
||||
max-width: 140px;
|
||||
}
|
||||
@ -786,6 +794,9 @@
|
||||
.max-w-lg {
|
||||
max-width: var(--container-lg);
|
||||
}
|
||||
.max-w-max {
|
||||
max-width: max-content;
|
||||
}
|
||||
.max-w-md {
|
||||
max-width: var(--container-md);
|
||||
}
|
||||
@ -810,9 +821,6 @@
|
||||
.min-w-9 {
|
||||
min-width: calc(var(--spacing) * 9);
|
||||
}
|
||||
.min-w-\[12rem\] {
|
||||
min-width: 12rem;
|
||||
}
|
||||
.min-w-\[460px\] {
|
||||
min-width: 460px;
|
||||
}
|
||||
@ -1288,6 +1296,9 @@
|
||||
.bg-gray-200 {
|
||||
background-color: var(--color-gray-200);
|
||||
}
|
||||
.bg-gray-700 {
|
||||
background-color: var(--color-gray-700);
|
||||
}
|
||||
.bg-gray-800 {
|
||||
background-color: var(--color-gray-800);
|
||||
}
|
||||
@ -1309,6 +1320,9 @@
|
||||
.bg-green-400 {
|
||||
background-color: var(--color-green-400);
|
||||
}
|
||||
.bg-indigo-200 {
|
||||
background-color: var(--color-indigo-200);
|
||||
}
|
||||
.bg-indigo-600 {
|
||||
background-color: var(--color-indigo-600);
|
||||
}
|
||||
@ -1375,6 +1389,9 @@
|
||||
.bg-teal-100 {
|
||||
background-color: var(--color-teal-100);
|
||||
}
|
||||
.bg-teal-700 {
|
||||
background-color: var(--color-teal-700);
|
||||
}
|
||||
.bg-transparent {
|
||||
background-color: transparent;
|
||||
}
|
||||
@ -1438,6 +1455,9 @@
|
||||
.px-6 {
|
||||
padding-inline: calc(var(--spacing) * 6);
|
||||
}
|
||||
.px-\[1\.125rem\] {
|
||||
padding-inline: 1.125rem;
|
||||
}
|
||||
.py-0\.5 {
|
||||
padding-block: calc(var(--spacing) * 0.5);
|
||||
}
|
||||
@ -1673,18 +1693,9 @@
|
||||
--tw-tracking: var(--tracking-wider);
|
||||
letter-spacing: var(--tracking-wider);
|
||||
}
|
||||
.text-wrap {
|
||||
text-wrap: wrap;
|
||||
}
|
||||
.break-words {
|
||||
overflow-wrap: break-word;
|
||||
}
|
||||
.wrap-normal {
|
||||
overflow-wrap: normal;
|
||||
}
|
||||
.whitespace-normal {
|
||||
white-space: normal;
|
||||
}
|
||||
.whitespace-nowrap {
|
||||
white-space: nowrap;
|
||||
}
|
||||
@ -2461,11 +2472,6 @@
|
||||
translate: var(--tw-translate-x) var(--tw-translate-y);
|
||||
}
|
||||
}
|
||||
.sm\:flex-row {
|
||||
@media (width >= 40rem) {
|
||||
flex-direction: row;
|
||||
}
|
||||
}
|
||||
.sm\:justify-between {
|
||||
@media (width >= 40rem) {
|
||||
justify-content: space-between;
|
||||
@ -2481,15 +2487,6 @@
|
||||
justify-content: flex-end;
|
||||
}
|
||||
}
|
||||
.sm\:space-y-0 {
|
||||
@media (width >= 40rem) {
|
||||
:where(& > :not(:last-child)) {
|
||||
--tw-space-y-reverse: 0;
|
||||
margin-block-start: calc(calc(var(--spacing) * 0) * var(--tw-space-y-reverse));
|
||||
margin-block-end: calc(calc(var(--spacing) * 0) * calc(1 - var(--tw-space-y-reverse)));
|
||||
}
|
||||
}
|
||||
}
|
||||
.sm\:space-x-3 {
|
||||
@media (width >= 40rem) {
|
||||
:where(& > :not(:last-child)) {
|
||||
@ -2648,11 +2645,6 @@
|
||||
margin-top: calc(var(--spacing) * 0);
|
||||
}
|
||||
}
|
||||
.md\:mt-6 {
|
||||
@media (width >= 48rem) {
|
||||
margin-top: calc(var(--spacing) * 6);
|
||||
}
|
||||
}
|
||||
.md\:mr-0 {
|
||||
@media (width >= 48rem) {
|
||||
margin-right: calc(var(--spacing) * 0);
|
||||
@ -2839,12 +2831,6 @@
|
||||
line-height: var(--tw-leading, var(--text-lg--line-height));
|
||||
}
|
||||
}
|
||||
.md\:text-sm {
|
||||
@media (width >= 48rem) {
|
||||
font-size: var(--text-sm);
|
||||
line-height: var(--tw-leading, var(--text-sm--line-height));
|
||||
}
|
||||
}
|
||||
.md\:text-xs {
|
||||
@media (width >= 48rem) {
|
||||
font-size: var(--text-xs);
|
||||
@ -3327,11 +3313,6 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.dark\:bg-green-900 {
|
||||
&:where(.dark, .dark *) {
|
||||
background-color: var(--color-green-900);
|
||||
}
|
||||
}
|
||||
.dark\:bg-orange-400 {
|
||||
&:where(.dark, .dark *) {
|
||||
background-color: var(--color-orange-400);
|
||||
@ -3650,13 +3631,6 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.dark\:focus\:ring-blue-600 {
|
||||
&:where(.dark, .dark *) {
|
||||
&:focus {
|
||||
--tw-ring-color: var(--color-blue-600);
|
||||
}
|
||||
}
|
||||
}
|
||||
.dark\:focus\:ring-gray-600 {
|
||||
&:where(.dark, .dark *) {
|
||||
&:focus {
|
||||
@ -3713,13 +3687,6 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
.dark\:focus\:ring-offset-gray-800 {
|
||||
&:where(.dark, .dark *) {
|
||||
&:focus {
|
||||
--tw-ring-offset-color: var(--color-gray-800);
|
||||
}
|
||||
}
|
||||
}
|
||||
.md\:dark\:hover\:bg-transparent {
|
||||
@media (width >= 48rem) {
|
||||
&:where(.dark, .dark *) {
|
||||
|
||||
Reference in New Issue
Block a user