Compare commits
10 Commits
4f970a3f71
...
7c65d5bb93
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c65d5bb93 | |||
| e22f9b8f10 | |||
| e766b06d5c | |||
| e28276634a | |||
| 3f5d9ea545 | |||
| 6bebbee4fa | |||
| 9bf92d6743 | |||
| 9ccd2f1d4d | |||
| d866553ac1 | |||
| 0a427b6a91 |
7
.dockerignore
Normal file
7
.dockerignore
Normal file
@ -0,0 +1,7 @@
|
||||
.venv
|
||||
.git
|
||||
.github
|
||||
**/.pytest_cache
|
||||
**/__pycache__
|
||||
.ruff_cache
|
||||
**/.testing
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
122
README.md
122
README.md
@ -1,58 +1,112 @@
|
||||
# Sshecret - Openssh based secrets management
|
||||
# Sshecret - Simple Secret management using SSH and RSA Keys
|
||||
|
||||
## Motivation
|
||||
Managing secrets within a container environment or homelab can be quite complex.
|
||||
|
||||
There are many approaches to managing secrets for services, but a lot of these
|
||||
either assume you have one of the industry-standard systems like hashicorp vault to manage them centrally.
|
||||
As with container orchestration in general, many of the available tools are targetted large enterprise installations and add significant complexity to both the management of secrets themselves, and to the consumers.
|
||||
|
||||
For enthusiasts or homelabbers this becomes overkill quickly, and end up
|
||||
consuming a lot more time and energy than what feels justified.
|
||||
For enthusiasts or homelabbers solutions like Hashicorp Vault become overkill
|
||||
quickly, and end up consuming a lot more time and energy than what feels
|
||||
justified.
|
||||
|
||||
This system has been created to provide a centralized solution that works well-enough.
|
||||
|
||||
One clear goal was to have all the complexity on the server-side, and be able to construct a minimal client.
|
||||
Sshecret provides a simple, centralized solution for secret storage, and requires only tools that are commonly pre-installed on most linux system.
|
||||
|
||||
## Components
|
||||
# Concept
|
||||
The system uses RSA keys, as generated by openssh, to encrypt secrets for each client.
|
||||
|
||||
This system has been designed with modularity and extensibility in mind. It has the following building blocks:
|
||||
It uses RSA keys as it is possible to do encryption using only the public key.
|
||||
|
||||
- Password database
|
||||
- Password input handler
|
||||
- Encryption and key management
|
||||
By using a custom SSH server, the consuming servers can fetch a version of a secret encrypted specifically for them.
|
||||
|
||||
As the secret is encrypted with RSA keys, it can be decrypted using the openssl command which is commonly available.
|
||||
|
||||
This means that while the backend interface can be complex, the to access and decrypt a secret, you can use ssh, and a simple bash script.
|
||||
|
||||
# Components
|
||||
|
||||
There are three components to Sshecret:
|
||||
|
||||
- Password database and admin interface
|
||||
- Client secret storage backend
|
||||
- Custom ssh server
|
||||
- Ssh server for clients to access
|
||||
|
||||
### Password database
|
||||
Currently a single password database is implemented: Keepass.
|
||||
The three systems should be deployed separately for security, with the backend system as the only central component.
|
||||
|
||||
Sshecret can create a database, and store your secrets in it.
|
||||
The ssh server that the clients connect to, only has access to encrypted versions of the secrets.
|
||||
|
||||
It only uses a master password for protection, so you are responsible for
|
||||
securing the password database file. In theory, the password database file can
|
||||
be disconnected after encrypting the passwords for the clients, and these two
|
||||
components may be disconnected.
|
||||
If it or the backend should be compromised, it wouldn't be possible to extract any clear-text secrets, since only encrypted values are stored, and each encrypted with RSA public key encryption.
|
||||
|
||||
### Password input handler
|
||||
Passwords can be randomly generated, they can be read from stdin, or from environment variables.
|
||||
## Backend
|
||||
The backend system stores the definition of each client, including their public key, and the IP addresses/networks they are allowed to connect from.
|
||||
It also stores a version of each secret that the client has access to, encrypted with the public key. It has no knowledge of the unencrypted secrets.
|
||||
|
||||
Other methods can be implemented in the future.
|
||||
Both the admin interface and the ssh server access this system over a REST API using a token-based authentication method.
|
||||
|
||||
### Client secret storage backend
|
||||
So far only a simple JSON file based backend has been implemented. It stores one file per client.
|
||||
The interface is flexible, and can be extended to databases or anything else really.
|
||||
## Password database and admin interface
|
||||
The sshecret password database is based on keepass, and the admin interface is available as a simple web interface as well as a REST API.
|
||||
|
||||
### Custom SSH server
|
||||
A custom SSH based on paramiko is included. This is how the clients receive the encrypted password.
|
||||
The client must send a single command over the SSH session equal to the name of the secret.
|
||||
This component is primarily responsible for storing the secrets in a keepass database and populating the backend with client-specific encrypted versions.
|
||||
|
||||
If permitted to access the secret, it will returned encrypted with the client RSA public key of the client, encoded as base64.
|
||||
The admin interface must be secured. It currently supports local user accounts, but will be expanded with OIDC support in the near future.
|
||||
|
||||
## Custom SSH server
|
||||
To make it easy to fetch secrets sshecret includes a SSH server.
|
||||
|
||||
To fetch a secret, simply ssh to it using the client name as username and the
|
||||
registered RSA key as authorization.
|
||||
|
||||
Send the command `get_secret` followed by the name of the secret. The ssh server
|
||||
will check the client's public key against permitted keys, and check if the
|
||||
client is allowed to connect and if the secret is available to it.
|
||||
|
||||
The server will answer with a base64 encoded version of the secret.
|
||||
|
||||
See `examples/sshecret-client.bash` for an example of a bash script that can be
|
||||
used to fetch and decrypt a secret.
|
||||
|
||||
Out of the box, only the `get_secret` command is available, however an optional
|
||||
command `register` can be enabled to allow registration of a client using the
|
||||
SSH interface to make it easy to onboard new clients automatically.
|
||||
|
||||
# Configuration
|
||||
|
||||
Each subsystem is set up using environment variables.
|
||||
|
||||
## Backend
|
||||
The location and type of database may be configured.
|
||||
For now, only sqlite is officially supported.
|
||||
|
||||
The value shown below is the default value. In a docker setup, you may want to
|
||||
create a volume and configure this to a path within the volume to be able to
|
||||
perform backups.
|
||||
|
||||
SSHECRET_BACKEND_DATABASE=/path/to/sshecret.db
|
||||
|
||||
|
||||
While the backend can be placed behind reverse proxy and served with HTTPS, it's
|
||||
probably better to have keep it inside an internal container network.
|
||||
|
||||
## Admin
|
||||
The backend must be generated on the backend before setting up the admin.
|
||||
|
||||
SSHECRET_BACKEND_URL=http://backend:8022
|
||||
SSHECRET_ADMIN_BACKEND_TOKEN: mySuperSecretBackendToken
|
||||
|
||||
|
||||
## Deployment
|
||||
|
||||
The system can be deplyed using docker or any other container runtime.
|
||||
|
||||
See the examples in the `docker/` folder.
|
||||
|
||||
This allows the client to decrypt and get the clear text value easily.
|
||||
|
||||
# FAQ
|
||||
## Why not use Age?
|
||||
I like age a lot, and it's ability to use more ssh key types is certainly a winner feature.
|
||||
However, one goal here is to be able to construct a client with minimal dependencies, and that speaks in favor of the current solution.
|
||||
|
||||
I like age a lot, and it's ability to use more ssh key types is certainly a
|
||||
winner feature. However, a clear goal of this project is to be able to construct
|
||||
a client with minimal dependencies.
|
||||
|
||||
Using just RSA keys, you can construct a client using only the following tools:
|
||||
- base64
|
||||
@ -60,3 +114,5 @@ Using just RSA keys, you can construct a client using only the following tools:
|
||||
- ssh
|
||||
|
||||
This means that you can create a client using just a shell script.
|
||||
|
||||
If age were to be used, the age tool would have to be installed on each client sytem.
|
||||
|
||||
30
docker/Dockerfile.admin
Normal file
30
docker/Dockerfile.admin
Normal file
@ -0,0 +1,30 @@
|
||||
# this Dockerfile should be built from the repo root
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy
|
||||
ENV UV_PYTHON_DOWNLOADS=0
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY . /build
|
||||
|
||||
RUN uv build --package sshecret
|
||||
RUN uv build --package sshecret-admin
|
||||
|
||||
|
||||
FROM python:3.13-slim-bookworm
|
||||
|
||||
COPY --from=builder --chown=app:app /build/dist /opt/sshecret
|
||||
|
||||
RUN pip install /opt/sshecret/sshecret-*.whl
|
||||
RUN pip install /opt/sshecret/sshecret_admin-*.whl
|
||||
|
||||
EXPOSE 8822
|
||||
|
||||
VOLUME /opt/sshecret-admin
|
||||
|
||||
WORKDIR /opt/sshecret-admin
|
||||
|
||||
ENTRYPOINT [ "sshecret-admin" ]
|
||||
|
||||
CMD ["run", "--host", "0.0.0.0"]
|
||||
30
docker/Dockerfile.backend
Normal file
30
docker/Dockerfile.backend
Normal file
@ -0,0 +1,30 @@
|
||||
# this Dockerfile should be built from the repo root
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy
|
||||
ENV UV_PYTHON_DOWNLOADS=0
|
||||
WORKDIR /build
|
||||
|
||||
COPY . /build
|
||||
|
||||
RUN uv build --package sshecret
|
||||
RUN uv build --package sshecret-backend
|
||||
|
||||
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
||||
|
||||
COPY --from=builder --chown=app:app /build/dist /opt/sshecret
|
||||
|
||||
RUN uv pip install --system /opt/sshecret/sshecret-*.whl
|
||||
RUN uv pip install --system /opt/sshecret/sshecret_backend-*.whl
|
||||
|
||||
COPY packages/sshecret-backend /opt/sshecret-backend
|
||||
COPY docker/backend.entrypoint.sh /entrypoint.sh
|
||||
|
||||
WORKDIR /opt/sshecret-backend
|
||||
|
||||
VOLUME /opt/sshecret-backend-db
|
||||
|
||||
EXPOSE 8022
|
||||
|
||||
CMD ["/entrypoint.sh"]
|
||||
26
docker/Dockerfile.sshd
Normal file
26
docker/Dockerfile.sshd
Normal file
@ -0,0 +1,26 @@
|
||||
# this Dockerfile should be built from the repo root
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||
|
||||
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy
|
||||
ENV UV_PYTHON_DOWNLOADS=0
|
||||
WORKDIR /build
|
||||
|
||||
COPY . /build
|
||||
|
||||
RUN uv build --package sshecret
|
||||
RUN uv build --package sshecret-sshd
|
||||
|
||||
FROM python:3.13-slim-bookworm
|
||||
|
||||
COPY --from=builder --chown=app:app /build/dist /opt/sshecret
|
||||
|
||||
RUN pip install /opt/sshecret/sshecret-*.whl
|
||||
RUN pip install /opt/sshecret/sshecret_sshd-*.whl
|
||||
|
||||
WORKDIR /opt/sshecret-sshd
|
||||
|
||||
VOLUME /opt/sshecret-sshd
|
||||
|
||||
EXPOSE 2222
|
||||
|
||||
CMD ["sshecret-sshd", "run", "--host", "0.0.0.0"]
|
||||
15
docker/backend.entrypoint.sh
Executable file
15
docker/backend.entrypoint.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
fail() {
|
||||
printf '%s\n' "$1" >&2 ## Send message to stderr.
|
||||
exit "${2-1}" ## Return a code specified by $2, or 1 by default.
|
||||
}
|
||||
|
||||
[[ -d migrations ]] || fail "Error: Must be run from the backend directory."
|
||||
[[ -d /opt/sshecret-backend-db ]] || mkdir /opt/sshecret-backend-db
|
||||
|
||||
export SSHECRET_BACKEND_DATABASE="/opt/sshecret-backend-db/sshecret.db"
|
||||
|
||||
alembic upgrade head
|
||||
|
||||
sshecret-backend run --host 0.0.0.0
|
||||
19
docker/docker-compose.yml
Normal file
19
docker/docker-compose.yml
Normal file
@ -0,0 +1,19 @@
|
||||
---
|
||||
|
||||
services:
|
||||
backend:
|
||||
image: sshecret-backend
|
||||
container_name: sshecret_backend
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: dockerfile.backend
|
||||
networks:
|
||||
- common
|
||||
volumes:
|
||||
- backend_data
|
||||
|
||||
volumes:
|
||||
backend_data:
|
||||
|
||||
networks:
|
||||
common:
|
||||
@ -19,10 +19,14 @@ dependencies = [
|
||||
"pyjwt>=2.10.1",
|
||||
"pykeepass>=4.1.1.post1",
|
||||
"sqlmodel>=0.0.24",
|
||||
"sshecret",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
sshecret = { workspace = true }
|
||||
|
||||
[project.scripts]
|
||||
sshecret-admin = "sshecret_admin.cli:cli"
|
||||
sshecret-admin = "sshecret_admin.core.cli:cli"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
@ -31,4 +35,5 @@ build-backend = "hatchling.build"
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytailwindcss>=0.2.0",
|
||||
"types-pyjwt>=1.7.1",
|
||||
]
|
||||
|
||||
@ -1,284 +0,0 @@
|
||||
"""Admin API."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
from typing import Annotated
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session, select
|
||||
from sshecret.backend import Client, SshecretBackend
|
||||
from sshecret.backend.models import Secret
|
||||
|
||||
from .admin_backend import AdminBackend
|
||||
from .auth_models import (
|
||||
PasswordDB,
|
||||
Token,
|
||||
TokenData,
|
||||
User,
|
||||
create_access_token,
|
||||
verify_password,
|
||||
)
|
||||
from .settings import AdminServerSettings
|
||||
from .types import DBSessionDep
|
||||
from .view_models import (
|
||||
ClientCreate,
|
||||
SecretCreate,
|
||||
SecretUpdate,
|
||||
SecretView,
|
||||
UpdateKeyModel,
|
||||
UpdateKeyResponse,
|
||||
UpdatePoliciesRequest,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
API_VERSION = "v1"
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
|
||||
|
||||
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
||||
"""Authenticate user."""
|
||||
user = session.exec(select(User).where(User.username == username)).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
async def map_secrets_to_clients(
|
||||
backend: SshecretBackend,
|
||||
) -> defaultdict[str, list[str]]:
|
||||
"""Map secrets to clients."""
|
||||
clients = await backend.get_clients()
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client in clients:
|
||||
for secret in client.secrets:
|
||||
client_secret_map[secret].append(client.name)
|
||||
return client_secret_map
|
||||
|
||||
|
||||
def get_admin_api(
|
||||
get_db_session: DBSessionDep, settings: AdminServerSettings
|
||||
) -> APIRouter:
|
||||
"""Get Admin API."""
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
async def get_admin_backend(session: Annotated[Session, Depends(get_db_session)]):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
if not password_db:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(settings, password_db.encrypted_password)
|
||||
yield admin
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
) -> User:
|
||||
"""Get current user from token."""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except jwt.InvalidTokenError:
|
||||
raise credentials_exception
|
||||
|
||||
user = session.exec(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
"""Get current active user."""
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive or disabled user")
|
||||
return current_user
|
||||
|
||||
app = APIRouter(
|
||||
prefix=f"/api/{API_VERSION}", dependencies=[Depends(get_current_active_user)]
|
||||
)
|
||||
|
||||
@app.post("/token")
|
||||
async def login_for_access_token(
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
"""Login user and generate token."""
|
||||
user = authenticate_user(session, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
settings,
|
||||
data={"sub": user.username},
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@app.get("/clients/")
|
||||
async def get_clients(
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
||||
) -> list[Client]:
|
||||
"""Get clients."""
|
||||
clients = await admin.get_clients()
|
||||
return clients
|
||||
|
||||
@app.post("/clients/")
|
||||
async def create_client(
|
||||
new_client: ClientCreate,
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
) -> Client:
|
||||
"""Create a new client."""
|
||||
sources: list[str] | None = None
|
||||
if new_client.sources:
|
||||
sources = [str(source) for source in new_client.sources]
|
||||
client = await admin.create_client(
|
||||
new_client.name, new_client.public_key, sources
|
||||
)
|
||||
return client
|
||||
|
||||
@app.delete("/clients/{name}")
|
||||
async def delete_client(
|
||||
name: str, admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
await admin.delete_client(name)
|
||||
|
||||
@app.delete("/clients/{name}/secrets/{secret_name}")
|
||||
async def delete_secret_from_client(
|
||||
name: str,
|
||||
secret_name: str,
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
) -> None:
|
||||
"""Delete a secret from a client."""
|
||||
client = await admin.get_client(name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
||||
)
|
||||
|
||||
if secret_name not in client.secrets:
|
||||
LOG.debug("Client does not have requested secret. No action to perform.")
|
||||
return None
|
||||
|
||||
await admin.delete_client_secret(name, secret_name)
|
||||
|
||||
@app.put("/clients/{name}/policies")
|
||||
async def update_client_policies(
|
||||
name: str,
|
||||
updated: UpdatePoliciesRequest,
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
) -> Client:
|
||||
"""Update the client access policies."""
|
||||
client = await admin.get_client(name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
||||
)
|
||||
|
||||
LOG.debug("Old policies: %r. New: %r", client.policies, updated.sources)
|
||||
|
||||
addresses: list[str] = [str(source) for source in updated.sources]
|
||||
await admin.update_client_sources(name, addresses)
|
||||
client = await admin.get_client(name)
|
||||
|
||||
assert client is not None, "Critical: The client disappeared after update!"
|
||||
|
||||
return client
|
||||
|
||||
@app.get("/secrets/")
|
||||
async def get_secret_names(
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
||||
) -> list[Secret]:
|
||||
"""Get Secret Names."""
|
||||
return await admin.get_secrets()
|
||||
|
||||
@app.post("/secrets/")
|
||||
async def add_secret(
|
||||
secret: SecretCreate, admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
||||
) -> None:
|
||||
"""Create a secret."""
|
||||
await admin.add_secret(secret.name, secret.get_secret(), secret.clients)
|
||||
|
||||
@app.get("/secrets/{name}")
|
||||
async def get_secret(
|
||||
name: str, admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
||||
) -> SecretView:
|
||||
"""Get a secret."""
|
||||
secret_view = await admin.get_secret(name)
|
||||
|
||||
if not secret_view:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found."
|
||||
)
|
||||
return secret_view
|
||||
|
||||
@app.put("/secrets/{name}")
|
||||
async def update_secret(
|
||||
name: str,
|
||||
value: SecretUpdate,
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
) -> None:
|
||||
new_value = value.get_secret()
|
||||
await admin.update_secret(name, new_value)
|
||||
|
||||
@app.delete("/secrets/{name}")
|
||||
async def delete_secret(
|
||||
name: str,
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
) -> None:
|
||||
"""Delete secret."""
|
||||
await admin.delete_secret(name)
|
||||
|
||||
@app.put("/clients/{name}/public-key")
|
||||
async def update_client_public_key(
|
||||
name: str,
|
||||
updated: UpdateKeyModel,
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
) -> UpdateKeyResponse:
|
||||
"""Update client public key.
|
||||
|
||||
Updating the public key will invalidate the current secrets, so these well
|
||||
be resolved first, and re-encrypted using the new key.
|
||||
"""
|
||||
# Let's first ensure that the key is actually updated.
|
||||
updated_secrets = await admin.update_client_public_key(name, updated.public_key)
|
||||
return UpdateKeyResponse(
|
||||
public_key=updated.public_key, updated_secrets=updated_secrets
|
||||
)
|
||||
|
||||
@app.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def add_secret_to_client(
|
||||
name: str,
|
||||
secret_name: str,
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
await admin.create_client_secret(name, secret_name)
|
||||
|
||||
return app
|
||||
@ -0,0 +1,5 @@
|
||||
"""Admin REST API."""
|
||||
|
||||
from .router import create_router as create_api_router
|
||||
|
||||
__all__ = ["create_api_router"]
|
||||
@ -0,0 +1 @@
|
||||
"""API Endpoints."""
|
||||
@ -0,0 +1,39 @@
|
||||
"""Authentication related endpoints factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session
|
||||
|
||||
from sshecret_admin.auth import Token, authenticate_user, create_access_token
|
||||
from sshecret_admin.core.dependencies import AdminDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
"""Create auth router."""
|
||||
app = APIRouter()
|
||||
|
||||
@app.post("/token")
|
||||
async def login_for_access_token(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
"""Login user and generate token."""
|
||||
user = authenticate_user(session, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token = create_access_token(
|
||||
dependencies.settings,
|
||||
data={"sub": user.username},
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
|
||||
return app
|
||||
@ -0,0 +1,124 @@
|
||||
"""Client-related endpoints factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from sshecret.backend import Client
|
||||
from sshecret_admin.core.dependencies import AdminDependencies
|
||||
from sshecret_admin.services import AdminBackend
|
||||
from sshecret_admin.services.models import (
|
||||
ClientCreate,
|
||||
UpdateKeyModel,
|
||||
UpdateKeyResponse,
|
||||
UpdatePoliciesRequest,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
"""Create clients router."""
|
||||
app = APIRouter()
|
||||
|
||||
@app.get("/clients/")
|
||||
async def get_clients(
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)]
|
||||
) -> list[Client]:
|
||||
"""Get clients."""
|
||||
clients = await admin.get_clients()
|
||||
return clients
|
||||
|
||||
@app.post("/clients/")
|
||||
async def create_client(
|
||||
new_client: ClientCreate,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> Client:
|
||||
"""Create a new client."""
|
||||
sources: list[str] | None = None
|
||||
if new_client.sources:
|
||||
sources = [str(source) for source in new_client.sources]
|
||||
client = await admin.create_client(
|
||||
new_client.name, new_client.public_key, sources=sources
|
||||
)
|
||||
return client
|
||||
|
||||
@app.delete("/clients/{name}")
|
||||
async def delete_client(
|
||||
name: str,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> None:
|
||||
"""Delete a client."""
|
||||
await admin.delete_client(name)
|
||||
|
||||
@app.delete("/clients/{name}/secrets/{secret_name}")
|
||||
async def delete_secret_from_client(
|
||||
name: str,
|
||||
secret_name: str,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> None:
|
||||
"""Delete a secret from a client."""
|
||||
client = await admin.get_client(name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
||||
)
|
||||
|
||||
if secret_name not in client.secrets:
|
||||
LOG.debug("Client does not have requested secret. No action to perform.")
|
||||
return None
|
||||
|
||||
await admin.delete_client_secret(name, secret_name)
|
||||
|
||||
@app.put("/clients/{name}/policies")
|
||||
async def update_client_policies(
|
||||
name: str,
|
||||
updated: UpdatePoliciesRequest,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> Client:
|
||||
"""Update the client access policies."""
|
||||
client = await admin.get_client(name)
|
||||
if not client:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
||||
)
|
||||
|
||||
LOG.debug("Old policies: %r. New: %r", client.policies, updated.sources)
|
||||
|
||||
addresses: list[str] = [str(source) for source in updated.sources]
|
||||
await admin.update_client_sources(name, addresses)
|
||||
client = await admin.get_client(name)
|
||||
|
||||
assert client is not None, "Critical: The client disappeared after update!"
|
||||
|
||||
return client
|
||||
|
||||
@app.put("/clients/{name}/public-key")
|
||||
async def update_client_public_key(
|
||||
name: str,
|
||||
updated: UpdateKeyModel,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> UpdateKeyResponse:
|
||||
"""Update client public key.
|
||||
|
||||
Updating the public key will invalidate the current secrets, so these well
|
||||
be resolved first, and re-encrypted using the new key.
|
||||
"""
|
||||
# Let's first ensure that the key is actually updated.
|
||||
updated_secrets = await admin.update_client_public_key(name, updated.public_key)
|
||||
return UpdateKeyResponse(
|
||||
public_key=updated.public_key, updated_secrets=updated_secrets
|
||||
)
|
||||
|
||||
@app.put("/clients/{name}/secrets/{secret_name}")
|
||||
async def add_secret_to_client(
|
||||
name: str,
|
||||
secret_name: str,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> None:
|
||||
"""Add secret to a client."""
|
||||
await admin.create_client_secret(name, secret_name)
|
||||
|
||||
return app
|
||||
@ -0,0 +1,70 @@
|
||||
"""Secrets related endpoints factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from sshecret.backend.models import Secret
|
||||
from sshecret_admin.core.dependencies import AdminDependencies
|
||||
from sshecret_admin.services import AdminBackend
|
||||
from sshecret_admin.services.models import (
|
||||
SecretCreate,
|
||||
SecretUpdate,
|
||||
SecretView,
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||
"""Create secrets router."""
|
||||
app = APIRouter()
|
||||
|
||||
@app.get("/secrets/")
|
||||
async def get_secret_names(
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)]
|
||||
) -> list[Secret]:
|
||||
"""Get Secret Names."""
|
||||
return await admin.get_secrets()
|
||||
|
||||
@app.post("/secrets/")
|
||||
async def add_secret(
|
||||
secret: SecretCreate,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> None:
|
||||
"""Create a secret."""
|
||||
await admin.add_secret(secret.name, secret.get_secret(), secret.clients)
|
||||
|
||||
@app.get("/secrets/{name}")
|
||||
async def get_secret(
|
||||
name: str,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> SecretView:
|
||||
"""Get a secret."""
|
||||
secret_view = await admin.get_secret(name)
|
||||
|
||||
if not secret_view:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found."
|
||||
)
|
||||
return secret_view
|
||||
|
||||
@app.put("/secrets/{name}")
|
||||
async def update_secret(
|
||||
name: str,
|
||||
value: SecretUpdate,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> None:
|
||||
new_value = value.get_secret()
|
||||
await admin.update_secret(name, new_value)
|
||||
|
||||
@app.delete("/secrets/{name}")
|
||||
async def delete_secret(
|
||||
name: str,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> None:
|
||||
"""Delete secret."""
|
||||
await admin.delete_secret(name)
|
||||
|
||||
return app
|
||||
78
packages/sshecret-admin/src/sshecret_admin/api/router.py
Normal file
78
packages/sshecret-admin/src/sshecret_admin/api/router.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Main API Router."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
from sshecret_admin.core.dependencies import BaseDependencies, AdminDependencies
|
||||
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||
|
||||
from .endpoints import auth, clients, secrets
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
API_VERSION = "v1"
|
||||
|
||||
|
||||
def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
"""Create clients router."""
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> User:
|
||||
"""Get current user from token."""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
token_data = decode_token(dependencies.settings, token)
|
||||
if not token_data:
|
||||
raise credentials_exception
|
||||
|
||||
user = session.exec(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
return user
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
) -> User:
|
||||
"""Get current active user."""
|
||||
if current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive or disabled user")
|
||||
return current_user
|
||||
|
||||
async def get_admin_backend(session: Annotated[Session, Depends(dependencies.get_db_session)]):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
if not password_db:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
yield admin
|
||||
|
||||
app = APIRouter(
|
||||
prefix=f"/api/{API_VERSION}", dependencies=[Depends(get_current_active_user)]
|
||||
)
|
||||
|
||||
endpoint_deps = AdminDependencies.create(dependencies, get_admin_backend)
|
||||
|
||||
app.include_router(auth.create_router(endpoint_deps))
|
||||
app.include_router(clients.create_router(endpoint_deps))
|
||||
app.include_router(secrets.create_router(endpoint_deps))
|
||||
|
||||
return app
|
||||
24
packages/sshecret-admin/src/sshecret_admin/auth/__init__.py
Normal file
24
packages/sshecret-admin/src/sshecret_admin/auth/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Authentication related module."""
|
||||
|
||||
from .authentication import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
check_password,
|
||||
decode_token,
|
||||
verify_password,
|
||||
)
|
||||
from .models import User, Token, PasswordDB
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PasswordDB",
|
||||
"Token",
|
||||
"User",
|
||||
"authenticate_user",
|
||||
"check_password",
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"decode_token",
|
||||
"verify_password",
|
||||
]
|
||||
@ -0,0 +1,95 @@
|
||||
"""Authentication utilities."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import cast, Any
|
||||
|
||||
import bcrypt
|
||||
import jwt
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from .models import User, TokenData
|
||||
from .exceptions import AuthenticationFailedError
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
|
||||
# sensitive application, seems reasonable.
|
||||
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM)
|
||||
return str(encoded_jwt)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
if not expires_delta:
|
||||
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
return create_token(settings, data, expires_delta)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
if not expires_delta:
|
||||
expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS)
|
||||
return create_token(settings, data, expires_delta)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify password against stored hash."""
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def check_password(plain_password: str, hashed_password: str) -> None:
|
||||
"""Check password.
|
||||
|
||||
If password doesn't match, throw AuthenticationFailedError.
|
||||
"""
|
||||
if not verify_password(plain_password, hashed_password):
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
|
||||
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
||||
"""Authenticate user."""
|
||||
user = session.exec(select(User).where(User.username == username)).first()
|
||||
if not user:
|
||||
return None
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return None
|
||||
return user
|
||||
|
||||
|
||||
def decode_token(settings: AdminServerSettings, token: str) -> TokenData | None:
|
||||
"""Decode token."""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
||||
username = cast("str | None", payload.get("sub"))
|
||||
if not username:
|
||||
return None
|
||||
|
||||
token_data = TokenData(username=username)
|
||||
return token_data
|
||||
except jwt.InvalidTokenError as e:
|
||||
LOG.debug("Could not decode token: %s", e, exc_info=True)
|
||||
return None
|
||||
@ -0,0 +1,30 @@
|
||||
"""Authentication related exceptions."""
|
||||
from typing import override
|
||||
|
||||
from .models import LoginError
|
||||
|
||||
|
||||
class AuthenticationFailedError(Exception):
|
||||
"""Authentication failed."""
|
||||
|
||||
@override
|
||||
def __init__(self, message: str | None = None) -> None:
|
||||
"""Initialize exception class."""
|
||||
if not message:
|
||||
message = "Invalid user or password."
|
||||
super().__init__(message)
|
||||
self.login_error: LoginError = LoginError(
|
||||
title="Authentication Failed", message=message
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationNeededError(Exception):
|
||||
"""Authentication needed error."""
|
||||
|
||||
@override
|
||||
def __init__(self, message: str | None = None) -> None:
|
||||
"""Initialize exception class."""
|
||||
if not message:
|
||||
message = "You need to be logged in to continue."
|
||||
super().__init__(message)
|
||||
self.login_error: LoginError = LoginError(title="Unauthorized", message=message)
|
||||
71
packages/sshecret-admin/src/sshecret_admin/auth/models.py
Normal file
71
packages/sshecret-admin/src/sshecret_admin/auth/models.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Models for authentication."""
|
||||
|
||||
from datetime import datetime
|
||||
import sqlalchemy as sa
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
|
||||
# sensitive application, seems reasonable.
|
||||
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
"""Users."""
|
||||
|
||||
username: str = Field(unique=True, primary_key=True)
|
||||
hashed_password: str
|
||||
disabled: bool = Field(default=False)
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class PasswordDB(SQLModel, table=True):
|
||||
"""Password database."""
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
encrypted_password: str
|
||||
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
)
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
class TokenData(SQLModel):
|
||||
"""Token data."""
|
||||
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class Token(SQLModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class LoginError(SQLModel):
|
||||
"""Login Error model."""
|
||||
# TODO: Remove this.
|
||||
|
||||
title: str
|
||||
message: str
|
||||
|
||||
@ -1,125 +0,0 @@
|
||||
"""Models for authentication."""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import bcrypt
|
||||
import sqlalchemy as sa
|
||||
from typing import Any, override
|
||||
import jwt
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sshecret_admin.settings import AdminServerSettings
|
||||
|
||||
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
"""Users."""
|
||||
|
||||
username: str = Field(unique=True, primary_key=True)
|
||||
hashed_password: str
|
||||
disabled: bool = Field(default=False)
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class PasswordDB(SQLModel, table=True):
|
||||
"""Password database."""
|
||||
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
encrypted_password: str
|
||||
|
||||
created_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"server_default": sa.func.now()},
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
updated_at: datetime | None = Field(
|
||||
default=None,
|
||||
sa_type=sa.DateTime(timezone=True),
|
||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||
)
|
||||
|
||||
|
||||
def init_db(engine: sa.Engine) -> None:
|
||||
"""Create database."""
|
||||
SQLModel.metadata.create_all(engine)
|
||||
|
||||
|
||||
class TokenData(SQLModel):
|
||||
"""Token data."""
|
||||
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class Token(SQLModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
def create_access_token(
|
||||
settings: AdminServerSettings,
|
||||
data: dict[str, Any],
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> str:
|
||||
"""Create access token."""
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify password against stored hash."""
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def check_password(plain_password: str, hashed_password: str) -> None:
|
||||
"""Check password.
|
||||
|
||||
If password doesn't match, throw AuthenticationFailedError.
|
||||
"""
|
||||
if not verify_password(plain_password, hashed_password):
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
|
||||
class LoginError(SQLModel):
|
||||
"""Login Error model."""
|
||||
|
||||
title: str
|
||||
message: str
|
||||
|
||||
|
||||
class AuthenticationFailedError(Exception):
|
||||
"""Authentication failed."""
|
||||
|
||||
@override
|
||||
def __init__(self, message: str | None = None) -> None:
|
||||
"""Initialize exception class."""
|
||||
if not message:
|
||||
message = "Invalid user or password."
|
||||
super().__init__(message)
|
||||
self.login_error: LoginError = LoginError(
|
||||
title="Authentication Failed", message=message
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationNeededError(Exception):
|
||||
"""Authentication needed error."""
|
||||
|
||||
@override
|
||||
def __init__(self, message: str | None = None) -> None:
|
||||
"""Initialize exception class."""
|
||||
if not message:
|
||||
message = "You need to be logged in to continue."
|
||||
super().__init__(message)
|
||||
self.login_error: LoginError = LoginError(title="Unauthorized", message=message)
|
||||
@ -5,24 +5,22 @@
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi import FastAPI, Request, Response, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sshecret_admin import api, frontend
|
||||
from sshecret_admin.auth.models import PasswordDB, init_db
|
||||
from sshecret_admin.core.db import setup_database
|
||||
from sshecret_admin.frontend.exceptions import RedirectException
|
||||
from sshecret_admin.services.master_password import setup_master_password
|
||||
|
||||
from .admin_api import get_admin_api
|
||||
from .auth_models import init_db, PasswordDB, AuthenticationFailedError, AuthenticationNeededError
|
||||
from .db import setup_database
|
||||
from .master_password import setup_master_password
|
||||
from .dependencies import BaseDependencies
|
||||
from .settings import AdminServerSettings
|
||||
from .frontend import create_frontend
|
||||
from .types import DBSessionDep
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -30,15 +28,14 @@ LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_frontend(
|
||||
app: FastAPI, settings: AdminServerSettings, get_db_session: DBSessionDep
|
||||
app: FastAPI, dependencies: BaseDependencies
|
||||
) -> None:
|
||||
"""Setup frontend."""
|
||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
static_path = script_path / "static"
|
||||
static_path = script_path.parent / "static"
|
||||
|
||||
app.mount("/static", StaticFiles(directory=static_path), name="static")
|
||||
frontend = create_frontend(settings, get_db_session)
|
||||
app.include_router(frontend)
|
||||
app.include_router(frontend.create_frontend_router(dependencies))
|
||||
|
||||
|
||||
def create_admin_app(
|
||||
@ -88,19 +85,15 @@ def create_admin_app(
|
||||
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||
)
|
||||
|
||||
@app.exception_handler(AuthenticationNeededError)
|
||||
async def authentication_needed_handler(
|
||||
request: Request, exc: AuthenticationNeededError,
|
||||
):
|
||||
qs = f"error_title={exc.login_error.title}&error_message={exc.login_error.message}"
|
||||
return RedirectResponse(f"/?{qs}")
|
||||
@app.exception_handler(RedirectException)
|
||||
async def redirect_handler(request: Request, exc: RedirectException) -> Response:
|
||||
"""Handle redirect exceptions."""
|
||||
if "hx-request" in request.headers:
|
||||
response = Response()
|
||||
response.headers["HX-Redirect"] = str(exc.to)
|
||||
return response
|
||||
return RedirectResponse(url=str(exc.to))
|
||||
|
||||
@app.exception_handler(AuthenticationFailedError)
|
||||
async def authentication_failed_handler(
|
||||
request: Request, exc: AuthenticationNeededError,
|
||||
):
|
||||
qs = f"error_title={exc.login_error.title}&error_message={exc.login_error.message}"
|
||||
return RedirectResponse(f"/?{qs}")
|
||||
|
||||
@app.get("/health")
|
||||
async def get_health() -> JSONResponse:
|
||||
@ -109,10 +102,11 @@ def create_admin_app(
|
||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
||||
)
|
||||
|
||||
admin_api = get_admin_api(get_db_session, settings)
|
||||
dependencies = BaseDependencies(settings, get_db_session)
|
||||
|
||||
app.include_router(admin_api)
|
||||
|
||||
app.include_router(api.create_api_router(dependencies))
|
||||
if with_frontend:
|
||||
setup_frontend(app, settings, get_db_session)
|
||||
setup_frontend(app, dependencies)
|
||||
|
||||
return app
|
||||
@ -7,29 +7,30 @@ import logging
|
||||
from typing import Any, cast
|
||||
import bcrypt
|
||||
import click
|
||||
from sshecret_admin.admin_backend import AdminBackend
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
import uvicorn
|
||||
from pydantic import ValidationError
|
||||
from sqlmodel import Session, create_engine, select
|
||||
from .auth_models import init_db, User, PasswordDB
|
||||
from .settings import AdminServerSettings
|
||||
from sshecret_admin.auth.models import init_db, User, PasswordDB
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s [%(processName)s: %(process)d] [%(threadName)s: %(thread)d] [%(levelname)s] %(name)s: %(message)s")
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s [%(processName)s: %(process)d] [%(threadName)s: %(thread)d] [%(levelname)s] %(name)s: %(message)s"
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
LOG = logging.getLogger()
|
||||
LOG.addHandler(handler)
|
||||
LOG.setLevel(logging.INFO)
|
||||
|
||||
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash password."""
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||||
return hashed_password.decode()
|
||||
|
||||
|
||||
def create_user(session: Session, username: str, password: str) -> None:
|
||||
"""Create a user."""
|
||||
hashed_password = hash_password(password)
|
||||
@ -48,7 +49,9 @@ def cli(ctx: click.Context, debug: bool) -> None:
|
||||
try:
|
||||
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
||||
except ValidationError as e:
|
||||
raise click.ClickException("Error: One or more required environment options are missing.") from e
|
||||
raise click.ClickException(
|
||||
"Error: One or more required environment options are missing."
|
||||
) from e
|
||||
ctx.obj = settings
|
||||
|
||||
|
||||
@ -66,6 +69,7 @@ def cli_create_user(ctx: click.Context, username: str, password: str) -> None:
|
||||
|
||||
click.echo("User created.")
|
||||
|
||||
|
||||
@cli.command("passwd")
|
||||
@click.argument("username")
|
||||
@click.password_option()
|
||||
@ -85,6 +89,7 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
|
||||
session.commit()
|
||||
click.echo("Password updated.")
|
||||
|
||||
|
||||
@cli.command("deluser")
|
||||
@click.argument("username")
|
||||
@click.confirmation_option()
|
||||
@ -112,7 +117,9 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
|
||||
@click.option("--workers", type=click.INT)
|
||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||
"""Run the server."""
|
||||
uvicorn.run("sshecret_admin.main:app", host=host, port=port, reload=dev, workers=workers)
|
||||
uvicorn.run(
|
||||
"sshecret_admin.core.main:app", host=host, port=port, reload=dev, workers=workers
|
||||
)
|
||||
|
||||
|
||||
@cli.command("repl")
|
||||
@ -126,7 +133,9 @@ def cli_repl(ctx: click.Context) -> None:
|
||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
|
||||
if not password_db:
|
||||
raise click.ClickException("Error: Password database has not yet been setup. Start the server to finish setup.")
|
||||
raise click.ClickException(
|
||||
"Error: Password database has not yet been setup. Start the server to finish setup."
|
||||
)
|
||||
|
||||
def run(func: Awaitable[Any]) -> Any:
|
||||
"""Run an async function."""
|
||||
@ -0,0 +1,37 @@
|
||||
"""Common type definitions."""
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Self
|
||||
|
||||
from sqlmodel import Session
|
||||
from sshecret_admin.services import AdminBackend
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
|
||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
|
||||
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseDependencies:
|
||||
"""Base level dependencies."""
|
||||
|
||||
settings: AdminServerSettings
|
||||
get_db_session: DBSessionDep
|
||||
|
||||
@dataclass
|
||||
class AdminDependencies(BaseDependencies):
|
||||
"""Dependency class with admin."""
|
||||
|
||||
get_admin_backend: AdminDep
|
||||
|
||||
@classmethod
|
||||
def create(cls, deps: BaseDependencies, get_admin_backend: AdminDep) -> Self:
|
||||
"""Create from base dependencies."""
|
||||
return cls(
|
||||
settings=deps.settings,
|
||||
get_db_session=deps.get_db_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
)
|
||||
@ -1,6 +1,5 @@
|
||||
"""Main server app."""
|
||||
import sys
|
||||
import uvicorn
|
||||
import click
|
||||
from pydantic import ValidationError
|
||||
|
||||
@ -2,11 +2,12 @@
|
||||
|
||||
from pydantic import AnyHttpUrl, Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from sqlalchemy import URL
|
||||
|
||||
|
||||
DEFAULT_LISTEN_PORT = 8822
|
||||
|
||||
DEFAULT_DATABASE = "sqlite:///ssh_admin.db"
|
||||
DEFAULT_DATABASE = "ssh_admin.db"
|
||||
|
||||
|
||||
class AdminServerSettings(BaseSettings):
|
||||
@ -21,5 +22,12 @@ class AdminServerSettings(BaseSettings):
|
||||
listen_address: str = Field(default="")
|
||||
secret_key: str
|
||||
port: int = DEFAULT_LISTEN_PORT
|
||||
admin_db: str = Field(default=DEFAULT_DATABASE)
|
||||
|
||||
database: str = Field(default=DEFAULT_DATABASE)
|
||||
#admin_db: str = Field(default=DEFAULT_DATABASE)
|
||||
debug: bool = False
|
||||
|
||||
@property
|
||||
def admin_db(self) -> URL:
|
||||
"""Construct database url."""
|
||||
return URL.create(drivername="sqlite", database=self.database)
|
||||
@ -1,240 +0,0 @@
|
||||
"""Frontend methods."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
from sqlmodel import Session, select
|
||||
from sshecret_admin.settings import AdminServerSettings
|
||||
from sshecret.backend import SshecretBackend
|
||||
from .admin_backend import AdminBackend
|
||||
from .auth_models import (
|
||||
JWT_ALGORITHM,
|
||||
AuthenticationFailedError,
|
||||
AuthenticationNeededError,
|
||||
LoginError,
|
||||
PasswordDB,
|
||||
User,
|
||||
TokenData,
|
||||
create_access_token,
|
||||
verify_password,
|
||||
)
|
||||
from .types import DBSessionDep
|
||||
from .views import create_audit_view, create_client_view, create_secrets_view
|
||||
|
||||
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 45
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def login_error(templates: Jinja2Blocks, request: Request):
|
||||
"""Return a login error."""
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"page_title": "Login",
|
||||
"page_description": "Login Page",
|
||||
"error": "Invalid Login.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_frontend(
|
||||
settings: AdminServerSettings, get_db_session: DBSessionDep
|
||||
) -> APIRouter:
|
||||
"""Create frontend."""
|
||||
app = APIRouter(include_in_schema=False)
|
||||
|
||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
template_path = script_path / "templates"
|
||||
|
||||
templates = Jinja2Blocks(directory=template_path)
|
||||
|
||||
# @app.exception_handler(AuthenticationFailedError)
|
||||
# async def handle_authentication_failed(request: Request, exc: AuthenticationFailedError):
|
||||
# """Handle authentication failed error."""
|
||||
# return templates.TemplateResponse(request, "login.html")
|
||||
|
||||
async def get_backend():
|
||||
"""Get backend client."""
|
||||
backend_client = SshecretBackend(
|
||||
str(settings.backend_url), settings.backend_token
|
||||
)
|
||||
yield backend_client
|
||||
|
||||
async def get_admin_backend(session: Annotated[Session, Depends(get_db_session)]):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
if not password_db:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(settings, password_db.encrypted_password)
|
||||
yield admin
|
||||
|
||||
async def get_login_status(
|
||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||
) -> bool:
|
||||
"""Get login status."""
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
return False
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
return False
|
||||
except jwt.InvalidTokenError:
|
||||
return False
|
||||
token_data = TokenData(username=username)
|
||||
user = session.exec(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if not user:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_current_user_from_token(
|
||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
||||
) -> User:
|
||||
credentials_exception = AuthenticationNeededError()
|
||||
"""Get current user from token."""
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise credentials_exception
|
||||
try:
|
||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
||||
username = payload.get("sub")
|
||||
if not username:
|
||||
raise credentials_exception
|
||||
except jwt.InvalidTokenError:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
user = session.exec(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
@app.get("/")
|
||||
async def get_index(
|
||||
request: Request,
|
||||
login_status: Annotated[bool, Depends(get_login_status)],
|
||||
error_title: str | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""Get index."""
|
||||
if login_status:
|
||||
return RedirectResponse("/dashboard")
|
||||
login_error: LoginError | None = None
|
||||
if error_title and error_message:
|
||||
login_error = LoginError(title=error_title, message=error_message)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"page_title": "Login",
|
||||
"page_description": "Login page.",
|
||||
"login_error": login_error,
|
||||
},
|
||||
)
|
||||
|
||||
@app.post("/")
|
||||
async def post_index(
|
||||
request: Request,
|
||||
error_title: str | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""Get index."""
|
||||
login_error: LoginError | None = None
|
||||
if error_title and error_message:
|
||||
login_error = LoginError(title=error_title, message=error_message)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"page_title": "Login",
|
||||
"page_description": "Login page.",
|
||||
"login_error": login_error,
|
||||
},
|
||||
)
|
||||
|
||||
@app.post("/login")
|
||||
async def login_user(
|
||||
response: Response,
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
username: Annotated[str, Form()],
|
||||
password: Annotated[str, Form()],
|
||||
):
|
||||
"""Log in user."""
|
||||
user = session.exec(select(User).where(User.username == username)).first()
|
||||
auth_error = AuthenticationFailedError()
|
||||
if not user:
|
||||
raise auth_error
|
||||
|
||||
if not verify_password(password, user.hashed_password):
|
||||
raise auth_error
|
||||
|
||||
token_data = {"sub": user.username}
|
||||
expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
token = create_access_token(settings, token_data, expires_delta=expires)
|
||||
response = RedirectResponse(url="/dashboard", status_code=status.HTTP_302_FOUND)
|
||||
response.set_cookie(
|
||||
key="access_token", value=token, httponly=True, secure=False, samesite="lax"
|
||||
)
|
||||
return response
|
||||
|
||||
@app.get("/success")
|
||||
async def success_page(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
):
|
||||
"""Display a success page."""
|
||||
return templates.TemplateResponse(
|
||||
request, "success.html", {"page_title": "Success!", "user": current_user}
|
||||
)
|
||||
|
||||
@app.get("/dashboard")
|
||||
async def get_dashboard(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
):
|
||||
"""Dashboard for mocking up the dashboard."""
|
||||
# secrets = await admin.get_secrets()
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"dashboard.html",
|
||||
{
|
||||
"page_title": "sshecret",
|
||||
"user": current_user.username,
|
||||
},
|
||||
)
|
||||
|
||||
# Stop adding routes here.
|
||||
|
||||
app.include_router(
|
||||
create_client_view(templates, get_current_user_from_token, get_admin_backend)
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
create_secrets_view(templates, get_current_user_from_token, get_admin_backend)
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
create_audit_view(templates, get_current_user_from_token, get_admin_backend)
|
||||
)
|
||||
|
||||
return app
|
||||
@ -0,0 +1,5 @@
|
||||
"""Frontend app."""
|
||||
|
||||
from .router import create_router as create_frontend_router
|
||||
|
||||
__all__ = ["create_frontend_router"]
|
||||
@ -0,0 +1,7 @@
|
||||
"""Custom oauth2 class."""
|
||||
|
||||
from fastapi.security import OAuth2
|
||||
|
||||
|
||||
class Oauth2TokenInCookies(OAuth2):
|
||||
"""TODO: Create this."""
|
||||
@ -0,0 +1,48 @@
|
||||
"""Frontend dependencies."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import Callable, Awaitable
|
||||
from typing import Self
|
||||
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
from fastapi import Request
|
||||
from sqlmodel import Session
|
||||
|
||||
from sshecret_admin.core.dependencies import AdminDep, BaseDependencies
|
||||
|
||||
from sshecret_admin.auth.models import User
|
||||
|
||||
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
||||
UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontendDependencies(BaseDependencies):
|
||||
"""Frontend dependencies."""
|
||||
|
||||
get_admin_backend: AdminDep
|
||||
templates: Jinja2Blocks
|
||||
get_user_from_access_token: UserTokenDep
|
||||
get_user_from_refresh_token: UserTokenDep
|
||||
get_login_status: UserLoginDep
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
deps: BaseDependencies,
|
||||
get_admin_backend: AdminDep,
|
||||
templates: Jinja2Blocks,
|
||||
get_user_from_access_token: UserTokenDep,
|
||||
get_user_from_refresh_token: UserTokenDep,
|
||||
get_login_status: UserLoginDep,
|
||||
) -> Self:
|
||||
"""Create from base dependencies."""
|
||||
return cls(
|
||||
settings=deps.settings,
|
||||
get_db_session=deps.get_db_session,
|
||||
get_admin_backend=get_admin_backend,
|
||||
templates=templates,
|
||||
get_user_from_access_token=get_user_from_access_token,
|
||||
get_user_from_refresh_token=get_user_from_refresh_token,
|
||||
get_login_status=get_login_status,
|
||||
)
|
||||
@ -0,0 +1,13 @@
|
||||
"""Frontend exceptions."""
|
||||
from starlette.datastructures import URL
|
||||
|
||||
|
||||
class RedirectException(Exception):
|
||||
"""Exception that initiates a redirect flow."""
|
||||
|
||||
def __init__(self, to: str | URL) -> None: # pyright: ignore[reportMissingSuperCall]
|
||||
"""Raise exception that redirects."""
|
||||
if isinstance(to, str):
|
||||
to = URL(to)
|
||||
|
||||
self.to: URL = to
|
||||
133
packages/sshecret-admin/src/sshecret_admin/frontend/router.py
Normal file
133
packages/sshecret-admin/src/sshecret_admin/frontend/router.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Frontend router."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from starlette.datastructures import URL
|
||||
|
||||
|
||||
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||
from sshecret_admin.core.dependencies import BaseDependencies
|
||||
from sshecret_admin.services.admin_backend import AdminBackend
|
||||
|
||||
from .dependencies import FrontendDependencies
|
||||
from .exceptions import RedirectException
|
||||
from .views import audit, auth, clients, index, secrets
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
access_token = "access_token"
|
||||
refresh_token = "refresh_token"
|
||||
|
||||
|
||||
def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||
"""Create frontend router."""
|
||||
|
||||
app = APIRouter(include_in_schema=False)
|
||||
|
||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
template_path = script_path / "templates"
|
||||
|
||||
templates = Jinja2Blocks(directory=template_path)
|
||||
|
||||
async def get_admin_backend(
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||
):
|
||||
"""Get admin backend API."""
|
||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||
if not password_db:
|
||||
raise HTTPException(
|
||||
500, detail="Error: The password manager has not yet been set up."
|
||||
)
|
||||
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||
yield admin
|
||||
|
||||
async def get_user_from_token(
|
||||
token: str,
|
||||
session: Session,
|
||||
) -> User | None:
|
||||
"""Get user from a token."""
|
||||
token_data = decode_token(dependencies.settings, token)
|
||||
if not token_data:
|
||||
return None
|
||||
user = session.exec(
|
||||
select(User).where(User.username == token_data.username)
|
||||
).first()
|
||||
if not user or user.disabled:
|
||||
return None
|
||||
return user
|
||||
|
||||
async def get_user_from_refresh_token(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> User:
|
||||
"""Get user from refresh token."""
|
||||
next = URL("/login").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
token = request.cookies.get("refresh_token")
|
||||
if not token:
|
||||
raise credentials_error
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
raise credentials_error
|
||||
return user
|
||||
|
||||
async def get_user_from_access_token(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> User:
|
||||
"""Get user from access token."""
|
||||
token = request.cookies.get("access_token")
|
||||
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||
credentials_error = RedirectException(to=next)
|
||||
if not token:
|
||||
raise credentials_error
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
raise credentials_error
|
||||
return user
|
||||
|
||||
async def get_login_status(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
) -> bool:
|
||||
"""Get login status."""
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
return False
|
||||
|
||||
user = await get_user_from_token(token, session)
|
||||
if not user:
|
||||
return False
|
||||
return True
|
||||
|
||||
view_dependencies = FrontendDependencies.create(
|
||||
dependencies,
|
||||
get_admin_backend,
|
||||
templates,
|
||||
get_user_from_access_token,
|
||||
get_user_from_refresh_token,
|
||||
get_login_status,
|
||||
)
|
||||
|
||||
app.include_router(audit.create_router(view_dependencies))
|
||||
app.include_router(auth.create_router(view_dependencies))
|
||||
app.include_router(clients.create_router(view_dependencies))
|
||||
app.include_router(index.create_router(view_dependencies))
|
||||
app.include_router(secrets.create_router(view_dependencies))
|
||||
|
||||
return app
|
||||
@ -0,0 +1,89 @@
|
||||
{% extends "/dashboard/_base.html" %} {% block content %}
|
||||
|
||||
<div
|
||||
class="p-4 bg-white block sm:flex items-center justify-between border-b border-gray-200 lg:mt-1.5 dark:bg-gray-800 dark:border-gray-700"
|
||||
>
|
||||
<h1 class="mb-4 text-4xl font-extrabold leading-none tracking-tight text-gray-900 md:text-5xl lg:text-6xl dark:text-white">Welcome to Sshecret</h1>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="grid w-full grid-cols-1 gap-4 mt-4 xl:grid-cols-2 2xl:grid-cols-3">
|
||||
<div class="items-center justify-between p-4 bg-white border border-gray-200 rounded-lg shadow-sm sm:flex dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||
<div class="w-full">
|
||||
<h3 class="text-base font-normal text-gray-500 dark:text-gray-400">Clients</h3>
|
||||
<span class="text-2xl font-bold leading-none text-gray-900 sm:text-3xl dark:text-white">{{ stats.clients }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- reference -->
|
||||
|
||||
<div class="grid w-full grid-cols-1 gap-4 mt-4 xl:grid-cols-2 2xl:grid-cols-3">
|
||||
<div class="items-center justify-between p-4 bg-white border border-gray-200 rounded-lg shadow-sm sm:flex dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||
<div class="w-full">
|
||||
<h3 class="text-base font-normal text-gray-500 dark:text-gray-400">New products</h3>
|
||||
<span class="text-2xl font-bold leading-none text-gray-900 sm:text-3xl dark:text-white">2,340</span>
|
||||
<p class="flex items-center text-base font-normal text-gray-500 dark:text-gray-400">
|
||||
<span class="flex items-center mr-1.5 text-sm text-green-500 dark:text-green-400">
|
||||
<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
|
||||
<path clip-rule="evenodd" fill-rule="evenodd" d="M10 17a.75.75 0 01-.75-.75V5.612L5.29 9.77a.75.75 0 01-1.08-1.04l5.25-5.5a.75.75 0 011.08 0l5.25 5.5a.75.75 0 11-1.08 1.04l-3.96-4.158V16.25A.75.75 0 0110 17z"></path>
|
||||
</svg>
|
||||
12.5%
|
||||
</span>
|
||||
Since last month
|
||||
</p>
|
||||
</div>
|
||||
<div class="w-full" id="new-products-chart"></div>
|
||||
</div>
|
||||
<div class="items-center justify-between p-4 bg-white border border-gray-200 rounded-lg shadow-sm sm:flex dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||
<div class="w-full">
|
||||
<h3 class="text-base font-normal text-gray-500 dark:text-gray-400">Users</h3>
|
||||
<span class="text-2xl font-bold leading-none text-gray-900 sm:text-3xl dark:text-white">2,340</span>
|
||||
<p class="flex items-center text-base font-normal text-gray-500 dark:text-gray-400">
|
||||
<span class="flex items-center mr-1.5 text-sm text-green-500 dark:text-green-400">
|
||||
<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
|
||||
<path clip-rule="evenodd" fill-rule="evenodd" d="M10 17a.75.75 0 01-.75-.75V5.612L5.29 9.77a.75.75 0 01-1.08-1.04l5.25-5.5a.75.75 0 011.08 0l5.25 5.5a.75.75 0 11-1.08 1.04l-3.96-4.158V16.25A.75.75 0 0110 17z"></path>
|
||||
</svg>
|
||||
3,4%
|
||||
</span>
|
||||
Since last month
|
||||
</p>
|
||||
</div>
|
||||
<div class="w-full" id="week-signups-chart"></div>
|
||||
</div>
|
||||
<div class="p-4 bg-white border border-gray-200 rounded-lg shadow-sm dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||
<div class="w-full">
|
||||
<h3 class="mb-2 text-base font-normal text-gray-500 dark:text-gray-400">Audience by age</h3>
|
||||
<div class="flex items-center mb-2">
|
||||
<div class="w-16 text-sm font-medium dark:text-white">50+</div>
|
||||
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 18%"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center mb-2">
|
||||
<div class="w-16 text-sm font-medium dark:text-white">40+</div>
|
||||
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 15%"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center mb-2">
|
||||
<div class="w-16 text-sm font-medium dark:text-white">30+</div>
|
||||
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 60%"></div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center mb-2">
|
||||
<div class="w-16 text-sm font-medium dark:text-white">20+</div>
|
||||
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 30%"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div id="traffic-channels-chart" class="w-full"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
{% endblock %}
|
||||
@ -0,0 +1 @@
|
||||
"""Frontend views."""
|
||||
@ -1,19 +1,20 @@
|
||||
"""Audit view."""
|
||||
# pyright: reportUnusedFunction=false
|
||||
"""Audit view factory."""
|
||||
|
||||
import math
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
import math
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Request, Response
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sshecret_admin.admin_backend import AdminBackend
|
||||
from sshecret_admin.types import UserTokenDep, AdminDep
|
||||
from sshecret_admin.auth_models import User
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PagingInfo(BaseModel):
|
||||
|
||||
page: int
|
||||
@ -36,20 +37,15 @@ class PagingInfo(BaseModel):
|
||||
"""Return total pages."""
|
||||
return math.ceil(self.total / self.limit)
|
||||
|
||||
def create_audit_view(
|
||||
templates: Jinja2Blocks,
|
||||
get_current_user_from_token: UserTokenDep,
|
||||
get_admin_backend: AdminDep,
|
||||
) -> APIRouter:
|
||||
"""Create client view."""
|
||||
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create clients router."""
|
||||
|
||||
app = APIRouter()
|
||||
templates = dependencies.templates
|
||||
|
||||
async def resolve_audit_entries(
|
||||
request: Request,
|
||||
current_user: User,
|
||||
admin: AdminBackend,
|
||||
page: int
|
||||
request: Request, current_user: User, admin: AdminBackend, page: int
|
||||
) -> Response:
|
||||
"""Resolve audit entries."""
|
||||
LOG.info("Page: %r", page)
|
||||
@ -61,7 +57,9 @@ def create_audit_view(
|
||||
|
||||
entries = await admin.get_audit_log(offset=offset, limit=per_page)
|
||||
LOG.info("Entries: %r", entries)
|
||||
page_info = PagingInfo(page=page, limit=per_page, total=total_messages, offset=offset)
|
||||
page_info = PagingInfo(
|
||||
page=page, limit=per_page, total=total_messages, offset=offset
|
||||
)
|
||||
if request.headers.get("HX-Request"):
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@ -69,8 +67,7 @@ def create_audit_view(
|
||||
{
|
||||
"entries": entries,
|
||||
"page_info": page_info,
|
||||
}
|
||||
|
||||
},
|
||||
)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
@ -80,34 +77,27 @@ def create_audit_view(
|
||||
"entries": entries,
|
||||
"user": current_user.username,
|
||||
"page_info": page_info,
|
||||
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/audit/")
|
||||
async def get_audit_entries(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
):
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> Response:
|
||||
"""Get audit entries."""
|
||||
return await resolve_audit_entries(request, current_user, admin, 1)
|
||||
|
||||
@app.get("/audit/page/{page}")
|
||||
async def get_audit_entries_page(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
page: int,
|
||||
):
|
||||
) -> Response:
|
||||
"""Get audit entries."""
|
||||
LOG.info("Get audit entries page: %r", page)
|
||||
return await resolve_audit_entries(request, current_user, admin, page)
|
||||
|
||||
|
||||
|
||||
# --------------#
|
||||
# END OF ROUTES #
|
||||
# --------------#
|
||||
return app
|
||||
@ -0,0 +1,143 @@
|
||||
"""Authentication related views factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
from pydantic import BaseModel
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Query, Request, Response, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlmodel import Session
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from sshecret_admin.auth import (
|
||||
User,
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
)
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
from ..exceptions import RedirectException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginError(BaseModel):
|
||||
"""Login error."""
|
||||
|
||||
title: str
|
||||
message: str
|
||||
|
||||
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create auth router."""
|
||||
|
||||
app = APIRouter()
|
||||
templates = dependencies.templates
|
||||
|
||||
@app.get("/login")
|
||||
async def get_login(
|
||||
request: Request,
|
||||
login_status: Annotated[bool, Depends(dependencies.get_login_status)],
|
||||
error_title: str | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""Get index."""
|
||||
if login_status:
|
||||
return RedirectResponse("/dashboard")
|
||||
login_error: LoginError | None = None
|
||||
if error_title and error_message:
|
||||
login_error = LoginError(title=error_title, message=error_message)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"page_title": "Login",
|
||||
"page_description": "Login page.",
|
||||
"login_error": login_error,
|
||||
},
|
||||
)
|
||||
|
||||
@app.post("/login")
|
||||
async def login_user(
|
||||
request: Request,
|
||||
response: Response,
|
||||
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
next: Annotated[str, Query()] = "/dashboard",
|
||||
error_title: str | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""Log in user."""
|
||||
if error_title and error_message:
|
||||
login_error = LoginError(title=error_title, message=error_message)
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"login.html",
|
||||
{
|
||||
"page_title": "Login",
|
||||
"page_description": "Login page.",
|
||||
"login_error": login_error,
|
||||
},
|
||||
)
|
||||
|
||||
user = authenticate_user(session, form_data.username, form_data.password)
|
||||
login_failed = RedirectException(
|
||||
to=URL("/login").include_query_params(
|
||||
error_title="Login Error", error_message="Invalid username or password"
|
||||
)
|
||||
)
|
||||
if not user:
|
||||
raise login_failed
|
||||
token_data: dict[str, str] = {"sub": user.username}
|
||||
access_token = create_access_token(dependencies.settings, data=token_data)
|
||||
refresh_token = create_refresh_token(dependencies.settings, data=token_data)
|
||||
response = RedirectResponse(url=next, status_code=status.HTTP_302_FOUND)
|
||||
response.set_cookie(
|
||||
"access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
)
|
||||
response.set_cookie(
|
||||
"refresh_token",
|
||||
value=refresh_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
@app.get("/refresh")
|
||||
async def get_refresh_token(
|
||||
response: Response,
|
||||
user: Annotated[User, Depends(dependencies.get_user_from_refresh_token)],
|
||||
next: Annotated[str, Query()],
|
||||
):
|
||||
"""Refresh tokens.
|
||||
|
||||
We might as well refresh the long-lived one here.
|
||||
"""
|
||||
token_data: dict[str, str] = {"sub": user.username}
|
||||
access_token = create_access_token(dependencies.settings, data=token_data)
|
||||
refresh_token = create_refresh_token(dependencies.settings, data=token_data)
|
||||
response = RedirectResponse(url=next, status_code=status.HTTP_302_FOUND)
|
||||
response.set_cookie(
|
||||
"access_token",
|
||||
value=access_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
)
|
||||
response.set_cookie(
|
||||
"refresh_token",
|
||||
value=refresh_token,
|
||||
httponly=True,
|
||||
secure=False,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
return app
|
||||
@ -1,21 +1,20 @@
|
||||
"""Client views."""
|
||||
"""clients view factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Request, Form
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request, Response
|
||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||
from sshecret_admin.admin_backend import AdminBackend
|
||||
|
||||
from sshecret.backend import ClientFilter
|
||||
from sshecret.backend.models import FilterType
|
||||
from sshecret.crypto import validate_public_key
|
||||
from sshecret_admin.types import UserTokenDep, AdminDep
|
||||
from sshecret_admin.auth_models import User
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -37,21 +36,18 @@ class ClientCreate(BaseModel):
|
||||
sources: str | None
|
||||
|
||||
|
||||
def create_client_view(
|
||||
templates: Jinja2Blocks,
|
||||
get_current_user_from_token: UserTokenDep,
|
||||
get_admin_backend: AdminDep,
|
||||
) -> APIRouter:
|
||||
"""Create client view."""
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create clients router."""
|
||||
|
||||
app = APIRouter()
|
||||
templates = dependencies.templates
|
||||
|
||||
@app.get("/clients")
|
||||
async def get_clients(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
):
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> Response:
|
||||
"""Get clients."""
|
||||
clients = await admin.get_clients()
|
||||
LOG.info("Clients %r", clients)
|
||||
@ -68,10 +64,12 @@ def create_client_view(
|
||||
@app.post("/clients/query")
|
||||
async def query_clients(
|
||||
request: Request,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
query: Annotated[str, Form()],
|
||||
):
|
||||
) -> Response:
|
||||
"""Query for a client."""
|
||||
query_filter: ClientFilter | None = None
|
||||
if query:
|
||||
@ -90,8 +88,10 @@ def create_client_view(
|
||||
async def update_client(
|
||||
request: Request,
|
||||
id: str,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
client: Annotated[ClientUpdate, Form()],
|
||||
):
|
||||
"""Update a client."""
|
||||
@ -135,9 +135,11 @@ def create_client_view(
|
||||
async def delete_client(
|
||||
request: Request,
|
||||
id: str,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
):
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
) -> Response:
|
||||
"""Delete a client."""
|
||||
await admin.delete_client(id)
|
||||
clients = await admin.get_clients()
|
||||
@ -154,10 +156,12 @@ def create_client_view(
|
||||
@app.post("/clients/")
|
||||
async def create_client(
|
||||
request: Request,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
client: Annotated[ClientCreate, Form()],
|
||||
):
|
||||
) -> Response:
|
||||
"""Create client."""
|
||||
sources: list[str] | None = None
|
||||
if client.sources:
|
||||
@ -179,9 +183,11 @@ def create_client_view(
|
||||
@app.post("/clients/validate/source")
|
||||
async def validate_client_source(
|
||||
request: Request,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
sources: Annotated[str, Form()],
|
||||
):
|
||||
) -> Response:
|
||||
"""Validate source."""
|
||||
source_str = sources.split(",")
|
||||
for source in source_str:
|
||||
@ -211,9 +217,11 @@ def create_client_view(
|
||||
@app.post("/clients/validate/public_key")
|
||||
async def validate_client_public_key(
|
||||
request: Request,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
public_key: Annotated[str, Form()],
|
||||
):
|
||||
) -> Response:
|
||||
"""Validate source."""
|
||||
if validate_public_key(public_key.rstrip()):
|
||||
return templates.TemplateResponse(
|
||||
@ -0,0 +1,70 @@
|
||||
"""Front page view factory."""
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
from pydantic import BaseModel
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
START_PAGE = "/dashboard"
|
||||
LOGIN_PAGE = "/login"
|
||||
|
||||
|
||||
class StatsView(BaseModel):
|
||||
"""Stats for the frontend."""
|
||||
|
||||
clients: int = 0
|
||||
secrets: int = 0
|
||||
audit_events: int = 0
|
||||
|
||||
|
||||
async def get_stats(admin: AdminBackend) -> StatsView:
|
||||
"""Get stats for the frontpage."""
|
||||
clients = await admin.get_clients()
|
||||
secrets = await admin.get_secrets()
|
||||
audit = await admin.get_audit_log_count()
|
||||
return StatsView(clients=len(clients), secrets=len(secrets), audit_events=audit)
|
||||
|
||||
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create auth router."""
|
||||
|
||||
app = APIRouter()
|
||||
templates = dependencies.templates
|
||||
|
||||
@app.get("/")
|
||||
def get_index(logged_in: Annotated[bool, Depends(dependencies.get_login_status)]):
|
||||
"""Get the index."""
|
||||
next = LOGIN_PAGE
|
||||
if logged_in:
|
||||
next = START_PAGE
|
||||
|
||||
return RedirectResponse(url=next)
|
||||
|
||||
@app.get("/dashboard")
|
||||
async def get_dashboard(
|
||||
request: Request,
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
):
|
||||
"""Dashboard for mocking up the dashboard."""
|
||||
stats = await get_stats(admin)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request,
|
||||
"dashboard.html",
|
||||
{
|
||||
"page_title": "sshecret",
|
||||
"user": current_user.username,
|
||||
"stats": stats,
|
||||
},
|
||||
)
|
||||
|
||||
return app
|
||||
@ -1,24 +1,25 @@
|
||||
"""Secrets view."""
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
# pyright: reportUnusedFunction=false
|
||||
import logging
|
||||
import secrets as pysecrets
|
||||
from typing import Annotated, Any
|
||||
from fastapi import APIRouter, Depends, Request, Form
|
||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
from pydantic import BaseModel, BeforeValidator, Field
|
||||
from sshecret_admin.admin_backend import AdminBackend
|
||||
from sshecret_admin.types import UserTokenDep, AdminDep
|
||||
from sshecret_admin.auth_models import User
|
||||
|
||||
from sshecret_admin.auth import User
|
||||
from sshecret_admin.services import AdminBackend
|
||||
|
||||
from ..dependencies import FrontendDependencies
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def split_clients(clients: Any) -> Any:
|
||||
def split_clients(clients: Any) -> Any: # pyright: ignore[reportAny]
|
||||
"""Split clients."""
|
||||
if isinstance(clients, list):
|
||||
return clients
|
||||
return clients # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(clients, str):
|
||||
raise ValueError("Invalid type for clients.")
|
||||
if not clients:
|
||||
@ -26,7 +27,7 @@ def split_clients(clients: Any) -> Any:
|
||||
return [client.rstrip() for client in clients.split(",")]
|
||||
|
||||
|
||||
def handle_select_bool(value: Any) -> Any:
|
||||
def handle_select_bool(value: Any) -> Any: # pyright: ignore[reportAny]
|
||||
"""Handle boolean from select."""
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
@ -47,20 +48,17 @@ class CreateSecret(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def create_secrets_view(
|
||||
templates: Jinja2Blocks,
|
||||
get_current_user_from_token: UserTokenDep,
|
||||
get_admin_backend: AdminDep,
|
||||
) -> APIRouter:
|
||||
"""Create secrets view."""
|
||||
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||
"""Create secrets router."""
|
||||
|
||||
app = APIRouter()
|
||||
templates = dependencies.templates
|
||||
|
||||
@app.get("/secrets/")
|
||||
async def get_secrets(
|
||||
request: Request,
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Get secrets index page."""
|
||||
secrets = await admin.get_detailed_secrets()
|
||||
@ -79,8 +77,10 @@ def create_secrets_view(
|
||||
@app.post("/secrets/")
|
||||
async def add_secret(
|
||||
request: Request,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
secret: Annotated[CreateSecret, Form()],
|
||||
):
|
||||
"""Add secret."""
|
||||
@ -108,8 +108,10 @@ def create_secrets_view(
|
||||
request: Request,
|
||||
name: str,
|
||||
id: str,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Remove a client's access to a secret."""
|
||||
await admin.delete_client_secret(id, name)
|
||||
@ -130,8 +132,10 @@ def create_secrets_view(
|
||||
request: Request,
|
||||
name: str,
|
||||
client: Annotated[str, Form()],
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Add a secret to a client."""
|
||||
await admin.create_client_secret(client, name)
|
||||
@ -153,8 +157,10 @@ def create_secrets_view(
|
||||
async def delete_secret(
|
||||
request: Request,
|
||||
name: str,
|
||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
||||
_current_user: Annotated[
|
||||
User, Depends(dependencies.get_user_from_access_token)
|
||||
],
|
||||
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||
):
|
||||
"""Delete a secret."""
|
||||
await admin.delete_secret(name)
|
||||
@ -172,7 +178,4 @@ def create_secrets_view(
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# --------------#
|
||||
# END OF ROUTES #
|
||||
# --------------#
|
||||
return app
|
||||
@ -0,0 +1,8 @@
|
||||
"""Services module.
|
||||
|
||||
This module contains business logic.
|
||||
"""
|
||||
|
||||
from .admin_backend import AdminBackend
|
||||
|
||||
__all__ = ["AdminBackend"]
|
||||
@ -7,13 +7,14 @@ import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sshecret.backend import AuditLog, Client, ClientFilter, Secret, SshecretBackend
|
||||
from sshecret.backend import AuditLog, Client, ClientFilter, Secret, SshecretBackend, Operation, SubSystem
|
||||
from sshecret.backend.models import DetailedSecrets
|
||||
from sshecret.backend.api import AuditAPI
|
||||
from sshecret.crypto import encrypt_string, load_public_key
|
||||
|
||||
from .keepass import PasswordContext, load_password_manager
|
||||
from .settings import AdminServerSettings
|
||||
from .view_models import SecretView
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
from .models import SecretView
|
||||
|
||||
|
||||
class ClientManagementError(Exception):
|
||||
@ -381,6 +382,11 @@ class AdminBackend:
|
||||
except Exception as e:
|
||||
raise BackendUnavailableError() from e
|
||||
|
||||
@property
|
||||
def audit(self) -> AuditAPI:
|
||||
"""Resolve audit API."""
|
||||
return self.backend.audit(SubSystem.ADMIN)
|
||||
|
||||
async def get_audit_log(
|
||||
self,
|
||||
offset: int = 0,
|
||||
@ -389,14 +395,36 @@ class AdminBackend:
|
||||
subsystem: str | None = None,
|
||||
) -> list[AuditLog]:
|
||||
"""Get audit log from backend."""
|
||||
return await self.backend.get_audit_log(offset, limit, client_name, subsystem)
|
||||
return await self.audit.get(offset, limit, client_name, subsystem)
|
||||
|
||||
async def write_audit_message(
|
||||
self,
|
||||
operation: Operation,
|
||||
message: str,
|
||||
origin: str,
|
||||
client: Client | None = None,
|
||||
secret_name: str | None = None,
|
||||
**data: str,
|
||||
) -> None:
|
||||
"""Write an audit message."""
|
||||
await self.audit.write_async(
|
||||
operation=operation,
|
||||
message=message,
|
||||
origin=origin,
|
||||
client=client,
|
||||
secret=None,
|
||||
secret_name=secret_name,
|
||||
**data,
|
||||
)
|
||||
|
||||
async def write_audit_log(self, entry: AuditLog) -> None:
|
||||
"""Write to the audit log."""
|
||||
if not entry.subsystem:
|
||||
entry.subsystem = "admin"
|
||||
await self.backend.add_audit_log(entry)
|
||||
entry.subsystem = SubSystem.ADMIN
|
||||
|
||||
await self.audit.write_model_async(entry)
|
||||
#await self.backend.add_audit_log(entry)
|
||||
|
||||
async def get_audit_log_count(self) -> int:
|
||||
"""Get audit log count."""
|
||||
return await self.backend.get_audit_log_count()
|
||||
return await self.audit.count()
|
||||
@ -8,7 +8,7 @@ from typing import cast
|
||||
|
||||
import pykeepass
|
||||
from .master_password import decrypt_master_password
|
||||
from .settings import AdminServerSettings
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -8,7 +8,7 @@ from sshecret.crypto import (
|
||||
encrypt_string,
|
||||
decode_string,
|
||||
)
|
||||
from .settings import AdminServerSettings
|
||||
from sshecret_admin.core.settings import AdminServerSettings
|
||||
|
||||
KEY_FILENAME = "sshecret-admin-key"
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Models for the API."""
|
||||
|
||||
import secrets
|
||||
from typing import Annotated, Literal, Self, Union
|
||||
from typing import Annotated, Literal
|
||||
from pydantic import (
|
||||
AfterValidator,
|
||||
BaseModel,
|
||||
@ -9,7 +9,6 @@ from pydantic import (
|
||||
Field,
|
||||
IPvAnyAddress,
|
||||
IPvAnyNetwork,
|
||||
model_validator,
|
||||
)
|
||||
from sshecret.crypto import validate_public_key
|
||||
|
||||
@ -1,10 +0,0 @@
|
||||
{% extends "/dashboard/_base.html" %} {% block content %}
|
||||
|
||||
<div
|
||||
class="p-4 bg-white block sm:flex items-center justify-between border-b border-gray-200 lg:mt-1.5 dark:bg-gray-800 dark:border-gray-700"
|
||||
>
|
||||
<h1 class="mb-4 text-4xl font-extrabold leading-none tracking-tight text-gray-900 md:text-5xl lg:text-6xl dark:text-white">Welcome to Sshecret</h1>
|
||||
</div>
|
||||
|
||||
|
||||
{% endblock %}
|
||||
@ -4,8 +4,7 @@ import os
|
||||
|
||||
import bcrypt
|
||||
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, select
|
||||
from sqlmodel import Session
|
||||
from .auth_models import User
|
||||
|
||||
|
||||
|
||||
@ -6,16 +6,10 @@ from fastapi import Request
|
||||
from sqlmodel import Session
|
||||
from sshecret_admin.admin_backend import AdminBackend
|
||||
from sshecret_admin.auth_models import User
|
||||
from sshecret.backend import SshecretBackend
|
||||
from . import keepass
|
||||
|
||||
|
||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||
|
||||
BackendDep = Callable[[], AsyncGenerator[SshecretBackend, None]]
|
||||
|
||||
PasswdCtxDep = Callable[[DBSessionDep], AsyncGenerator[keepass.PasswordContext, None]]
|
||||
|
||||
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
||||
|
||||
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
from .audit import create_audit_view
|
||||
from .clients import create_client_view
|
||||
from .secrets import create_secrets_view
|
||||
|
||||
__all__ = ["create_audit_view", "create_client_view", "create_secrets_view"]
|
||||
@ -1,24 +0,0 @@
|
||||
<!doctype html>
|
||||
<html class="no-js" lang="">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta http-equiv="x-ua-compatible" content="ie=edge" />
|
||||
<title>Untitled</title>
|
||||
<meta name="description" content="" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
|
||||
<link rel="apple-touch-icon" href="/apple-touch-icon.png" />
|
||||
<!-- Place favicon.ico in the root directory -->
|
||||
</head>
|
||||
<body>
|
||||
<!--[if lt IE 8]>
|
||||
<p class="browserupgrade">
|
||||
You are using an <strong>outdated</strong> browser. Please
|
||||
<a href="http://browsehappy.com/">upgrade your browser</a> to improve
|
||||
your experience.
|
||||
</p>
|
||||
<![endif]-->
|
||||
|
||||
<p>I am outside of the package</p>
|
||||
</body>
|
||||
</html>
|
||||
@ -3,22 +3,22 @@ from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from sqlmodel import create_engine
|
||||
|
||||
from alembic import context
|
||||
from sshecret_backend.models import *
|
||||
|
||||
def get_database_url() -> str:
|
||||
"""Get database URL."""
|
||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||
return f"sqlite:///{db_file}"
|
||||
return "sqlite:///sshecret.db"
|
||||
|
||||
from sshecret_backend.models import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
|
||||
def get_database_url() -> str | None:
|
||||
"""Get database URL."""
|
||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||
return f"sqlite:///{db_file}"
|
||||
return config.get_main_option("sqlalchemy.url")
|
||||
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
@ -28,8 +28,7 @@ if config.config_file_name is not None:
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
#target_metadata = None
|
||||
target_metadata = SQLModel.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
@ -68,7 +67,11 @@ def run_migrations_online() -> None:
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = create_engine(get_database_url())
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
|
||||
@ -9,7 +9,6 @@ from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
@ -5,13 +5,14 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from fastapi import APIRouter, Depends, Request, Query
|
||||
from sqlmodel import Session, col, func, select
|
||||
from sqlalchemy import desc
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import AuditLog
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.view_models import AuditInfo
|
||||
from sshecret_backend.view_models import AuditInfo, AuditView
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Construct audit sub-api."""
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/audit/", response_model=list[AuditLog])
|
||||
@router.get("/audit/", response_model=list[AuditView])
|
||||
async def get_audit_logs(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
limit: Annotated[int, Query(le=100)] = 100,
|
||||
filter_client: Annotated[str | None, Query()] = None,
|
||||
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||
) -> Sequence[AuditLog]:
|
||||
) -> Sequence[AuditView]:
|
||||
"""Get audit logs."""
|
||||
#audit.audit_access_audit_log(session, request)
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
|
||||
if filter_client:
|
||||
statement = statement.where(AuditLog.client_name == filter_client)
|
||||
|
||||
if filter_subsystem:
|
||||
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||
|
||||
results = session.exec(statement).all()
|
||||
return results
|
||||
LogAdapt = TypeAdapter(list[AuditView])
|
||||
results = session.scalars(statement).all()
|
||||
return LogAdapt.validate_python(results, from_attributes=True)
|
||||
|
||||
|
||||
@router.post("/audit/")
|
||||
async def add_audit_log(
|
||||
request: Request,
|
||||
session: Annotated[Session, Depends(get_db_session)],
|
||||
entry: AuditLog,
|
||||
) -> AuditLog:
|
||||
entry: AuditView,
|
||||
) -> AuditView:
|
||||
"""Add entry to audit log."""
|
||||
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
||||
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||
session.add(audit_log)
|
||||
session.commit()
|
||||
return audit_log
|
||||
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||
|
||||
@router.get("/audit/info")
|
||||
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||
"""Get audit info."""
|
||||
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
||||
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
|
||||
return AuditInfo(entries=audit_count)
|
||||
|
||||
|
||||
|
||||
@ -6,11 +6,11 @@ import uuid
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from sqlmodel import Session, col, select
|
||||
from sqlalchemy import func
|
||||
from typing import Annotated, Self, TypeVar
|
||||
from typing import Annotated, Any, Self, TypeVar, cast
|
||||
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import Select
|
||||
from sshecret_backend.types import DBSessionDep
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
from sshecret_backend.view_models import (
|
||||
@ -55,8 +55,8 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
def filter_client_statement(
|
||||
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> SelectOfScalar[T]:
|
||||
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||
) -> Select[Any]:
|
||||
"""Filter a statement with the provided params."""
|
||||
if params.id:
|
||||
statement = statement.where(Client.id == params.id)
|
||||
@ -64,9 +64,9 @@ def filter_client_statement(
|
||||
if params.name:
|
||||
statement = statement.where(Client.name == params.name)
|
||||
elif params.name__like:
|
||||
statement = statement.where(col(Client.name).like(params.name__like))
|
||||
statement = statement.where(Client.name.like(params.name__like))
|
||||
elif params.name__contains:
|
||||
statement = statement.where(col(Client.name).contains(params.name__contains))
|
||||
statement = statement.where(Client.name.contains(params.name__contains))
|
||||
|
||||
if ignore_limits:
|
||||
return statement
|
||||
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
"""Get clients."""
|
||||
# Get total results first
|
||||
count_statement = select(func.count("*")).select_from(Client)
|
||||
count_statement = filter_client_statement(count_statement, filter_query, True)
|
||||
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||
|
||||
total_results = session.exec(count_statement).one()
|
||||
total_results = session.scalars(count_statement).one()
|
||||
|
||||
statement = filter_client_statement(select(Client), filter_query, False)
|
||||
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
remainder = total_results - filter_query.offset - filter_query.limit
|
||||
if remainder < 0:
|
||||
remainder = 0
|
||||
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
client.public_key = client_update.public_key
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.refresh(client)
|
||||
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
public_key_updated = False
|
||||
if client_update.public_key != client.public_key:
|
||||
public_key_updated = True
|
||||
for secret in session.exec(
|
||||
for secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all():
|
||||
LOG.debug("Invalidated secret %s", secret.id)
|
||||
secret.invalidated = True
|
||||
secret.client_id = None
|
||||
secret.client = None
|
||||
|
||||
session.add(client)
|
||||
session.commit()
|
||||
|
||||
@ -4,7 +4,8 @@ import re
|
||||
import uuid
|
||||
import bcrypt
|
||||
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from sshecret_backend.models import Client
|
||||
|
||||
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.name == name)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||
"""Get client by name."""
|
||||
client_filter = select(Client).where(Client.id == id)
|
||||
client_results = session.exec(client_filter)
|
||||
client_results = session.scalars(client_filter)
|
||||
return client_results.first()
|
||||
|
||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import ClientAccessPolicy
|
||||
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
status_code=404, detail="Cannot find a client with the given name."
|
||||
)
|
||||
# Remove old policies.
|
||||
policies = session.exec(
|
||||
policies = session.scalars(
|
||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||
).all()
|
||||
deleted_policies: list[ClientAccessPolicy] = []
|
||||
|
||||
@ -5,7 +5,8 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Annotated
|
||||
|
||||
from sshecret_backend.models import Client, ClientSecret
|
||||
@ -34,7 +35,7 @@ async def lookup_client_secret(
|
||||
.where(ClientSecret.client_id == client.id)
|
||||
.where(ClientSecret.name == name)
|
||||
)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
return results.first()
|
||||
|
||||
|
||||
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
if not client_secret.client:
|
||||
if client_secret.name not in client_secret_map:
|
||||
client_secret_map[client_secret.name] = []
|
||||
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> list[ClientSecretDetailList]:
|
||||
"""Get a list of all secrets and which clients have them."""
|
||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||
for client_secret in session.exec(select(ClientSecret)).all():
|
||||
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||
|
||||
if client_secret.name not in client_secrets:
|
||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
clients: list[str] = []
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
||||
) -> ClientSecretDetailList:
|
||||
"""Get a list of which clients has a named secret."""
|
||||
detail_list = ClientSecretDetailList(name=name)
|
||||
for client_secret in session.exec(
|
||||
for client_secret in session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.name == name)
|
||||
).all():
|
||||
if not client_secret.client:
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
|
||||
from collections.abc import Sequence
|
||||
from fastapi import Request
|
||||
from sqlmodel import Session, select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||
|
||||
|
||||
def _get_origin(request: Request) -> str | None:
|
||||
@ -22,7 +23,7 @@ def _write_audit_log(
|
||||
"""Write the audit log."""
|
||||
origin = _get_origin(request)
|
||||
entry.origin = origin
|
||||
entry.subsystem = "backend"
|
||||
entry.subsystem = SubSystem.BACKEND
|
||||
session.add(entry)
|
||||
if commit:
|
||||
session.commit()
|
||||
@ -33,7 +34,7 @@ def audit_create_client(
|
||||
) -> None:
|
||||
"""Log the creation of a client."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
operation=Operation.CREATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client Created",
|
||||
@ -46,7 +47,7 @@ def audit_delete_client(
|
||||
) -> None:
|
||||
"""Log the creation of a client."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
operation=Operation.CREATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client deleted",
|
||||
@ -63,9 +64,9 @@ def audit_create_secret(
|
||||
) -> None:
|
||||
"""Audit a create secret event."""
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.CREATE,
|
||||
secret_id=secret.id,
|
||||
secret_name=secret.name,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Added secret to client",
|
||||
@ -81,13 +82,13 @@ def audit_remove_policy(
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit removal of policy."""
|
||||
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||
entry = AuditLog(
|
||||
operation="DELETE",
|
||||
object="ClientAccessPolicy",
|
||||
object_id=str(policy.id),
|
||||
operation=Operation.DELETE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Deleted client policy",
|
||||
data=data,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -100,13 +101,13 @@ def audit_update_policy(
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""Audit update of policy."""
|
||||
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||
entry = AuditLog(
|
||||
operation="CREATE",
|
||||
object="ClientAccessPolicy",
|
||||
object_id=str(policy.id),
|
||||
client_id=client.id,
|
||||
operation=Operation.CREATE,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Updated client policy",
|
||||
data=data,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -119,11 +120,10 @@ def audit_update_client(
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation="UPDATE",
|
||||
object="Client",
|
||||
operation=Operation.UPDATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
message="Client updated",
|
||||
message="Client data updated",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
@ -137,11 +137,11 @@ def audit_update_secret(
|
||||
) -> None:
|
||||
"""Audit an update secret event."""
|
||||
entry = AuditLog(
|
||||
operation="UPDATE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.UPDATE,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
message="Secret value updated",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
|
||||
) -> None:
|
||||
"""Audit Invalidate client secrets."""
|
||||
entry = AuditLog(
|
||||
operation="INVALIDATE",
|
||||
object="ClientSecret",
|
||||
operation=Operation.UPDATE,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Client public-key changed. All secrets invalidated.",
|
||||
@ -173,9 +172,9 @@ def audit_delete_secret(
|
||||
) -> None:
|
||||
"""Audit Delete client secrets."""
|
||||
entry = AuditLog(
|
||||
operation="DELETE",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
operation=Operation.DELETE,
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
client_name=client.name,
|
||||
client_id=client.id,
|
||||
message="Deleted secret.",
|
||||
@ -195,7 +194,7 @@ def audit_access_secrets(
|
||||
With no secrets provided, all secrets of the client will be resolved.
|
||||
"""
|
||||
if not secrets:
|
||||
secrets = session.exec(
|
||||
secrets = session.scalars(
|
||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||
).all()
|
||||
|
||||
@ -215,37 +214,21 @@ def audit_access_secret(
|
||||
) -> None:
|
||||
"""Audit that someone accessed one secrets."""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
operation=Operation.READ,
|
||||
message="Secret was viewed",
|
||||
object="ClientSecret",
|
||||
object_id=str(secret.id),
|
||||
secret_name=secret.name,
|
||||
secret_id=secret.id,
|
||||
client_id=client.id,
|
||||
client_name=client.name,
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_access_audit_log(
|
||||
session: Session, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
"""Audit access to the audit log.
|
||||
|
||||
Because why not...
|
||||
"""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
message="Audit log was viewed",
|
||||
object="AuditLog",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
|
||||
def audit_client_secret_list(
|
||||
session: Session, request: Request, commit: bool = True
|
||||
) -> None:
|
||||
"""Audit a list of all secrets."""
|
||||
entry = AuditLog(
|
||||
operation="ACCESS",
|
||||
operation=Operation.READ,
|
||||
message="All secret names and their clients was viewed",
|
||||
)
|
||||
_write_audit_log(session, request, entry, commit)
|
||||
|
||||
17
packages/sshecret-backend/src/sshecret_backend/auth.py
Normal file
17
packages/sshecret-backend/src/sshecret_backend/auth.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Auth helpers."""
|
||||
|
||||
import bcrypt
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Hash a token."""
|
||||
pwbytes = token.encode("utf-8")
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
|
||||
return hashed_bytes.decode()
|
||||
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
@ -3,11 +3,11 @@
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||
from .auth import verify_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
)
|
||||
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
|
||||
API_VERSION = "v1"
|
||||
|
||||
|
||||
def verify_token(token: str, stored_hash: str) -> bool:
|
||||
"""Verify token."""
|
||||
token_bytes = token.encode("utf-8")
|
||||
stored_bytes = stored_hash.encode("utf-8")
|
||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||
|
||||
|
||||
def get_backend_api(
|
||||
get_db_session: DBSessionDep,
|
||||
) -> APIRouter:
|
||||
@ -37,7 +30,7 @@ def get_backend_api(
|
||||
"""Validate token."""
|
||||
LOG.debug("Validating token %s", x_api_token)
|
||||
statement = select(APIClient)
|
||||
results = session.exec(statement)
|
||||
results = session.scalars(statement)
|
||||
valid = False
|
||||
for result in results:
|
||||
if verify_token(x_api_token, result.token):
|
||||
|
||||
@ -3,15 +3,24 @@
|
||||
import code
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from dotenv import load_dotenv
|
||||
from typing import Literal, cast
|
||||
|
||||
import click
|
||||
from sqlmodel import Session, col, func, select
|
||||
import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .db import get_engine, create_api_token
|
||||
|
||||
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
|
||||
from .db import create_api_token, get_engine, hash_token
|
||||
from .models import (
|
||||
APIClient,
|
||||
AuditLog,
|
||||
Client,
|
||||
ClientAccessPolicy,
|
||||
ClientSecret,
|
||||
SubSystem,
|
||||
init_db,
|
||||
)
|
||||
from .settings import BackendSettings
|
||||
|
||||
DEFAULT_LISTEN = "127.0.0.1"
|
||||
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
|
||||
|
||||
load_dotenv()
|
||||
|
||||
def generate_token(settings: BackendSettings) -> str:
|
||||
|
||||
def generate_token(
|
||||
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
|
||||
) -> str:
|
||||
"""Generate a token."""
|
||||
engine = get_engine(settings.db_url)
|
||||
init_db(engine)
|
||||
with Session(engine) as session:
|
||||
token = create_api_token(session, True)
|
||||
token = create_api_token(session, subsystem)
|
||||
return token
|
||||
|
||||
def count_tokens(settings: BackendSettings) -> int:
|
||||
"""Count the amount of tokens created."""
|
||||
|
||||
def add_system_tokens(settings: BackendSettings) -> None:
|
||||
"""Add token for subsystems."""
|
||||
if not settings.admin_token and not settings.sshd_token:
|
||||
# Tokens should be generated manually.
|
||||
return
|
||||
|
||||
engine = get_engine(settings.db_url)
|
||||
init_db(engine)
|
||||
tokens: list[tuple[str, SubSystem]] = []
|
||||
if admin_token := settings.admin_token:
|
||||
tokens.append((admin_token, SubSystem.ADMIN))
|
||||
if sshd_token := settings.sshd_token:
|
||||
tokens.append((sshd_token, SubSystem.SSHD))
|
||||
with Session(engine) as session:
|
||||
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
|
||||
for token, subsystem in tokens:
|
||||
hashed_token = hash_token(token)
|
||||
if existing := session.scalars(
|
||||
select(APIClient).where(APIClient.subsystem == subsystem)
|
||||
).first():
|
||||
existing.token = hashed_token
|
||||
else:
|
||||
new_token = APIClient(token=hashed_token, subsystem=subsystem)
|
||||
session.add(new_token)
|
||||
|
||||
return count
|
||||
session.commit()
|
||||
click.echo("Generated system tokens.")
|
||||
|
||||
|
||||
@click.group()
|
||||
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
|
||||
else:
|
||||
settings = BackendSettings()
|
||||
|
||||
add_system_tokens(settings)
|
||||
|
||||
if settings.generate_initial_tokens:
|
||||
if count_tokens(settings) == 0:
|
||||
click.echo("Creating initial tokens for admin and sshd.")
|
||||
admin_token = generate_token(settings)
|
||||
sshd_token = generate_token(settings)
|
||||
click.echo(f"Admin token: {admin_token}")
|
||||
click.echo(f"SSHD token: {sshd_token}")
|
||||
# if settings.generate_initial_tokens:
|
||||
# if count_tokens(settings) == 0:
|
||||
# click.echo("Creating initial tokens for admin and sshd.")
|
||||
# admin_token = generate_token(settings)
|
||||
# sshd_token = generate_token(settings)
|
||||
# click.echo(f"Admin token: {admin_token}")
|
||||
# click.echo(f"SSHD token: {sshd_token}")
|
||||
|
||||
ctx.obj = settings
|
||||
|
||||
|
||||
@cli.command("generate-token")
|
||||
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
|
||||
@click.pass_context
|
||||
def cli_generate_token(ctx: click.Context) -> None:
|
||||
"""Generate a token."""
|
||||
def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
|
||||
"""Generate a token for a subsystem.."""
|
||||
settings = cast(BackendSettings, ctx.obj)
|
||||
token = generate_token(settings)
|
||||
token = generate_token(settings, subsystem)
|
||||
click.echo("Generated api token:")
|
||||
click.echo(token)
|
||||
|
||||
|
||||
@cli.command("run")
|
||||
@click.option("--host", default="127.0.0.1")
|
||||
@click.option("--port", default=8022, type=click.INT)
|
||||
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
|
||||
@click.option("--workers", type=click.INT)
|
||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||
"""Run the server."""
|
||||
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
|
||||
uvicorn.run(
|
||||
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
|
||||
)
|
||||
|
||||
|
||||
@cli.command("repl")
|
||||
@click.pass_context
|
||||
|
||||
@ -2,56 +2,108 @@
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import sqlite3
|
||||
|
||||
from collections.abc import Generator, Callable
|
||||
from pathlib import Path
|
||||
from sqlalchemy import Engine
|
||||
from sqlmodel import Session, create_engine, text
|
||||
import bcrypt
|
||||
from typing import Literal
|
||||
from sqlalchemy import create_engine, Engine, event, select
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
|
||||
|
||||
from .models import APIClient
|
||||
from .auth import hash_token, verify_token
|
||||
from .models import APIClient, SubSystem
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_database(
|
||||
db_url: URL | str,
|
||||
db_url: URL,
|
||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||
"""Setup database."""
|
||||
|
||||
engine = create_engine(db_url, echo=False)
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||
engine = get_engine(db_url)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
|
||||
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
"""Get DB Session."""
|
||||
with Session(engine) as session:
|
||||
session = SessionLocal(bind=engine)
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return engine, get_db_session
|
||||
|
||||
|
||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||
"""Initialize the engine."""
|
||||
engine = create_engine(url, echo=echo)
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
||||
engine = create_engine(url, echo=echo, future=True)
|
||||
if url.drivername.startswith("sqlite"):
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def create_api_token(session: Session, read_write: bool) -> str:
|
||||
"""Create API token."""
|
||||
token = secrets.token_urlsafe(32)
|
||||
pwbytes = token.encode("utf-8")
|
||||
salt = bcrypt.gensalt()
|
||||
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
|
||||
hashed = hashed_bytes.decode()
|
||||
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
|
||||
"""Get an async engine."""
|
||||
engine = create_async_engine(url, echo=echo, future=True)
|
||||
if url.drivername.startswith("sqlite+"):
|
||||
|
||||
api_token = APIClient(token=hashed, read_write=read_write)
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(
|
||||
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||
) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
|
||||
"""Create API token with a given value."""
|
||||
|
||||
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
|
||||
if existing:
|
||||
if verify_token(token, existing.token):
|
||||
LOG.info("Token is up to date.")
|
||||
return
|
||||
LOG.info("Updating token value for subsystem %s", subsystem)
|
||||
hashed = hash_token(token)
|
||||
existing.token=hashed
|
||||
session.commit()
|
||||
return
|
||||
|
||||
LOG.info("No existing token found. Creating new")
|
||||
hashed = hash_token(token)
|
||||
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
|
||||
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
|
||||
"""Create API token."""
|
||||
subsys = SubSystem(subsystem)
|
||||
token = secrets.token_urlsafe(32)
|
||||
hashed = hash_token(token)
|
||||
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
|
||||
if not recreate:
|
||||
raise RuntimeError("Error: A token already exist for this subsystem.")
|
||||
existing.token = hashed
|
||||
else:
|
||||
api_token = APIClient(token=hashed, subsystem=subsys)
|
||||
session.add(api_token)
|
||||
session.commit()
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user