Compare commits
10 Commits
4f970a3f71
...
7c65d5bb93
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c65d5bb93 | |||
| e22f9b8f10 | |||
| e766b06d5c | |||
| e28276634a | |||
| 3f5d9ea545 | |||
| 6bebbee4fa | |||
| 9bf92d6743 | |||
| 9ccd2f1d4d | |||
| d866553ac1 | |||
| 0a427b6a91 |
7
.dockerignore
Normal file
7
.dockerignore
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
.venv
|
||||||
|
.git
|
||||||
|
.github
|
||||||
|
**/.pytest_cache
|
||||||
|
**/__pycache__
|
||||||
|
.ruff_cache
|
||||||
|
**/.testing
|
||||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
122
README.md
122
README.md
@ -1,58 +1,112 @@
|
|||||||
# Sshecret - Openssh based secrets management
|
# Sshecret - Simple Secret management using SSH and RSA Keys
|
||||||
|
|
||||||
## Motivation
|
Managing secrets within a container environment or homelab can be quite complex.
|
||||||
|
|
||||||
There are many approaches to managing secrets for services, but a lot of these
|
As with container orchestration in general, many of the available tools are targetted large enterprise installations and add significant complexity to both the management of secrets themselves, and to the consumers.
|
||||||
either assume you have one of the industry-standard systems like hashicorp vault to manage them centrally.
|
|
||||||
|
|
||||||
For enthusiasts or homelabbers this becomes overkill quickly, and end up
|
For enthusiasts or homelabbers solutions like Hashicorp Vault become overkill
|
||||||
consuming a lot more time and energy than what feels justified.
|
quickly, and end up consuming a lot more time and energy than what feels
|
||||||
|
justified.
|
||||||
|
|
||||||
This system has been created to provide a centralized solution that works well-enough.
|
This system has been created to provide a centralized solution that works well-enough.
|
||||||
|
|
||||||
One clear goal was to have all the complexity on the server-side, and be able to construct a minimal client.
|
Sshecret provides a simple, centralized solution for secret storage, and requires only tools that are commonly pre-installed on most linux system.
|
||||||
|
|
||||||
## Components
|
# Concept
|
||||||
|
The system uses RSA keys, as generated by openssh, to encrypt secrets for each client.
|
||||||
|
|
||||||
This system has been designed with modularity and extensibility in mind. It has the following building blocks:
|
It uses RSA keys as it is possible to do encryption using only the public key.
|
||||||
|
|
||||||
- Password database
|
By using a custom SSH server, the consuming servers can fetch a version of a secret encrypted specifically for them.
|
||||||
- Password input handler
|
|
||||||
- Encryption and key management
|
As the secret is encrypted with RSA keys, it can be decrypted using the openssl command which is commonly available.
|
||||||
|
|
||||||
|
This means that while the backend interface can be complex, the to access and decrypt a secret, you can use ssh, and a simple bash script.
|
||||||
|
|
||||||
|
# Components
|
||||||
|
|
||||||
|
There are three components to Sshecret:
|
||||||
|
|
||||||
|
- Password database and admin interface
|
||||||
- Client secret storage backend
|
- Client secret storage backend
|
||||||
- Custom ssh server
|
- Ssh server for clients to access
|
||||||
|
|
||||||
### Password database
|
The three systems should be deployed separately for security, with the backend system as the only central component.
|
||||||
Currently a single password database is implemented: Keepass.
|
|
||||||
|
|
||||||
Sshecret can create a database, and store your secrets in it.
|
The ssh server that the clients connect to, only has access to encrypted versions of the secrets.
|
||||||
|
|
||||||
It only uses a master password for protection, so you are responsible for
|
If it or the backend should be compromised, it wouldn't be possible to extract any clear-text secrets, since only encrypted values are stored, and each encrypted with RSA public key encryption.
|
||||||
securing the password database file. In theory, the password database file can
|
|
||||||
be disconnected after encrypting the passwords for the clients, and these two
|
|
||||||
components may be disconnected.
|
|
||||||
|
|
||||||
### Password input handler
|
## Backend
|
||||||
Passwords can be randomly generated, they can be read from stdin, or from environment variables.
|
The backend system stores the definition of each client, including their public key, and the IP addresses/networks they are allowed to connect from.
|
||||||
|
It also stores a version of each secret that the client has access to, encrypted with the public key. It has no knowledge of the unencrypted secrets.
|
||||||
|
|
||||||
Other methods can be implemented in the future.
|
Both the admin interface and the ssh server access this system over a REST API using a token-based authentication method.
|
||||||
|
|
||||||
### Client secret storage backend
|
## Password database and admin interface
|
||||||
So far only a simple JSON file based backend has been implemented. It stores one file per client.
|
The sshecret password database is based on keepass, and the admin interface is available as a simple web interface as well as a REST API.
|
||||||
The interface is flexible, and can be extended to databases or anything else really.
|
|
||||||
|
|
||||||
### Custom SSH server
|
This component is primarily responsible for storing the secrets in a keepass database and populating the backend with client-specific encrypted versions.
|
||||||
A custom SSH based on paramiko is included. This is how the clients receive the encrypted password.
|
|
||||||
The client must send a single command over the SSH session equal to the name of the secret.
|
|
||||||
|
|
||||||
If permitted to access the secret, it will returned encrypted with the client RSA public key of the client, encoded as base64.
|
The admin interface must be secured. It currently supports local user accounts, but will be expanded with OIDC support in the near future.
|
||||||
|
|
||||||
|
## Custom SSH server
|
||||||
|
To make it easy to fetch secrets sshecret includes a SSH server.
|
||||||
|
|
||||||
|
To fetch a secret, simply ssh to it using the client name as username and the
|
||||||
|
registered RSA key as authorization.
|
||||||
|
|
||||||
|
Send the command `get_secret` followed by the name of the secret. The ssh server
|
||||||
|
will check the client's public key against permitted keys, and check if the
|
||||||
|
client is allowed to connect and if the secret is available to it.
|
||||||
|
|
||||||
|
The server will answer with a base64 encoded version of the secret.
|
||||||
|
|
||||||
|
See `examples/sshecret-client.bash` for an example of a bash script that can be
|
||||||
|
used to fetch and decrypt a secret.
|
||||||
|
|
||||||
|
Out of the box, only the `get_secret` command is available, however an optional
|
||||||
|
command `register` can be enabled to allow registration of a client using the
|
||||||
|
SSH interface to make it easy to onboard new clients automatically.
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
|
||||||
|
Each subsystem is set up using environment variables.
|
||||||
|
|
||||||
|
## Backend
|
||||||
|
The location and type of database may be configured.
|
||||||
|
For now, only sqlite is officially supported.
|
||||||
|
|
||||||
|
The value shown below is the default value. In a docker setup, you may want to
|
||||||
|
create a volume and configure this to a path within the volume to be able to
|
||||||
|
perform backups.
|
||||||
|
|
||||||
|
SSHECRET_BACKEND_DATABASE=/path/to/sshecret.db
|
||||||
|
|
||||||
|
|
||||||
|
While the backend can be placed behind reverse proxy and served with HTTPS, it's
|
||||||
|
probably better to have keep it inside an internal container network.
|
||||||
|
|
||||||
|
## Admin
|
||||||
|
The backend must be generated on the backend before setting up the admin.
|
||||||
|
|
||||||
|
SSHECRET_BACKEND_URL=http://backend:8022
|
||||||
|
SSHECRET_ADMIN_BACKEND_TOKEN: mySuperSecretBackendToken
|
||||||
|
|
||||||
|
|
||||||
|
## Deployment
|
||||||
|
|
||||||
|
The system can be deplyed using docker or any other container runtime.
|
||||||
|
|
||||||
|
See the examples in the `docker/` folder.
|
||||||
|
|
||||||
This allows the client to decrypt and get the clear text value easily.
|
|
||||||
|
|
||||||
# FAQ
|
# FAQ
|
||||||
## Why not use Age?
|
## Why not use Age?
|
||||||
I like age a lot, and it's ability to use more ssh key types is certainly a winner feature.
|
|
||||||
However, one goal here is to be able to construct a client with minimal dependencies, and that speaks in favor of the current solution.
|
I like age a lot, and it's ability to use more ssh key types is certainly a
|
||||||
|
winner feature. However, a clear goal of this project is to be able to construct
|
||||||
|
a client with minimal dependencies.
|
||||||
|
|
||||||
Using just RSA keys, you can construct a client using only the following tools:
|
Using just RSA keys, you can construct a client using only the following tools:
|
||||||
- base64
|
- base64
|
||||||
@ -60,3 +114,5 @@ Using just RSA keys, you can construct a client using only the following tools:
|
|||||||
- ssh
|
- ssh
|
||||||
|
|
||||||
This means that you can create a client using just a shell script.
|
This means that you can create a client using just a shell script.
|
||||||
|
|
||||||
|
If age were to be used, the age tool would have to be installed on each client sytem.
|
||||||
|
|||||||
30
docker/Dockerfile.admin
Normal file
30
docker/Dockerfile.admin
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# this Dockerfile should be built from the repo root
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy
|
||||||
|
ENV UV_PYTHON_DOWNLOADS=0
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY . /build
|
||||||
|
|
||||||
|
RUN uv build --package sshecret
|
||||||
|
RUN uv build --package sshecret-admin
|
||||||
|
|
||||||
|
|
||||||
|
FROM python:3.13-slim-bookworm
|
||||||
|
|
||||||
|
COPY --from=builder --chown=app:app /build/dist /opt/sshecret
|
||||||
|
|
||||||
|
RUN pip install /opt/sshecret/sshecret-*.whl
|
||||||
|
RUN pip install /opt/sshecret/sshecret_admin-*.whl
|
||||||
|
|
||||||
|
EXPOSE 8822
|
||||||
|
|
||||||
|
VOLUME /opt/sshecret-admin
|
||||||
|
|
||||||
|
WORKDIR /opt/sshecret-admin
|
||||||
|
|
||||||
|
ENTRYPOINT [ "sshecret-admin" ]
|
||||||
|
|
||||||
|
CMD ["run", "--host", "0.0.0.0"]
|
||||||
30
docker/Dockerfile.backend
Normal file
30
docker/Dockerfile.backend
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# this Dockerfile should be built from the repo root
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy
|
||||||
|
ENV UV_PYTHON_DOWNLOADS=0
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY . /build
|
||||||
|
|
||||||
|
RUN uv build --package sshecret
|
||||||
|
RUN uv build --package sshecret-backend
|
||||||
|
|
||||||
|
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
||||||
|
|
||||||
|
COPY --from=builder --chown=app:app /build/dist /opt/sshecret
|
||||||
|
|
||||||
|
RUN uv pip install --system /opt/sshecret/sshecret-*.whl
|
||||||
|
RUN uv pip install --system /opt/sshecret/sshecret_backend-*.whl
|
||||||
|
|
||||||
|
COPY packages/sshecret-backend /opt/sshecret-backend
|
||||||
|
COPY docker/backend.entrypoint.sh /entrypoint.sh
|
||||||
|
|
||||||
|
WORKDIR /opt/sshecret-backend
|
||||||
|
|
||||||
|
VOLUME /opt/sshecret-backend-db
|
||||||
|
|
||||||
|
EXPOSE 8022
|
||||||
|
|
||||||
|
CMD ["/entrypoint.sh"]
|
||||||
26
docker/Dockerfile.sshd
Normal file
26
docker/Dockerfile.sshd
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# this Dockerfile should be built from the repo root
|
||||||
|
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
|
||||||
|
|
||||||
|
ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy
|
||||||
|
ENV UV_PYTHON_DOWNLOADS=0
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY . /build
|
||||||
|
|
||||||
|
RUN uv build --package sshecret
|
||||||
|
RUN uv build --package sshecret-sshd
|
||||||
|
|
||||||
|
FROM python:3.13-slim-bookworm
|
||||||
|
|
||||||
|
COPY --from=builder --chown=app:app /build/dist /opt/sshecret
|
||||||
|
|
||||||
|
RUN pip install /opt/sshecret/sshecret-*.whl
|
||||||
|
RUN pip install /opt/sshecret/sshecret_sshd-*.whl
|
||||||
|
|
||||||
|
WORKDIR /opt/sshecret-sshd
|
||||||
|
|
||||||
|
VOLUME /opt/sshecret-sshd
|
||||||
|
|
||||||
|
EXPOSE 2222
|
||||||
|
|
||||||
|
CMD ["sshecret-sshd", "run", "--host", "0.0.0.0"]
|
||||||
15
docker/backend.entrypoint.sh
Executable file
15
docker/backend.entrypoint.sh
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
fail() {
|
||||||
|
printf '%s\n' "$1" >&2 ## Send message to stderr.
|
||||||
|
exit "${2-1}" ## Return a code specified by $2, or 1 by default.
|
||||||
|
}
|
||||||
|
|
||||||
|
[[ -d migrations ]] || fail "Error: Must be run from the backend directory."
|
||||||
|
[[ -d /opt/sshecret-backend-db ]] || mkdir /opt/sshecret-backend-db
|
||||||
|
|
||||||
|
export SSHECRET_BACKEND_DATABASE="/opt/sshecret-backend-db/sshecret.db"
|
||||||
|
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
sshecret-backend run --host 0.0.0.0
|
||||||
19
docker/docker-compose.yml
Normal file
19
docker/docker-compose.yml
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
---
|
||||||
|
|
||||||
|
services:
|
||||||
|
backend:
|
||||||
|
image: sshecret-backend
|
||||||
|
container_name: sshecret_backend
|
||||||
|
build:
|
||||||
|
context: ../
|
||||||
|
dockerfile: dockerfile.backend
|
||||||
|
networks:
|
||||||
|
- common
|
||||||
|
volumes:
|
||||||
|
- backend_data
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
backend_data:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
common:
|
||||||
@ -19,10 +19,14 @@ dependencies = [
|
|||||||
"pyjwt>=2.10.1",
|
"pyjwt>=2.10.1",
|
||||||
"pykeepass>=4.1.1.post1",
|
"pykeepass>=4.1.1.post1",
|
||||||
"sqlmodel>=0.0.24",
|
"sqlmodel>=0.0.24",
|
||||||
|
"sshecret",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
sshecret = { workspace = true }
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
sshecret-admin = "sshecret_admin.cli:cli"
|
sshecret-admin = "sshecret_admin.core.cli:cli"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
@ -31,4 +35,5 @@ build-backend = "hatchling.build"
|
|||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"pytailwindcss>=0.2.0",
|
"pytailwindcss>=0.2.0",
|
||||||
|
"types-pyjwt>=1.7.1",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,284 +0,0 @@
|
|||||||
"""Admin API."""
|
|
||||||
|
|
||||||
# pyright: reportUnusedFunction=false
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
from sshecret.backend import Client, SshecretBackend
|
|
||||||
from sshecret.backend.models import Secret
|
|
||||||
|
|
||||||
from .admin_backend import AdminBackend
|
|
||||||
from .auth_models import (
|
|
||||||
PasswordDB,
|
|
||||||
Token,
|
|
||||||
TokenData,
|
|
||||||
User,
|
|
||||||
create_access_token,
|
|
||||||
verify_password,
|
|
||||||
)
|
|
||||||
from .settings import AdminServerSettings
|
|
||||||
from .types import DBSessionDep
|
|
||||||
from .view_models import (
|
|
||||||
ClientCreate,
|
|
||||||
SecretCreate,
|
|
||||||
SecretUpdate,
|
|
||||||
SecretView,
|
|
||||||
UpdateKeyModel,
|
|
||||||
UpdateKeyResponse,
|
|
||||||
UpdatePoliciesRequest,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
API_VERSION = "v1"
|
|
||||||
JWT_ALGORITHM = "HS256"
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
||||||
|
|
||||||
|
|
||||||
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
|
||||||
"""Authenticate user."""
|
|
||||||
user = session.exec(select(User).where(User.username == username)).first()
|
|
||||||
if not user:
|
|
||||||
return None
|
|
||||||
if not verify_password(password, user.hashed_password):
|
|
||||||
return None
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def map_secrets_to_clients(
|
|
||||||
backend: SshecretBackend,
|
|
||||||
) -> defaultdict[str, list[str]]:
|
|
||||||
"""Map secrets to clients."""
|
|
||||||
clients = await backend.get_clients()
|
|
||||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
|
||||||
for client in clients:
|
|
||||||
for secret in client.secrets:
|
|
||||||
client_secret_map[secret].append(client.name)
|
|
||||||
return client_secret_map
|
|
||||||
|
|
||||||
|
|
||||||
def get_admin_api(
|
|
||||||
get_db_session: DBSessionDep, settings: AdminServerSettings
|
|
||||||
) -> APIRouter:
|
|
||||||
"""Get Admin API."""
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
|
||||||
|
|
||||||
async def get_admin_backend(session: Annotated[Session, Depends(get_db_session)]):
|
|
||||||
"""Get admin backend API."""
|
|
||||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
|
||||||
if not password_db:
|
|
||||||
raise HTTPException(
|
|
||||||
500, detail="Error: The password manager has not yet been set up."
|
|
||||||
)
|
|
||||||
admin = AdminBackend(settings, password_db.encrypted_password)
|
|
||||||
yield admin
|
|
||||||
|
|
||||||
async def get_current_user(
|
|
||||||
token: Annotated[str, Depends(oauth2_scheme)],
|
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
|
||||||
) -> User:
|
|
||||||
"""Get current user from token."""
|
|
||||||
credentials_exception = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
|
||||||
username = payload.get("sub")
|
|
||||||
if not username:
|
|
||||||
raise credentials_exception
|
|
||||||
token_data = TokenData(username=username)
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
user = session.exec(
|
|
||||||
select(User).where(User.username == token_data.username)
|
|
||||||
).first()
|
|
||||||
if not user:
|
|
||||||
raise credentials_exception
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def get_current_active_user(
|
|
||||||
current_user: Annotated[User, Depends(get_current_user)],
|
|
||||||
) -> User:
|
|
||||||
"""Get current active user."""
|
|
||||||
if current_user.disabled:
|
|
||||||
raise HTTPException(status_code=400, detail="Inactive or disabled user")
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
app = APIRouter(
|
|
||||||
prefix=f"/api/{API_VERSION}", dependencies=[Depends(get_current_active_user)]
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.post("/token")
|
|
||||||
async def login_for_access_token(
|
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
|
||||||
) -> Token:
|
|
||||||
"""Login user and generate token."""
|
|
||||||
user = authenticate_user(session, form_data.username, form_data.password)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Incorrect username or password",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
||||||
access_token = create_access_token(
|
|
||||||
settings,
|
|
||||||
data={"sub": user.username},
|
|
||||||
expires_delta=access_token_expires,
|
|
||||||
)
|
|
||||||
return Token(access_token=access_token, token_type="bearer")
|
|
||||||
|
|
||||||
@app.get("/clients/")
|
|
||||||
async def get_clients(
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
|
||||||
) -> list[Client]:
|
|
||||||
"""Get clients."""
|
|
||||||
clients = await admin.get_clients()
|
|
||||||
return clients
|
|
||||||
|
|
||||||
@app.post("/clients/")
|
|
||||||
async def create_client(
|
|
||||||
new_client: ClientCreate,
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
) -> Client:
|
|
||||||
"""Create a new client."""
|
|
||||||
sources: list[str] | None = None
|
|
||||||
if new_client.sources:
|
|
||||||
sources = [str(source) for source in new_client.sources]
|
|
||||||
client = await admin.create_client(
|
|
||||||
new_client.name, new_client.public_key, sources
|
|
||||||
)
|
|
||||||
return client
|
|
||||||
|
|
||||||
@app.delete("/clients/{name}")
|
|
||||||
async def delete_client(
|
|
||||||
name: str, admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
|
||||||
) -> None:
|
|
||||||
"""Delete a client."""
|
|
||||||
await admin.delete_client(name)
|
|
||||||
|
|
||||||
@app.delete("/clients/{name}/secrets/{secret_name}")
|
|
||||||
async def delete_secret_from_client(
|
|
||||||
name: str,
|
|
||||||
secret_name: str,
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
) -> None:
|
|
||||||
"""Delete a secret from a client."""
|
|
||||||
client = await admin.get_client(name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
if secret_name not in client.secrets:
|
|
||||||
LOG.debug("Client does not have requested secret. No action to perform.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
await admin.delete_client_secret(name, secret_name)
|
|
||||||
|
|
||||||
@app.put("/clients/{name}/policies")
|
|
||||||
async def update_client_policies(
|
|
||||||
name: str,
|
|
||||||
updated: UpdatePoliciesRequest,
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
) -> Client:
|
|
||||||
"""Update the client access policies."""
|
|
||||||
client = await admin.get_client(name)
|
|
||||||
if not client:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.debug("Old policies: %r. New: %r", client.policies, updated.sources)
|
|
||||||
|
|
||||||
addresses: list[str] = [str(source) for source in updated.sources]
|
|
||||||
await admin.update_client_sources(name, addresses)
|
|
||||||
client = await admin.get_client(name)
|
|
||||||
|
|
||||||
assert client is not None, "Critical: The client disappeared after update!"
|
|
||||||
|
|
||||||
return client
|
|
||||||
|
|
||||||
@app.get("/secrets/")
|
|
||||||
async def get_secret_names(
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
|
||||||
) -> list[Secret]:
|
|
||||||
"""Get Secret Names."""
|
|
||||||
return await admin.get_secrets()
|
|
||||||
|
|
||||||
@app.post("/secrets/")
|
|
||||||
async def add_secret(
|
|
||||||
secret: SecretCreate, admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
|
||||||
) -> None:
|
|
||||||
"""Create a secret."""
|
|
||||||
await admin.add_secret(secret.name, secret.get_secret(), secret.clients)
|
|
||||||
|
|
||||||
@app.get("/secrets/{name}")
|
|
||||||
async def get_secret(
|
|
||||||
name: str, admin: Annotated[AdminBackend, Depends(get_admin_backend)]
|
|
||||||
) -> SecretView:
|
|
||||||
"""Get a secret."""
|
|
||||||
secret_view = await admin.get_secret(name)
|
|
||||||
|
|
||||||
if not secret_view:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found."
|
|
||||||
)
|
|
||||||
return secret_view
|
|
||||||
|
|
||||||
@app.put("/secrets/{name}")
|
|
||||||
async def update_secret(
|
|
||||||
name: str,
|
|
||||||
value: SecretUpdate,
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
) -> None:
|
|
||||||
new_value = value.get_secret()
|
|
||||||
await admin.update_secret(name, new_value)
|
|
||||||
|
|
||||||
@app.delete("/secrets/{name}")
|
|
||||||
async def delete_secret(
|
|
||||||
name: str,
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
) -> None:
|
|
||||||
"""Delete secret."""
|
|
||||||
await admin.delete_secret(name)
|
|
||||||
|
|
||||||
@app.put("/clients/{name}/public-key")
|
|
||||||
async def update_client_public_key(
|
|
||||||
name: str,
|
|
||||||
updated: UpdateKeyModel,
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
) -> UpdateKeyResponse:
|
|
||||||
"""Update client public key.
|
|
||||||
|
|
||||||
Updating the public key will invalidate the current secrets, so these well
|
|
||||||
be resolved first, and re-encrypted using the new key.
|
|
||||||
"""
|
|
||||||
# Let's first ensure that the key is actually updated.
|
|
||||||
updated_secrets = await admin.update_client_public_key(name, updated.public_key)
|
|
||||||
return UpdateKeyResponse(
|
|
||||||
public_key=updated.public_key, updated_secrets=updated_secrets
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.put("/clients/{name}/secrets/{secret_name}")
|
|
||||||
async def add_secret_to_client(
|
|
||||||
name: str,
|
|
||||||
secret_name: str,
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
) -> None:
|
|
||||||
"""Add secret to a client."""
|
|
||||||
await admin.create_client_secret(name, secret_name)
|
|
||||||
|
|
||||||
return app
|
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
"""Admin REST API."""
|
||||||
|
|
||||||
|
from .router import create_router as create_api_router
|
||||||
|
|
||||||
|
__all__ = ["create_api_router"]
|
||||||
@ -0,0 +1 @@
|
|||||||
|
"""API Endpoints."""
|
||||||
@ -0,0 +1,39 @@
|
|||||||
|
"""Authentication related endpoints factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from sqlmodel import Session
|
||||||
|
|
||||||
|
from sshecret_admin.auth import Token, authenticate_user, create_access_token
|
||||||
|
from sshecret_admin.core.dependencies import AdminDependencies
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||||
|
"""Create auth router."""
|
||||||
|
app = APIRouter()
|
||||||
|
|
||||||
|
@app.post("/token")
|
||||||
|
async def login_for_access_token(
|
||||||
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
|
) -> Token:
|
||||||
|
"""Login user and generate token."""
|
||||||
|
user = authenticate_user(session, form_data.username, form_data.password)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
access_token = create_access_token(
|
||||||
|
dependencies.settings,
|
||||||
|
data={"sub": user.username},
|
||||||
|
)
|
||||||
|
return Token(access_token=access_token, token_type="bearer")
|
||||||
|
|
||||||
|
|
||||||
|
return app
|
||||||
@ -0,0 +1,124 @@
|
|||||||
|
"""Client-related endpoints factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from sshecret.backend import Client
|
||||||
|
from sshecret_admin.core.dependencies import AdminDependencies
|
||||||
|
from sshecret_admin.services import AdminBackend
|
||||||
|
from sshecret_admin.services.models import (
|
||||||
|
ClientCreate,
|
||||||
|
UpdateKeyModel,
|
||||||
|
UpdateKeyResponse,
|
||||||
|
UpdatePoliciesRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||||
|
"""Create clients router."""
|
||||||
|
app = APIRouter()
|
||||||
|
|
||||||
|
@app.get("/clients/")
|
||||||
|
async def get_clients(
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)]
|
||||||
|
) -> list[Client]:
|
||||||
|
"""Get clients."""
|
||||||
|
clients = await admin.get_clients()
|
||||||
|
return clients
|
||||||
|
|
||||||
|
@app.post("/clients/")
|
||||||
|
async def create_client(
|
||||||
|
new_client: ClientCreate,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> Client:
|
||||||
|
"""Create a new client."""
|
||||||
|
sources: list[str] | None = None
|
||||||
|
if new_client.sources:
|
||||||
|
sources = [str(source) for source in new_client.sources]
|
||||||
|
client = await admin.create_client(
|
||||||
|
new_client.name, new_client.public_key, sources=sources
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
@app.delete("/clients/{name}")
|
||||||
|
async def delete_client(
|
||||||
|
name: str,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> None:
|
||||||
|
"""Delete a client."""
|
||||||
|
await admin.delete_client(name)
|
||||||
|
|
||||||
|
@app.delete("/clients/{name}/secrets/{secret_name}")
|
||||||
|
async def delete_secret_from_client(
|
||||||
|
name: str,
|
||||||
|
secret_name: str,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> None:
|
||||||
|
"""Delete a secret from a client."""
|
||||||
|
client = await admin.get_client(name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
if secret_name not in client.secrets:
|
||||||
|
LOG.debug("Client does not have requested secret. No action to perform.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
await admin.delete_client_secret(name, secret_name)
|
||||||
|
|
||||||
|
@app.put("/clients/{name}/policies")
|
||||||
|
async def update_client_policies(
|
||||||
|
name: str,
|
||||||
|
updated: UpdatePoliciesRequest,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> Client:
|
||||||
|
"""Update the client access policies."""
|
||||||
|
client = await admin.get_client(name)
|
||||||
|
if not client:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.debug("Old policies: %r. New: %r", client.policies, updated.sources)
|
||||||
|
|
||||||
|
addresses: list[str] = [str(source) for source in updated.sources]
|
||||||
|
await admin.update_client_sources(name, addresses)
|
||||||
|
client = await admin.get_client(name)
|
||||||
|
|
||||||
|
assert client is not None, "Critical: The client disappeared after update!"
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
@app.put("/clients/{name}/public-key")
|
||||||
|
async def update_client_public_key(
|
||||||
|
name: str,
|
||||||
|
updated: UpdateKeyModel,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> UpdateKeyResponse:
|
||||||
|
"""Update client public key.
|
||||||
|
|
||||||
|
Updating the public key will invalidate the current secrets, so these well
|
||||||
|
be resolved first, and re-encrypted using the new key.
|
||||||
|
"""
|
||||||
|
# Let's first ensure that the key is actually updated.
|
||||||
|
updated_secrets = await admin.update_client_public_key(name, updated.public_key)
|
||||||
|
return UpdateKeyResponse(
|
||||||
|
public_key=updated.public_key, updated_secrets=updated_secrets
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.put("/clients/{name}/secrets/{secret_name}")
|
||||||
|
async def add_secret_to_client(
|
||||||
|
name: str,
|
||||||
|
secret_name: str,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> None:
|
||||||
|
"""Add secret to a client."""
|
||||||
|
await admin.create_client_secret(name, secret_name)
|
||||||
|
|
||||||
|
return app
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
"""Secrets related endpoints factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from sshecret.backend.models import Secret
|
||||||
|
from sshecret_admin.core.dependencies import AdminDependencies
|
||||||
|
from sshecret_admin.services import AdminBackend
|
||||||
|
from sshecret_admin.services.models import (
|
||||||
|
SecretCreate,
|
||||||
|
SecretUpdate,
|
||||||
|
SecretView,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_router(dependencies: AdminDependencies) -> APIRouter:
|
||||||
|
"""Create secrets router."""
|
||||||
|
app = APIRouter()
|
||||||
|
|
||||||
|
@app.get("/secrets/")
|
||||||
|
async def get_secret_names(
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)]
|
||||||
|
) -> list[Secret]:
|
||||||
|
"""Get Secret Names."""
|
||||||
|
return await admin.get_secrets()
|
||||||
|
|
||||||
|
@app.post("/secrets/")
|
||||||
|
async def add_secret(
|
||||||
|
secret: SecretCreate,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> None:
|
||||||
|
"""Create a secret."""
|
||||||
|
await admin.add_secret(secret.name, secret.get_secret(), secret.clients)
|
||||||
|
|
||||||
|
@app.get("/secrets/{name}")
|
||||||
|
async def get_secret(
|
||||||
|
name: str,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> SecretView:
|
||||||
|
"""Get a secret."""
|
||||||
|
secret_view = await admin.get_secret(name)
|
||||||
|
|
||||||
|
if not secret_view:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND, detail="Item not found."
|
||||||
|
)
|
||||||
|
return secret_view
|
||||||
|
|
||||||
|
@app.put("/secrets/{name}")
|
||||||
|
async def update_secret(
|
||||||
|
name: str,
|
||||||
|
value: SecretUpdate,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> None:
|
||||||
|
new_value = value.get_secret()
|
||||||
|
await admin.update_secret(name, new_value)
|
||||||
|
|
||||||
|
@app.delete("/secrets/{name}")
|
||||||
|
async def delete_secret(
|
||||||
|
name: str,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> None:
|
||||||
|
"""Delete secret."""
|
||||||
|
await admin.delete_secret(name)
|
||||||
|
|
||||||
|
return app
|
||||||
78
packages/sshecret-admin/src/sshecret_admin/api/router.py
Normal file
78
packages/sshecret-admin/src/sshecret_admin/api/router.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"""Main API Router."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from sshecret_admin.services.admin_backend import AdminBackend
|
||||||
|
from sshecret_admin.core.dependencies import BaseDependencies, AdminDependencies
|
||||||
|
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||||
|
|
||||||
|
from .endpoints import auth, clients, secrets
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
API_VERSION = "v1"
|
||||||
|
|
||||||
|
|
||||||
|
def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||||
|
"""Create clients router."""
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: Annotated[str, Depends(oauth2_scheme)],
|
||||||
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
|
) -> User:
|
||||||
|
"""Get current user from token."""
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
token_data = decode_token(dependencies.settings, token)
|
||||||
|
if not token_data:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
user = session.exec(
|
||||||
|
select(User).where(User.username == token_data.username)
|
||||||
|
).first()
|
||||||
|
if not user:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_current_active_user(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
) -> User:
|
||||||
|
"""Get current active user."""
|
||||||
|
if current_user.disabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Inactive or disabled user")
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
async def get_admin_backend(session: Annotated[Session, Depends(dependencies.get_db_session)]):
|
||||||
|
"""Get admin backend API."""
|
||||||
|
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||||
|
if not password_db:
|
||||||
|
raise HTTPException(
|
||||||
|
500, detail="Error: The password manager has not yet been set up."
|
||||||
|
)
|
||||||
|
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||||
|
yield admin
|
||||||
|
|
||||||
|
app = APIRouter(
|
||||||
|
prefix=f"/api/{API_VERSION}", dependencies=[Depends(get_current_active_user)]
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint_deps = AdminDependencies.create(dependencies, get_admin_backend)
|
||||||
|
|
||||||
|
app.include_router(auth.create_router(endpoint_deps))
|
||||||
|
app.include_router(clients.create_router(endpoint_deps))
|
||||||
|
app.include_router(secrets.create_router(endpoint_deps))
|
||||||
|
|
||||||
|
return app
|
||||||
24
packages/sshecret-admin/src/sshecret_admin/auth/__init__.py
Normal file
24
packages/sshecret-admin/src/sshecret_admin/auth/__init__.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""Authentication related module."""
|
||||||
|
|
||||||
|
from .authentication import (
|
||||||
|
authenticate_user,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
check_password,
|
||||||
|
decode_token,
|
||||||
|
verify_password,
|
||||||
|
)
|
||||||
|
from .models import User, Token, PasswordDB
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PasswordDB",
|
||||||
|
"Token",
|
||||||
|
"User",
|
||||||
|
"authenticate_user",
|
||||||
|
"check_password",
|
||||||
|
"create_access_token",
|
||||||
|
"create_refresh_token",
|
||||||
|
"decode_token",
|
||||||
|
"verify_password",
|
||||||
|
]
|
||||||
@ -0,0 +1,95 @@
|
|||||||
|
"""Authentication utilities."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone, timedelta
|
||||||
|
from typing import cast, Any
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
import jwt
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
from .models import User, TokenData
|
||||||
|
from .exceptions import AuthenticationFailedError
|
||||||
|
|
||||||
|
JWT_ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||||
|
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
|
||||||
|
# sensitive application, seems reasonable.
|
||||||
|
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def create_token(
|
||||||
|
settings: AdminServerSettings,
|
||||||
|
data: dict[str, Any],
|
||||||
|
expires_delta: timedelta,
|
||||||
|
) -> str:
|
||||||
|
"""Create access token."""
|
||||||
|
to_encode = data.copy()
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
|
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM)
|
||||||
|
return str(encoded_jwt)
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(
|
||||||
|
settings: AdminServerSettings,
|
||||||
|
data: dict[str, Any],
|
||||||
|
expires_delta: timedelta | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create access token."""
|
||||||
|
if not expires_delta:
|
||||||
|
expires_delta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
return create_token(settings, data, expires_delta)
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(
|
||||||
|
settings: AdminServerSettings,
|
||||||
|
data: dict[str, Any],
|
||||||
|
expires_delta: timedelta | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create access token."""
|
||||||
|
if not expires_delta:
|
||||||
|
expires_delta = timedelta(hours=REFRESH_TOKEN_EXPIRE_HOURS)
|
||||||
|
return create_token(settings, data, expires_delta)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""Verify password against stored hash."""
|
||||||
|
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def check_password(plain_password: str, hashed_password: str) -> None:
|
||||||
|
"""Check password.
|
||||||
|
|
||||||
|
If password doesn't match, throw AuthenticationFailedError.
|
||||||
|
"""
|
||||||
|
if not verify_password(plain_password, hashed_password):
|
||||||
|
raise AuthenticationFailedError()
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user(session: Session, username: str, password: str) -> User | None:
|
||||||
|
"""Authenticate user."""
|
||||||
|
user = session.exec(select(User).where(User.username == username)).first()
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
if not verify_password(password, user.hashed_password):
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token(settings: AdminServerSettings, token: str) -> TokenData | None:
|
||||||
|
"""Decode token."""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
||||||
|
username = cast("str | None", payload.get("sub"))
|
||||||
|
if not username:
|
||||||
|
return None
|
||||||
|
|
||||||
|
token_data = TokenData(username=username)
|
||||||
|
return token_data
|
||||||
|
except jwt.InvalidTokenError as e:
|
||||||
|
LOG.debug("Could not decode token: %s", e, exc_info=True)
|
||||||
|
return None
|
||||||
@ -0,0 +1,30 @@
|
|||||||
|
"""Authentication related exceptions."""
|
||||||
|
from typing import override
|
||||||
|
|
||||||
|
from .models import LoginError
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationFailedError(Exception):
|
||||||
|
"""Authentication failed."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __init__(self, message: str | None = None) -> None:
|
||||||
|
"""Initialize exception class."""
|
||||||
|
if not message:
|
||||||
|
message = "Invalid user or password."
|
||||||
|
super().__init__(message)
|
||||||
|
self.login_error: LoginError = LoginError(
|
||||||
|
title="Authentication Failed", message=message
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationNeededError(Exception):
|
||||||
|
"""Authentication needed error."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __init__(self, message: str | None = None) -> None:
|
||||||
|
"""Initialize exception class."""
|
||||||
|
if not message:
|
||||||
|
message = "You need to be logged in to continue."
|
||||||
|
super().__init__(message)
|
||||||
|
self.login_error: LoginError = LoginError(title="Unauthorized", message=message)
|
||||||
71
packages/sshecret-admin/src/sshecret_admin/auth/models.py
Normal file
71
packages/sshecret-admin/src/sshecret_admin/auth/models.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
"""Models for authentication."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlmodel import SQLModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
JWT_ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||||||
|
# I know refresh tokens are supposed to be long-lived, but 6 hours for a
|
||||||
|
# sensitive application, seems reasonable.
|
||||||
|
REFRESH_TOKEN_EXPIRE_HOURS = 6
|
||||||
|
|
||||||
|
|
||||||
|
class User(SQLModel, table=True):
|
||||||
|
"""Users."""
|
||||||
|
|
||||||
|
username: str = Field(unique=True, primary_key=True)
|
||||||
|
hashed_password: str
|
||||||
|
disabled: bool = Field(default=False)
|
||||||
|
created_at: datetime | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_type=sa.DateTime(timezone=True),
|
||||||
|
sa_column_kwargs={"server_default": sa.func.now()},
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PasswordDB(SQLModel, table=True):
|
||||||
|
"""Password database."""
|
||||||
|
|
||||||
|
id: int | None = Field(default=None, primary_key=True)
|
||||||
|
encrypted_password: str
|
||||||
|
|
||||||
|
created_at: datetime | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_type=sa.DateTime(timezone=True),
|
||||||
|
sa_column_kwargs={"server_default": sa.func.now()},
|
||||||
|
nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_at: datetime | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_type=sa.DateTime(timezone=True),
|
||||||
|
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_db(engine: sa.Engine) -> None:
|
||||||
|
"""Create database."""
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenData(SQLModel):
|
||||||
|
"""Token data."""
|
||||||
|
|
||||||
|
username: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Token(SQLModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoginError(SQLModel):
|
||||||
|
"""Login Error model."""
|
||||||
|
# TODO: Remove this.
|
||||||
|
|
||||||
|
title: str
|
||||||
|
message: str
|
||||||
|
|
||||||
@ -1,125 +0,0 @@
|
|||||||
"""Models for authentication."""
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
import bcrypt
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from typing import Any, override
|
|
||||||
import jwt
|
|
||||||
from sqlmodel import SQLModel, Field
|
|
||||||
from sshecret_admin.settings import AdminServerSettings
|
|
||||||
|
|
||||||
|
|
||||||
JWT_ALGORITHM = "HS256"
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
|
||||||
|
|
||||||
|
|
||||||
class User(SQLModel, table=True):
|
|
||||||
"""Users."""
|
|
||||||
|
|
||||||
username: str = Field(unique=True, primary_key=True)
|
|
||||||
hashed_password: str
|
|
||||||
disabled: bool = Field(default=False)
|
|
||||||
created_at: datetime | None = Field(
|
|
||||||
default=None,
|
|
||||||
sa_type=sa.DateTime(timezone=True),
|
|
||||||
sa_column_kwargs={"server_default": sa.func.now()},
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PasswordDB(SQLModel, table=True):
|
|
||||||
"""Password database."""
|
|
||||||
|
|
||||||
id: int | None = Field(default=None, primary_key=True)
|
|
||||||
encrypted_password: str
|
|
||||||
|
|
||||||
created_at: datetime | None = Field(
|
|
||||||
default=None,
|
|
||||||
sa_type=sa.DateTime(timezone=True),
|
|
||||||
sa_column_kwargs={"server_default": sa.func.now()},
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_at: datetime | None = Field(
|
|
||||||
default=None,
|
|
||||||
sa_type=sa.DateTime(timezone=True),
|
|
||||||
sa_column_kwargs={"onupdate": sa.func.now(), "server_default": sa.func.now()},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init_db(engine: sa.Engine) -> None:
|
|
||||||
"""Create database."""
|
|
||||||
SQLModel.metadata.create_all(engine)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenData(SQLModel):
|
|
||||||
"""Token data."""
|
|
||||||
|
|
||||||
username: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class Token(SQLModel):
|
|
||||||
access_token: str
|
|
||||||
token_type: str
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(
|
|
||||||
settings: AdminServerSettings,
|
|
||||||
data: dict[str, Any],
|
|
||||||
expires_delta: timedelta | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Create access token."""
|
|
||||||
to_encode = data.copy()
|
|
||||||
expire = datetime.now(timezone.utc) + timedelta(minutes=15)
|
|
||||||
if expires_delta:
|
|
||||||
expire = datetime.now(timezone.utc) + expires_delta
|
|
||||||
to_encode.update({"exp": expire})
|
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=JWT_ALGORITHM)
|
|
||||||
return encoded_jwt
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
||||||
"""Verify password against stored hash."""
|
|
||||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
|
||||||
|
|
||||||
|
|
||||||
def check_password(plain_password: str, hashed_password: str) -> None:
|
|
||||||
"""Check password.
|
|
||||||
|
|
||||||
If password doesn't match, throw AuthenticationFailedError.
|
|
||||||
"""
|
|
||||||
if not verify_password(plain_password, hashed_password):
|
|
||||||
raise AuthenticationFailedError()
|
|
||||||
|
|
||||||
|
|
||||||
class LoginError(SQLModel):
|
|
||||||
"""Login Error model."""
|
|
||||||
|
|
||||||
title: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationFailedError(Exception):
|
|
||||||
"""Authentication failed."""
|
|
||||||
|
|
||||||
@override
|
|
||||||
def __init__(self, message: str | None = None) -> None:
|
|
||||||
"""Initialize exception class."""
|
|
||||||
if not message:
|
|
||||||
message = "Invalid user or password."
|
|
||||||
super().__init__(message)
|
|
||||||
self.login_error: LoginError = LoginError(
|
|
||||||
title="Authentication Failed", message=message
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationNeededError(Exception):
|
|
||||||
"""Authentication needed error."""
|
|
||||||
|
|
||||||
@override
|
|
||||||
def __init__(self, message: str | None = None) -> None:
|
|
||||||
"""Initialize exception class."""
|
|
||||||
if not message:
|
|
||||||
message = "You need to be logged in to continue."
|
|
||||||
super().__init__(message)
|
|
||||||
self.login_error: LoginError = LoginError(title="Unauthorized", message=message)
|
|
||||||
@ -5,24 +5,22 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, status
|
from fastapi import FastAPI, Request, Response, status
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, RedirectResponse
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
from sshecret_admin import api, frontend
|
||||||
|
from sshecret_admin.auth.models import PasswordDB, init_db
|
||||||
|
from sshecret_admin.core.db import setup_database
|
||||||
|
from sshecret_admin.frontend.exceptions import RedirectException
|
||||||
|
from sshecret_admin.services.master_password import setup_master_password
|
||||||
|
|
||||||
from .admin_api import get_admin_api
|
from .dependencies import BaseDependencies
|
||||||
from .auth_models import init_db, PasswordDB, AuthenticationFailedError, AuthenticationNeededError
|
|
||||||
from .db import setup_database
|
|
||||||
from .master_password import setup_master_password
|
|
||||||
from .settings import AdminServerSettings
|
from .settings import AdminServerSettings
|
||||||
from .frontend import create_frontend
|
|
||||||
from .types import DBSessionDep
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -30,15 +28,14 @@ LOG = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def setup_frontend(
|
def setup_frontend(
|
||||||
app: FastAPI, settings: AdminServerSettings, get_db_session: DBSessionDep
|
app: FastAPI, dependencies: BaseDependencies
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Setup frontend."""
|
"""Setup frontend."""
|
||||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||||
static_path = script_path / "static"
|
static_path = script_path.parent / "static"
|
||||||
|
|
||||||
app.mount("/static", StaticFiles(directory=static_path), name="static")
|
app.mount("/static", StaticFiles(directory=static_path), name="static")
|
||||||
frontend = create_frontend(settings, get_db_session)
|
app.include_router(frontend.create_frontend_router(dependencies))
|
||||||
app.include_router(frontend)
|
|
||||||
|
|
||||||
|
|
||||||
def create_admin_app(
|
def create_admin_app(
|
||||||
@ -88,19 +85,15 @@ def create_admin_app(
|
|||||||
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}),
|
||||||
)
|
)
|
||||||
|
|
||||||
@app.exception_handler(AuthenticationNeededError)
|
@app.exception_handler(RedirectException)
|
||||||
async def authentication_needed_handler(
|
async def redirect_handler(request: Request, exc: RedirectException) -> Response:
|
||||||
request: Request, exc: AuthenticationNeededError,
|
"""Handle redirect exceptions."""
|
||||||
):
|
if "hx-request" in request.headers:
|
||||||
qs = f"error_title={exc.login_error.title}&error_message={exc.login_error.message}"
|
response = Response()
|
||||||
return RedirectResponse(f"/?{qs}")
|
response.headers["HX-Redirect"] = str(exc.to)
|
||||||
|
return response
|
||||||
|
return RedirectResponse(url=str(exc.to))
|
||||||
|
|
||||||
@app.exception_handler(AuthenticationFailedError)
|
|
||||||
async def authentication_failed_handler(
|
|
||||||
request: Request, exc: AuthenticationNeededError,
|
|
||||||
):
|
|
||||||
qs = f"error_title={exc.login_error.title}&error_message={exc.login_error.message}"
|
|
||||||
return RedirectResponse(f"/?{qs}")
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def get_health() -> JSONResponse:
|
async def get_health() -> JSONResponse:
|
||||||
@ -109,10 +102,11 @@ def create_admin_app(
|
|||||||
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
status_code=status.HTTP_200_OK, content=jsonable_encoder({"status": "LIVE"})
|
||||||
)
|
)
|
||||||
|
|
||||||
admin_api = get_admin_api(get_db_session, settings)
|
dependencies = BaseDependencies(settings, get_db_session)
|
||||||
|
|
||||||
app.include_router(admin_api)
|
|
||||||
|
app.include_router(api.create_api_router(dependencies))
|
||||||
if with_frontend:
|
if with_frontend:
|
||||||
setup_frontend(app, settings, get_db_session)
|
setup_frontend(app, dependencies)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
@ -7,29 +7,30 @@ import logging
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import click
|
import click
|
||||||
from sshecret_admin.admin_backend import AdminBackend
|
from sshecret_admin.services.admin_backend import AdminBackend
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlmodel import Session, create_engine, select
|
from sqlmodel import Session, create_engine, select
|
||||||
from .auth_models import init_db, User, PasswordDB
|
from sshecret_admin.auth.models import init_db, User, PasswordDB
|
||||||
from .settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
formatter = logging.Formatter("%(asctime)s [%(processName)s: %(process)d] [%(threadName)s: %(thread)d] [%(levelname)s] %(name)s: %(message)s")
|
formatter = logging.Formatter(
|
||||||
|
"%(asctime)s [%(processName)s: %(process)d] [%(threadName)s: %(thread)d] [%(levelname)s] %(name)s: %(message)s"
|
||||||
|
)
|
||||||
handler.setFormatter(formatter)
|
handler.setFormatter(formatter)
|
||||||
LOG = logging.getLogger()
|
LOG = logging.getLogger()
|
||||||
LOG.addHandler(handler)
|
LOG.addHandler(handler)
|
||||||
LOG.setLevel(logging.INFO)
|
LOG.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
"""Hash password."""
|
"""Hash password."""
|
||||||
salt = bcrypt.gensalt()
|
salt = bcrypt.gensalt()
|
||||||
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
hashed_password = bcrypt.hashpw(password.encode(), salt)
|
||||||
return hashed_password.decode()
|
return hashed_password.decode()
|
||||||
|
|
||||||
|
|
||||||
def create_user(session: Session, username: str, password: str) -> None:
|
def create_user(session: Session, username: str, password: str) -> None:
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
hashed_password = hash_password(password)
|
hashed_password = hash_password(password)
|
||||||
@ -48,7 +49,9 @@ def cli(ctx: click.Context, debug: bool) -> None:
|
|||||||
try:
|
try:
|
||||||
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
settings = AdminServerSettings() # pyright: ignore[reportCallIssue]
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise click.ClickException("Error: One or more required environment options are missing.") from e
|
raise click.ClickException(
|
||||||
|
"Error: One or more required environment options are missing."
|
||||||
|
) from e
|
||||||
ctx.obj = settings
|
ctx.obj = settings
|
||||||
|
|
||||||
|
|
||||||
@ -66,6 +69,7 @@ def cli_create_user(ctx: click.Context, username: str, password: str) -> None:
|
|||||||
|
|
||||||
click.echo("User created.")
|
click.echo("User created.")
|
||||||
|
|
||||||
|
|
||||||
@cli.command("passwd")
|
@cli.command("passwd")
|
||||||
@click.argument("username")
|
@click.argument("username")
|
||||||
@click.password_option()
|
@click.password_option()
|
||||||
@ -85,6 +89,7 @@ def cli_change_user_passwd(ctx: click.Context, username: str, password: str) ->
|
|||||||
session.commit()
|
session.commit()
|
||||||
click.echo("Password updated.")
|
click.echo("Password updated.")
|
||||||
|
|
||||||
|
|
||||||
@cli.command("deluser")
|
@cli.command("deluser")
|
||||||
@click.argument("username")
|
@click.argument("username")
|
||||||
@click.confirmation_option()
|
@click.confirmation_option()
|
||||||
@ -112,7 +117,9 @@ def cli_delete_user(ctx: click.Context, username: str) -> None:
|
|||||||
@click.option("--workers", type=click.INT)
|
@click.option("--workers", type=click.INT)
|
||||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||||
"""Run the server."""
|
"""Run the server."""
|
||||||
uvicorn.run("sshecret_admin.main:app", host=host, port=port, reload=dev, workers=workers)
|
uvicorn.run(
|
||||||
|
"sshecret_admin.core.main:app", host=host, port=port, reload=dev, workers=workers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cli.command("repl")
|
@cli.command("repl")
|
||||||
@ -126,7 +133,9 @@ def cli_repl(ctx: click.Context) -> None:
|
|||||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||||
|
|
||||||
if not password_db:
|
if not password_db:
|
||||||
raise click.ClickException("Error: Password database has not yet been setup. Start the server to finish setup.")
|
raise click.ClickException(
|
||||||
|
"Error: Password database has not yet been setup. Start the server to finish setup."
|
||||||
|
)
|
||||||
|
|
||||||
def run(func: Awaitable[Any]) -> Any:
|
def run(func: Awaitable[Any]) -> Any:
|
||||||
"""Run an async function."""
|
"""Run an async function."""
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
"""Common type definitions."""
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator, Callable, Generator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from sqlmodel import Session
|
||||||
|
from sshecret_admin.services import AdminBackend
|
||||||
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
|
|
||||||
|
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||||
|
|
||||||
|
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BaseDependencies:
|
||||||
|
"""Base level dependencies."""
|
||||||
|
|
||||||
|
settings: AdminServerSettings
|
||||||
|
get_db_session: DBSessionDep
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdminDependencies(BaseDependencies):
|
||||||
|
"""Dependency class with admin."""
|
||||||
|
|
||||||
|
get_admin_backend: AdminDep
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, deps: BaseDependencies, get_admin_backend: AdminDep) -> Self:
|
||||||
|
"""Create from base dependencies."""
|
||||||
|
return cls(
|
||||||
|
settings=deps.settings,
|
||||||
|
get_db_session=deps.get_db_session,
|
||||||
|
get_admin_backend=get_admin_backend,
|
||||||
|
)
|
||||||
@ -1,6 +1,5 @@
|
|||||||
"""Main server app."""
|
"""Main server app."""
|
||||||
import sys
|
import sys
|
||||||
import uvicorn
|
|
||||||
import click
|
import click
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
@ -2,11 +2,12 @@
|
|||||||
|
|
||||||
from pydantic import AnyHttpUrl, Field
|
from pydantic import AnyHttpUrl, Field
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
from sqlalchemy import URL
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_LISTEN_PORT = 8822
|
DEFAULT_LISTEN_PORT = 8822
|
||||||
|
|
||||||
DEFAULT_DATABASE = "sqlite:///ssh_admin.db"
|
DEFAULT_DATABASE = "ssh_admin.db"
|
||||||
|
|
||||||
|
|
||||||
class AdminServerSettings(BaseSettings):
|
class AdminServerSettings(BaseSettings):
|
||||||
@ -21,5 +22,12 @@ class AdminServerSettings(BaseSettings):
|
|||||||
listen_address: str = Field(default="")
|
listen_address: str = Field(default="")
|
||||||
secret_key: str
|
secret_key: str
|
||||||
port: int = DEFAULT_LISTEN_PORT
|
port: int = DEFAULT_LISTEN_PORT
|
||||||
admin_db: str = Field(default=DEFAULT_DATABASE)
|
|
||||||
|
database: str = Field(default=DEFAULT_DATABASE)
|
||||||
|
#admin_db: str = Field(default=DEFAULT_DATABASE)
|
||||||
debug: bool = False
|
debug: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def admin_db(self) -> URL:
|
||||||
|
"""Construct database url."""
|
||||||
|
return URL.create(drivername="sqlite", database=self.database)
|
||||||
@ -1,240 +0,0 @@
|
|||||||
"""Frontend methods."""
|
|
||||||
|
|
||||||
# pyright: reportUnusedFunction=false
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from datetime import timedelta
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, Response, status
|
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
from sshecret_admin.settings import AdminServerSettings
|
|
||||||
from sshecret.backend import SshecretBackend
|
|
||||||
from .admin_backend import AdminBackend
|
|
||||||
from .auth_models import (
|
|
||||||
JWT_ALGORITHM,
|
|
||||||
AuthenticationFailedError,
|
|
||||||
AuthenticationNeededError,
|
|
||||||
LoginError,
|
|
||||||
PasswordDB,
|
|
||||||
User,
|
|
||||||
TokenData,
|
|
||||||
create_access_token,
|
|
||||||
verify_password,
|
|
||||||
)
|
|
||||||
from .types import DBSessionDep
|
|
||||||
from .views import create_audit_view, create_client_view, create_secrets_view
|
|
||||||
|
|
||||||
|
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 45
|
|
||||||
LOG = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def login_error(templates: Jinja2Blocks, request: Request):
|
|
||||||
"""Return a login error."""
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
request,
|
|
||||||
"login.html",
|
|
||||||
{
|
|
||||||
"page_title": "Login",
|
|
||||||
"page_description": "Login Page",
|
|
||||||
"error": "Invalid Login.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_frontend(
|
|
||||||
settings: AdminServerSettings, get_db_session: DBSessionDep
|
|
||||||
) -> APIRouter:
|
|
||||||
"""Create frontend."""
|
|
||||||
app = APIRouter(include_in_schema=False)
|
|
||||||
|
|
||||||
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
|
||||||
|
|
||||||
template_path = script_path / "templates"
|
|
||||||
|
|
||||||
templates = Jinja2Blocks(directory=template_path)
|
|
||||||
|
|
||||||
# @app.exception_handler(AuthenticationFailedError)
|
|
||||||
# async def handle_authentication_failed(request: Request, exc: AuthenticationFailedError):
|
|
||||||
# """Handle authentication failed error."""
|
|
||||||
# return templates.TemplateResponse(request, "login.html")
|
|
||||||
|
|
||||||
async def get_backend():
|
|
||||||
"""Get backend client."""
|
|
||||||
backend_client = SshecretBackend(
|
|
||||||
str(settings.backend_url), settings.backend_token
|
|
||||||
)
|
|
||||||
yield backend_client
|
|
||||||
|
|
||||||
async def get_admin_backend(session: Annotated[Session, Depends(get_db_session)]):
|
|
||||||
"""Get admin backend API."""
|
|
||||||
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
|
||||||
if not password_db:
|
|
||||||
raise HTTPException(
|
|
||||||
500, detail="Error: The password manager has not yet been set up."
|
|
||||||
)
|
|
||||||
admin = AdminBackend(settings, password_db.encrypted_password)
|
|
||||||
yield admin
|
|
||||||
|
|
||||||
async def get_login_status(
|
|
||||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
|
||||||
) -> bool:
|
|
||||||
"""Get login status."""
|
|
||||||
token = request.cookies.get("access_token")
|
|
||||||
if not token:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
|
||||||
username = payload.get("sub")
|
|
||||||
if not username:
|
|
||||||
return False
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
return False
|
|
||||||
token_data = TokenData(username=username)
|
|
||||||
user = session.exec(
|
|
||||||
select(User).where(User.username == token_data.username)
|
|
||||||
).first()
|
|
||||||
if not user:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def get_current_user_from_token(
|
|
||||||
request: Request, session: Annotated[Session, Depends(get_db_session)]
|
|
||||||
) -> User:
|
|
||||||
credentials_exception = AuthenticationNeededError()
|
|
||||||
"""Get current user from token."""
|
|
||||||
token = request.cookies.get("access_token")
|
|
||||||
if not token:
|
|
||||||
raise credentials_exception
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.secret_key, algorithms=[JWT_ALGORITHM])
|
|
||||||
username = payload.get("sub")
|
|
||||||
if not username:
|
|
||||||
raise credentials_exception
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
raise credentials_exception
|
|
||||||
token_data = TokenData(username=username)
|
|
||||||
user = session.exec(
|
|
||||||
select(User).where(User.username == token_data.username)
|
|
||||||
).first()
|
|
||||||
if not user:
|
|
||||||
raise credentials_exception
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def get_index(
|
|
||||||
request: Request,
|
|
||||||
login_status: Annotated[bool, Depends(get_login_status)],
|
|
||||||
error_title: str | None = None,
|
|
||||||
error_message: str | None = None,
|
|
||||||
):
|
|
||||||
"""Get index."""
|
|
||||||
if login_status:
|
|
||||||
return RedirectResponse("/dashboard")
|
|
||||||
login_error: LoginError | None = None
|
|
||||||
if error_title and error_message:
|
|
||||||
login_error = LoginError(title=error_title, message=error_message)
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
request,
|
|
||||||
"login.html",
|
|
||||||
{
|
|
||||||
"page_title": "Login",
|
|
||||||
"page_description": "Login page.",
|
|
||||||
"login_error": login_error,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.post("/")
|
|
||||||
async def post_index(
|
|
||||||
request: Request,
|
|
||||||
error_title: str | None = None,
|
|
||||||
error_message: str | None = None,
|
|
||||||
):
|
|
||||||
"""Get index."""
|
|
||||||
login_error: LoginError | None = None
|
|
||||||
if error_title and error_message:
|
|
||||||
login_error = LoginError(title=error_title, message=error_message)
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
request,
|
|
||||||
"login.html",
|
|
||||||
{
|
|
||||||
"page_title": "Login",
|
|
||||||
"page_description": "Login page.",
|
|
||||||
"login_error": login_error,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.post("/login")
|
|
||||||
async def login_user(
|
|
||||||
response: Response,
|
|
||||||
request: Request,
|
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
|
||||||
username: Annotated[str, Form()],
|
|
||||||
password: Annotated[str, Form()],
|
|
||||||
):
|
|
||||||
"""Log in user."""
|
|
||||||
user = session.exec(select(User).where(User.username == username)).first()
|
|
||||||
auth_error = AuthenticationFailedError()
|
|
||||||
if not user:
|
|
||||||
raise auth_error
|
|
||||||
|
|
||||||
if not verify_password(password, user.hashed_password):
|
|
||||||
raise auth_error
|
|
||||||
|
|
||||||
token_data = {"sub": user.username}
|
|
||||||
expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
||||||
token = create_access_token(settings, token_data, expires_delta=expires)
|
|
||||||
response = RedirectResponse(url="/dashboard", status_code=status.HTTP_302_FOUND)
|
|
||||||
response.set_cookie(
|
|
||||||
key="access_token", value=token, httponly=True, secure=False, samesite="lax"
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
@app.get("/success")
|
|
||||||
async def success_page(
|
|
||||||
request: Request,
|
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
|
||||||
):
|
|
||||||
"""Display a success page."""
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
request, "success.html", {"page_title": "Success!", "user": current_user}
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/dashboard")
|
|
||||||
async def get_dashboard(
|
|
||||||
request: Request,
|
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
|
||||||
):
|
|
||||||
"""Dashboard for mocking up the dashboard."""
|
|
||||||
# secrets = await admin.get_secrets()
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
request,
|
|
||||||
"dashboard.html",
|
|
||||||
{
|
|
||||||
"page_title": "sshecret",
|
|
||||||
"user": current_user.username,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stop adding routes here.
|
|
||||||
|
|
||||||
app.include_router(
|
|
||||||
create_client_view(templates, get_current_user_from_token, get_admin_backend)
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(
|
|
||||||
create_secrets_view(templates, get_current_user_from_token, get_admin_backend)
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(
|
|
||||||
create_audit_view(templates, get_current_user_from_token, get_admin_backend)
|
|
||||||
)
|
|
||||||
|
|
||||||
return app
|
|
||||||
@ -0,0 +1,5 @@
|
|||||||
|
"""Frontend app."""
|
||||||
|
|
||||||
|
from .router import create_router as create_frontend_router
|
||||||
|
|
||||||
|
__all__ = ["create_frontend_router"]
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
"""Custom oauth2 class."""
|
||||||
|
|
||||||
|
from fastapi.security import OAuth2
|
||||||
|
|
||||||
|
|
||||||
|
class Oauth2TokenInCookies(OAuth2):
|
||||||
|
"""TODO: Create this."""
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
"""Frontend dependencies."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections.abc import Callable, Awaitable
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
|
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||||
|
from fastapi import Request
|
||||||
|
from sqlmodel import Session
|
||||||
|
|
||||||
|
from sshecret_admin.core.dependencies import AdminDep, BaseDependencies
|
||||||
|
|
||||||
|
from sshecret_admin.auth.models import User
|
||||||
|
|
||||||
|
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
||||||
|
UserLoginDep = Callable[[Request, Session], Awaitable[bool]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FrontendDependencies(BaseDependencies):
|
||||||
|
"""Frontend dependencies."""
|
||||||
|
|
||||||
|
get_admin_backend: AdminDep
|
||||||
|
templates: Jinja2Blocks
|
||||||
|
get_user_from_access_token: UserTokenDep
|
||||||
|
get_user_from_refresh_token: UserTokenDep
|
||||||
|
get_login_status: UserLoginDep
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
deps: BaseDependencies,
|
||||||
|
get_admin_backend: AdminDep,
|
||||||
|
templates: Jinja2Blocks,
|
||||||
|
get_user_from_access_token: UserTokenDep,
|
||||||
|
get_user_from_refresh_token: UserTokenDep,
|
||||||
|
get_login_status: UserLoginDep,
|
||||||
|
) -> Self:
|
||||||
|
"""Create from base dependencies."""
|
||||||
|
return cls(
|
||||||
|
settings=deps.settings,
|
||||||
|
get_db_session=deps.get_db_session,
|
||||||
|
get_admin_backend=get_admin_backend,
|
||||||
|
templates=templates,
|
||||||
|
get_user_from_access_token=get_user_from_access_token,
|
||||||
|
get_user_from_refresh_token=get_user_from_refresh_token,
|
||||||
|
get_login_status=get_login_status,
|
||||||
|
)
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
"""Frontend exceptions."""
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
|
||||||
|
class RedirectException(Exception):
|
||||||
|
"""Exception that initiates a redirect flow."""
|
||||||
|
|
||||||
|
def __init__(self, to: str | URL) -> None: # pyright: ignore[reportMissingSuperCall]
|
||||||
|
"""Raise exception that redirects."""
|
||||||
|
if isinstance(to, str):
|
||||||
|
to = URL(to)
|
||||||
|
|
||||||
|
self.to: URL = to
|
||||||
133
packages/sshecret-admin/src/sshecret_admin/frontend/router.py
Normal file
133
packages/sshecret-admin/src/sshecret_admin/frontend/router.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
"""Frontend router."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
|
||||||
|
from jinja2_fragments.fastapi import Jinja2Blocks
|
||||||
|
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
|
||||||
|
from sshecret_admin.auth import PasswordDB, User, decode_token
|
||||||
|
from sshecret_admin.core.dependencies import BaseDependencies
|
||||||
|
from sshecret_admin.services.admin_backend import AdminBackend
|
||||||
|
|
||||||
|
from .dependencies import FrontendDependencies
|
||||||
|
from .exceptions import RedirectException
|
||||||
|
from .views import audit, auth, clients, index, secrets
|
||||||
|
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
access_token = "access_token"
|
||||||
|
refresh_token = "refresh_token"
|
||||||
|
|
||||||
|
|
||||||
|
def create_router(dependencies: BaseDependencies) -> APIRouter:
|
||||||
|
"""Create frontend router."""
|
||||||
|
|
||||||
|
app = APIRouter(include_in_schema=False)
|
||||||
|
|
||||||
|
script_path = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|
||||||
|
template_path = script_path / "templates"
|
||||||
|
|
||||||
|
templates = Jinja2Blocks(directory=template_path)
|
||||||
|
|
||||||
|
async def get_admin_backend(
|
||||||
|
session: Annotated[Session, Depends(dependencies.get_db_session)]
|
||||||
|
):
|
||||||
|
"""Get admin backend API."""
|
||||||
|
password_db = session.exec(select(PasswordDB).where(PasswordDB.id == 1)).first()
|
||||||
|
if not password_db:
|
||||||
|
raise HTTPException(
|
||||||
|
500, detail="Error: The password manager has not yet been set up."
|
||||||
|
)
|
||||||
|
admin = AdminBackend(dependencies.settings, password_db.encrypted_password)
|
||||||
|
yield admin
|
||||||
|
|
||||||
|
async def get_user_from_token(
|
||||||
|
token: str,
|
||||||
|
session: Session,
|
||||||
|
) -> User | None:
|
||||||
|
"""Get user from a token."""
|
||||||
|
token_data = decode_token(dependencies.settings, token)
|
||||||
|
if not token_data:
|
||||||
|
return None
|
||||||
|
user = session.exec(
|
||||||
|
select(User).where(User.username == token_data.username)
|
||||||
|
).first()
|
||||||
|
if not user or user.disabled:
|
||||||
|
return None
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_user_from_refresh_token(
|
||||||
|
request: Request,
|
||||||
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
|
) -> User:
|
||||||
|
"""Get user from refresh token."""
|
||||||
|
next = URL("/login").include_query_params(next=request.url.path)
|
||||||
|
credentials_error = RedirectException(to=next)
|
||||||
|
token = request.cookies.get("refresh_token")
|
||||||
|
if not token:
|
||||||
|
raise credentials_error
|
||||||
|
|
||||||
|
user = await get_user_from_token(token, session)
|
||||||
|
if not user:
|
||||||
|
raise credentials_error
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_user_from_access_token(
|
||||||
|
request: Request,
|
||||||
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
|
) -> User:
|
||||||
|
"""Get user from access token."""
|
||||||
|
token = request.cookies.get("access_token")
|
||||||
|
next = URL("/refresh").include_query_params(next=request.url.path)
|
||||||
|
credentials_error = RedirectException(to=next)
|
||||||
|
if not token:
|
||||||
|
raise credentials_error
|
||||||
|
|
||||||
|
user = await get_user_from_token(token, session)
|
||||||
|
if not user:
|
||||||
|
raise credentials_error
|
||||||
|
return user
|
||||||
|
|
||||||
|
async def get_login_status(
|
||||||
|
request: Request,
|
||||||
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
|
) -> bool:
|
||||||
|
"""Get login status."""
|
||||||
|
token = request.cookies.get("access_token")
|
||||||
|
if not token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
user = await get_user_from_token(token, session)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
view_dependencies = FrontendDependencies.create(
|
||||||
|
dependencies,
|
||||||
|
get_admin_backend,
|
||||||
|
templates,
|
||||||
|
get_user_from_access_token,
|
||||||
|
get_user_from_refresh_token,
|
||||||
|
get_login_status,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(audit.create_router(view_dependencies))
|
||||||
|
app.include_router(auth.create_router(view_dependencies))
|
||||||
|
app.include_router(clients.create_router(view_dependencies))
|
||||||
|
app.include_router(index.create_router(view_dependencies))
|
||||||
|
app.include_router(secrets.create_router(view_dependencies))
|
||||||
|
|
||||||
|
return app
|
||||||
@ -0,0 +1,89 @@
|
|||||||
|
{% extends "/dashboard/_base.html" %} {% block content %}
|
||||||
|
|
||||||
|
<div
|
||||||
|
class="p-4 bg-white block sm:flex items-center justify-between border-b border-gray-200 lg:mt-1.5 dark:bg-gray-800 dark:border-gray-700"
|
||||||
|
>
|
||||||
|
<h1 class="mb-4 text-4xl font-extrabold leading-none tracking-tight text-gray-900 md:text-5xl lg:text-6xl dark:text-white">Welcome to Sshecret</h1>
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="grid w-full grid-cols-1 gap-4 mt-4 xl:grid-cols-2 2xl:grid-cols-3">
|
||||||
|
<div class="items-center justify-between p-4 bg-white border border-gray-200 rounded-lg shadow-sm sm:flex dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||||
|
<div class="w-full">
|
||||||
|
<h3 class="text-base font-normal text-gray-500 dark:text-gray-400">Clients</h3>
|
||||||
|
<span class="text-2xl font-bold leading-none text-gray-900 sm:text-3xl dark:text-white">{{ stats.clients }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- reference -->
|
||||||
|
|
||||||
|
<div class="grid w-full grid-cols-1 gap-4 mt-4 xl:grid-cols-2 2xl:grid-cols-3">
|
||||||
|
<div class="items-center justify-between p-4 bg-white border border-gray-200 rounded-lg shadow-sm sm:flex dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||||
|
<div class="w-full">
|
||||||
|
<h3 class="text-base font-normal text-gray-500 dark:text-gray-400">New products</h3>
|
||||||
|
<span class="text-2xl font-bold leading-none text-gray-900 sm:text-3xl dark:text-white">2,340</span>
|
||||||
|
<p class="flex items-center text-base font-normal text-gray-500 dark:text-gray-400">
|
||||||
|
<span class="flex items-center mr-1.5 text-sm text-green-500 dark:text-green-400">
|
||||||
|
<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
|
||||||
|
<path clip-rule="evenodd" fill-rule="evenodd" d="M10 17a.75.75 0 01-.75-.75V5.612L5.29 9.77a.75.75 0 01-1.08-1.04l5.25-5.5a.75.75 0 011.08 0l5.25 5.5a.75.75 0 11-1.08 1.04l-3.96-4.158V16.25A.75.75 0 0110 17z"></path>
|
||||||
|
</svg>
|
||||||
|
12.5%
|
||||||
|
</span>
|
||||||
|
Since last month
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="w-full" id="new-products-chart"></div>
|
||||||
|
</div>
|
||||||
|
<div class="items-center justify-between p-4 bg-white border border-gray-200 rounded-lg shadow-sm sm:flex dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||||
|
<div class="w-full">
|
||||||
|
<h3 class="text-base font-normal text-gray-500 dark:text-gray-400">Users</h3>
|
||||||
|
<span class="text-2xl font-bold leading-none text-gray-900 sm:text-3xl dark:text-white">2,340</span>
|
||||||
|
<p class="flex items-center text-base font-normal text-gray-500 dark:text-gray-400">
|
||||||
|
<span class="flex items-center mr-1.5 text-sm text-green-500 dark:text-green-400">
|
||||||
|
<svg class="w-4 h-4" fill="currentColor" viewBox="0 0 20 20" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
|
||||||
|
<path clip-rule="evenodd" fill-rule="evenodd" d="M10 17a.75.75 0 01-.75-.75V5.612L5.29 9.77a.75.75 0 01-1.08-1.04l5.25-5.5a.75.75 0 011.08 0l5.25 5.5a.75.75 0 11-1.08 1.04l-3.96-4.158V16.25A.75.75 0 0110 17z"></path>
|
||||||
|
</svg>
|
||||||
|
3,4%
|
||||||
|
</span>
|
||||||
|
Since last month
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<div class="w-full" id="week-signups-chart"></div>
|
||||||
|
</div>
|
||||||
|
<div class="p-4 bg-white border border-gray-200 rounded-lg shadow-sm dark:border-gray-700 sm:p-6 dark:bg-gray-800">
|
||||||
|
<div class="w-full">
|
||||||
|
<h3 class="mb-2 text-base font-normal text-gray-500 dark:text-gray-400">Audience by age</h3>
|
||||||
|
<div class="flex items-center mb-2">
|
||||||
|
<div class="w-16 text-sm font-medium dark:text-white">50+</div>
|
||||||
|
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||||
|
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 18%"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center mb-2">
|
||||||
|
<div class="w-16 text-sm font-medium dark:text-white">40+</div>
|
||||||
|
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||||
|
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 15%"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center mb-2">
|
||||||
|
<div class="w-16 text-sm font-medium dark:text-white">30+</div>
|
||||||
|
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||||
|
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 60%"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center mb-2">
|
||||||
|
<div class="w-16 text-sm font-medium dark:text-white">20+</div>
|
||||||
|
<div class="w-full bg-gray-200 rounded-full h-2.5 dark:bg-gray-700">
|
||||||
|
<div class="bg-primary-600 h-2.5 rounded-full dark:bg-primary-500" style="width: 30%"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div id="traffic-channels-chart" class="w-full"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
{% endblock %}
|
||||||
@ -0,0 +1 @@
|
|||||||
|
"""Frontend views."""
|
||||||
@ -1,19 +1,20 @@
|
|||||||
"""Audit view."""
|
"""Audit view factory."""
|
||||||
# pyright: reportUnusedFunction=false
|
|
||||||
|
|
||||||
import math
|
# pyright: reportUnusedFunction=false
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import APIRouter, Depends, Request, Response
|
from fastapi import APIRouter, Depends, Request, Response
|
||||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sshecret_admin.admin_backend import AdminBackend
|
from sshecret_admin.auth import User
|
||||||
from sshecret_admin.types import UserTokenDep, AdminDep
|
from sshecret_admin.services import AdminBackend
|
||||||
from sshecret_admin.auth_models import User
|
|
||||||
|
from ..dependencies import FrontendDependencies
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PagingInfo(BaseModel):
|
class PagingInfo(BaseModel):
|
||||||
|
|
||||||
page: int
|
page: int
|
||||||
@ -36,20 +37,15 @@ class PagingInfo(BaseModel):
|
|||||||
"""Return total pages."""
|
"""Return total pages."""
|
||||||
return math.ceil(self.total / self.limit)
|
return math.ceil(self.total / self.limit)
|
||||||
|
|
||||||
def create_audit_view(
|
|
||||||
templates: Jinja2Blocks,
|
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||||
get_current_user_from_token: UserTokenDep,
|
"""Create clients router."""
|
||||||
get_admin_backend: AdminDep,
|
|
||||||
) -> APIRouter:
|
|
||||||
"""Create client view."""
|
|
||||||
|
|
||||||
app = APIRouter()
|
app = APIRouter()
|
||||||
|
templates = dependencies.templates
|
||||||
|
|
||||||
async def resolve_audit_entries(
|
async def resolve_audit_entries(
|
||||||
request: Request,
|
request: Request, current_user: User, admin: AdminBackend, page: int
|
||||||
current_user: User,
|
|
||||||
admin: AdminBackend,
|
|
||||||
page: int
|
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Resolve audit entries."""
|
"""Resolve audit entries."""
|
||||||
LOG.info("Page: %r", page)
|
LOG.info("Page: %r", page)
|
||||||
@ -61,7 +57,9 @@ def create_audit_view(
|
|||||||
|
|
||||||
entries = await admin.get_audit_log(offset=offset, limit=per_page)
|
entries = await admin.get_audit_log(offset=offset, limit=per_page)
|
||||||
LOG.info("Entries: %r", entries)
|
LOG.info("Entries: %r", entries)
|
||||||
page_info = PagingInfo(page=page, limit=per_page, total=total_messages, offset=offset)
|
page_info = PagingInfo(
|
||||||
|
page=page, limit=per_page, total=total_messages, offset=offset
|
||||||
|
)
|
||||||
if request.headers.get("HX-Request"):
|
if request.headers.get("HX-Request"):
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
@ -69,8 +67,7 @@ def create_audit_view(
|
|||||||
{
|
{
|
||||||
"entries": entries,
|
"entries": entries,
|
||||||
"page_info": page_info,
|
"page_info": page_info,
|
||||||
}
|
},
|
||||||
|
|
||||||
)
|
)
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
request,
|
request,
|
||||||
@ -80,34 +77,27 @@ def create_audit_view(
|
|||||||
"entries": entries,
|
"entries": entries,
|
||||||
"user": current_user.username,
|
"user": current_user.username,
|
||||||
"page_info": page_info,
|
"page_info": page_info,
|
||||||
|
},
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/audit/")
|
@app.get("/audit/")
|
||||||
async def get_audit_entries(
|
async def get_audit_entries(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
):
|
) -> Response:
|
||||||
"""Get audit entries."""
|
"""Get audit entries."""
|
||||||
return await resolve_audit_entries(request, current_user, admin, 1)
|
return await resolve_audit_entries(request, current_user, admin, 1)
|
||||||
|
|
||||||
@app.get("/audit/page/{page}")
|
@app.get("/audit/page/{page}")
|
||||||
async def get_audit_entries_page(
|
async def get_audit_entries_page(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
page: int,
|
page: int,
|
||||||
):
|
) -> Response:
|
||||||
"""Get audit entries."""
|
"""Get audit entries."""
|
||||||
LOG.info("Get audit entries page: %r", page)
|
LOG.info("Get audit entries page: %r", page)
|
||||||
return await resolve_audit_entries(request, current_user, admin, page)
|
return await resolve_audit_entries(request, current_user, admin, page)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# --------------#
|
|
||||||
# END OF ROUTES #
|
|
||||||
# --------------#
|
|
||||||
return app
|
return app
|
||||||
@ -0,0 +1,143 @@
|
|||||||
|
"""Authentication related views factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
import logging
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request, Response, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from sqlmodel import Session
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from sshecret_admin.auth import (
|
||||||
|
User,
|
||||||
|
authenticate_user,
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..dependencies import FrontendDependencies
|
||||||
|
from ..exceptions import RedirectException
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginError(BaseModel):
|
||||||
|
"""Login error."""
|
||||||
|
|
||||||
|
title: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||||
|
"""Create auth router."""
|
||||||
|
|
||||||
|
app = APIRouter()
|
||||||
|
templates = dependencies.templates
|
||||||
|
|
||||||
|
@app.get("/login")
|
||||||
|
async def get_login(
|
||||||
|
request: Request,
|
||||||
|
login_status: Annotated[bool, Depends(dependencies.get_login_status)],
|
||||||
|
error_title: str | None = None,
|
||||||
|
error_message: str | None = None,
|
||||||
|
):
|
||||||
|
"""Get index."""
|
||||||
|
if login_status:
|
||||||
|
return RedirectResponse("/dashboard")
|
||||||
|
login_error: LoginError | None = None
|
||||||
|
if error_title and error_message:
|
||||||
|
login_error = LoginError(title=error_title, message=error_message)
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"login.html",
|
||||||
|
{
|
||||||
|
"page_title": "Login",
|
||||||
|
"page_description": "Login page.",
|
||||||
|
"login_error": login_error,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.post("/login")
|
||||||
|
async def login_user(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
session: Annotated[Session, Depends(dependencies.get_db_session)],
|
||||||
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
|
next: Annotated[str, Query()] = "/dashboard",
|
||||||
|
error_title: str | None = None,
|
||||||
|
error_message: str | None = None,
|
||||||
|
):
|
||||||
|
"""Log in user."""
|
||||||
|
if error_title and error_message:
|
||||||
|
login_error = LoginError(title=error_title, message=error_message)
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"login.html",
|
||||||
|
{
|
||||||
|
"page_title": "Login",
|
||||||
|
"page_description": "Login page.",
|
||||||
|
"login_error": login_error,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
user = authenticate_user(session, form_data.username, form_data.password)
|
||||||
|
login_failed = RedirectException(
|
||||||
|
to=URL("/login").include_query_params(
|
||||||
|
error_title="Login Error", error_message="Invalid username or password"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if not user:
|
||||||
|
raise login_failed
|
||||||
|
token_data: dict[str, str] = {"sub": user.username}
|
||||||
|
access_token = create_access_token(dependencies.settings, data=token_data)
|
||||||
|
refresh_token = create_refresh_token(dependencies.settings, data=token_data)
|
||||||
|
response = RedirectResponse(url=next, status_code=status.HTTP_302_FOUND)
|
||||||
|
response.set_cookie(
|
||||||
|
"access_token",
|
||||||
|
value=access_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False,
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
"refresh_token",
|
||||||
|
value=refresh_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False,
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.get("/refresh")
|
||||||
|
async def get_refresh_token(
|
||||||
|
response: Response,
|
||||||
|
user: Annotated[User, Depends(dependencies.get_user_from_refresh_token)],
|
||||||
|
next: Annotated[str, Query()],
|
||||||
|
):
|
||||||
|
"""Refresh tokens.
|
||||||
|
|
||||||
|
We might as well refresh the long-lived one here.
|
||||||
|
"""
|
||||||
|
token_data: dict[str, str] = {"sub": user.username}
|
||||||
|
access_token = create_access_token(dependencies.settings, data=token_data)
|
||||||
|
refresh_token = create_refresh_token(dependencies.settings, data=token_data)
|
||||||
|
response = RedirectResponse(url=next, status_code=status.HTTP_302_FOUND)
|
||||||
|
response.set_cookie(
|
||||||
|
"access_token",
|
||||||
|
value=access_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False,
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
"refresh_token",
|
||||||
|
value=refresh_token,
|
||||||
|
httponly=True,
|
||||||
|
secure=False,
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
return app
|
||||||
@ -1,21 +1,20 @@
|
|||||||
"""Client views."""
|
"""clients view factory."""
|
||||||
|
|
||||||
# pyright: reportUnusedFunction=false
|
# pyright: reportUnusedFunction=false
|
||||||
|
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import APIRouter, Depends, Request, Form
|
from fastapi import APIRouter, Depends, Form, Request, Response
|
||||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
|
||||||
|
|
||||||
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
from pydantic import BaseModel, IPvAnyAddress, IPvAnyNetwork
|
||||||
from sshecret_admin.admin_backend import AdminBackend
|
|
||||||
from sshecret.backend import ClientFilter
|
from sshecret.backend import ClientFilter
|
||||||
from sshecret.backend.models import FilterType
|
from sshecret.backend.models import FilterType
|
||||||
from sshecret.crypto import validate_public_key
|
from sshecret.crypto import validate_public_key
|
||||||
from sshecret_admin.types import UserTokenDep, AdminDep
|
from sshecret_admin.auth import User
|
||||||
from sshecret_admin.auth_models import User
|
from sshecret_admin.services import AdminBackend
|
||||||
|
|
||||||
|
from ..dependencies import FrontendDependencies
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -37,21 +36,18 @@ class ClientCreate(BaseModel):
|
|||||||
sources: str | None
|
sources: str | None
|
||||||
|
|
||||||
|
|
||||||
def create_client_view(
|
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||||
templates: Jinja2Blocks,
|
"""Create clients router."""
|
||||||
get_current_user_from_token: UserTokenDep,
|
|
||||||
get_admin_backend: AdminDep,
|
|
||||||
) -> APIRouter:
|
|
||||||
"""Create client view."""
|
|
||||||
|
|
||||||
app = APIRouter()
|
app = APIRouter()
|
||||||
|
templates = dependencies.templates
|
||||||
|
|
||||||
@app.get("/clients")
|
@app.get("/clients")
|
||||||
async def get_clients(
|
async def get_clients(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
):
|
) -> Response:
|
||||||
"""Get clients."""
|
"""Get clients."""
|
||||||
clients = await admin.get_clients()
|
clients = await admin.get_clients()
|
||||||
LOG.info("Clients %r", clients)
|
LOG.info("Clients %r", clients)
|
||||||
@ -68,10 +64,12 @@ def create_client_view(
|
|||||||
@app.post("/clients/query")
|
@app.post("/clients/query")
|
||||||
async def query_clients(
|
async def query_clients(
|
||||||
request: Request,
|
request: Request,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
query: Annotated[str, Form()],
|
query: Annotated[str, Form()],
|
||||||
):
|
) -> Response:
|
||||||
"""Query for a client."""
|
"""Query for a client."""
|
||||||
query_filter: ClientFilter | None = None
|
query_filter: ClientFilter | None = None
|
||||||
if query:
|
if query:
|
||||||
@ -90,8 +88,10 @@ def create_client_view(
|
|||||||
async def update_client(
|
async def update_client(
|
||||||
request: Request,
|
request: Request,
|
||||||
id: str,
|
id: str,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
client: Annotated[ClientUpdate, Form()],
|
client: Annotated[ClientUpdate, Form()],
|
||||||
):
|
):
|
||||||
"""Update a client."""
|
"""Update a client."""
|
||||||
@ -135,9 +135,11 @@ def create_client_view(
|
|||||||
async def delete_client(
|
async def delete_client(
|
||||||
request: Request,
|
request: Request,
|
||||||
id: str,
|
id: str,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
):
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
) -> Response:
|
||||||
"""Delete a client."""
|
"""Delete a client."""
|
||||||
await admin.delete_client(id)
|
await admin.delete_client(id)
|
||||||
clients = await admin.get_clients()
|
clients = await admin.get_clients()
|
||||||
@ -154,10 +156,12 @@ def create_client_view(
|
|||||||
@app.post("/clients/")
|
@app.post("/clients/")
|
||||||
async def create_client(
|
async def create_client(
|
||||||
request: Request,
|
request: Request,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
client: Annotated[ClientCreate, Form()],
|
client: Annotated[ClientCreate, Form()],
|
||||||
):
|
) -> Response:
|
||||||
"""Create client."""
|
"""Create client."""
|
||||||
sources: list[str] | None = None
|
sources: list[str] | None = None
|
||||||
if client.sources:
|
if client.sources:
|
||||||
@ -179,9 +183,11 @@ def create_client_view(
|
|||||||
@app.post("/clients/validate/source")
|
@app.post("/clients/validate/source")
|
||||||
async def validate_client_source(
|
async def validate_client_source(
|
||||||
request: Request,
|
request: Request,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
sources: Annotated[str, Form()],
|
sources: Annotated[str, Form()],
|
||||||
):
|
) -> Response:
|
||||||
"""Validate source."""
|
"""Validate source."""
|
||||||
source_str = sources.split(",")
|
source_str = sources.split(",")
|
||||||
for source in source_str:
|
for source in source_str:
|
||||||
@ -211,9 +217,11 @@ def create_client_view(
|
|||||||
@app.post("/clients/validate/public_key")
|
@app.post("/clients/validate/public_key")
|
||||||
async def validate_client_public_key(
|
async def validate_client_public_key(
|
||||||
request: Request,
|
request: Request,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
public_key: Annotated[str, Form()],
|
public_key: Annotated[str, Form()],
|
||||||
):
|
) -> Response:
|
||||||
"""Validate source."""
|
"""Validate source."""
|
||||||
if validate_public_key(public_key.rstrip()):
|
if validate_public_key(public_key.rstrip()):
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
"""Front page view factory."""
|
||||||
|
|
||||||
|
# pyright: reportUnusedFunction=false
|
||||||
|
import logging
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sshecret_admin.auth import User
|
||||||
|
from sshecret_admin.services import AdminBackend
|
||||||
|
|
||||||
|
from ..dependencies import FrontendDependencies
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
START_PAGE = "/dashboard"
|
||||||
|
LOGIN_PAGE = "/login"
|
||||||
|
|
||||||
|
|
||||||
|
class StatsView(BaseModel):
|
||||||
|
"""Stats for the frontend."""
|
||||||
|
|
||||||
|
clients: int = 0
|
||||||
|
secrets: int = 0
|
||||||
|
audit_events: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
async def get_stats(admin: AdminBackend) -> StatsView:
|
||||||
|
"""Get stats for the frontpage."""
|
||||||
|
clients = await admin.get_clients()
|
||||||
|
secrets = await admin.get_secrets()
|
||||||
|
audit = await admin.get_audit_log_count()
|
||||||
|
return StatsView(clients=len(clients), secrets=len(secrets), audit_events=audit)
|
||||||
|
|
||||||
|
|
||||||
|
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||||
|
"""Create auth router."""
|
||||||
|
|
||||||
|
app = APIRouter()
|
||||||
|
templates = dependencies.templates
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def get_index(logged_in: Annotated[bool, Depends(dependencies.get_login_status)]):
|
||||||
|
"""Get the index."""
|
||||||
|
next = LOGIN_PAGE
|
||||||
|
if logged_in:
|
||||||
|
next = START_PAGE
|
||||||
|
|
||||||
|
return RedirectResponse(url=next)
|
||||||
|
|
||||||
|
@app.get("/dashboard")
|
||||||
|
async def get_dashboard(
|
||||||
|
request: Request,
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
|
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||||
|
):
|
||||||
|
"""Dashboard for mocking up the dashboard."""
|
||||||
|
stats = await get_stats(admin)
|
||||||
|
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
request,
|
||||||
|
"dashboard.html",
|
||||||
|
{
|
||||||
|
"page_title": "sshecret",
|
||||||
|
"user": current_user.username,
|
||||||
|
"stats": stats,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
@ -1,24 +1,25 @@
|
|||||||
"""Secrets view."""
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
|
||||||
# pyright: reportUnusedFunction=false
|
# pyright: reportUnusedFunction=false
|
||||||
import logging
|
import logging
|
||||||
import secrets as pysecrets
|
import secrets as pysecrets
|
||||||
from typing import Annotated, Any
|
from typing import Annotated, Any
|
||||||
from fastapi import APIRouter, Depends, Request, Form
|
from fastapi import APIRouter, Depends, Form, Request
|
||||||
from jinja2_fragments.fastapi import Jinja2Blocks
|
|
||||||
|
|
||||||
from pydantic import BaseModel, BeforeValidator, Field
|
from pydantic import BaseModel, BeforeValidator, Field
|
||||||
from sshecret_admin.admin_backend import AdminBackend
|
|
||||||
from sshecret_admin.types import UserTokenDep, AdminDep
|
from sshecret_admin.auth import User
|
||||||
from sshecret_admin.auth_models import User
|
from sshecret_admin.services import AdminBackend
|
||||||
|
|
||||||
|
from ..dependencies import FrontendDependencies
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def split_clients(clients: Any) -> Any:
|
def split_clients(clients: Any) -> Any: # pyright: ignore[reportAny]
|
||||||
"""Split clients."""
|
"""Split clients."""
|
||||||
if isinstance(clients, list):
|
if isinstance(clients, list):
|
||||||
return clients
|
return clients # pyright: ignore[reportUnknownVariableType]
|
||||||
if not isinstance(clients, str):
|
if not isinstance(clients, str):
|
||||||
raise ValueError("Invalid type for clients.")
|
raise ValueError("Invalid type for clients.")
|
||||||
if not clients:
|
if not clients:
|
||||||
@ -26,7 +27,7 @@ def split_clients(clients: Any) -> Any:
|
|||||||
return [client.rstrip() for client in clients.split(",")]
|
return [client.rstrip() for client in clients.split(",")]
|
||||||
|
|
||||||
|
|
||||||
def handle_select_bool(value: Any) -> Any:
|
def handle_select_bool(value: Any) -> Any: # pyright: ignore[reportAny]
|
||||||
"""Handle boolean from select."""
|
"""Handle boolean from select."""
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return value
|
return value
|
||||||
@ -47,20 +48,17 @@ class CreateSecret(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_secrets_view(
|
def create_router(dependencies: FrontendDependencies) -> APIRouter:
|
||||||
templates: Jinja2Blocks,
|
"""Create secrets router."""
|
||||||
get_current_user_from_token: UserTokenDep,
|
|
||||||
get_admin_backend: AdminDep,
|
|
||||||
) -> APIRouter:
|
|
||||||
"""Create secrets view."""
|
|
||||||
|
|
||||||
app = APIRouter()
|
app = APIRouter()
|
||||||
|
templates = dependencies.templates
|
||||||
|
|
||||||
@app.get("/secrets/")
|
@app.get("/secrets/")
|
||||||
async def get_secrets(
|
async def get_secrets(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
current_user: Annotated[User, Depends(dependencies.get_user_from_access_token)],
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
):
|
):
|
||||||
"""Get secrets index page."""
|
"""Get secrets index page."""
|
||||||
secrets = await admin.get_detailed_secrets()
|
secrets = await admin.get_detailed_secrets()
|
||||||
@ -79,8 +77,10 @@ def create_secrets_view(
|
|||||||
@app.post("/secrets/")
|
@app.post("/secrets/")
|
||||||
async def add_secret(
|
async def add_secret(
|
||||||
request: Request,
|
request: Request,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
secret: Annotated[CreateSecret, Form()],
|
secret: Annotated[CreateSecret, Form()],
|
||||||
):
|
):
|
||||||
"""Add secret."""
|
"""Add secret."""
|
||||||
@ -108,8 +108,10 @@ def create_secrets_view(
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
id: str,
|
id: str,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
):
|
):
|
||||||
"""Remove a client's access to a secret."""
|
"""Remove a client's access to a secret."""
|
||||||
await admin.delete_client_secret(id, name)
|
await admin.delete_client_secret(id, name)
|
||||||
@ -130,8 +132,10 @@ def create_secrets_view(
|
|||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
client: Annotated[str, Form()],
|
client: Annotated[str, Form()],
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
):
|
):
|
||||||
"""Add a secret to a client."""
|
"""Add a secret to a client."""
|
||||||
await admin.create_client_secret(client, name)
|
await admin.create_client_secret(client, name)
|
||||||
@ -153,8 +157,10 @@ def create_secrets_view(
|
|||||||
async def delete_secret(
|
async def delete_secret(
|
||||||
request: Request,
|
request: Request,
|
||||||
name: str,
|
name: str,
|
||||||
_current_user: Annotated[User, Depends(get_current_user_from_token)],
|
_current_user: Annotated[
|
||||||
admin: Annotated[AdminBackend, Depends(get_admin_backend)],
|
User, Depends(dependencies.get_user_from_access_token)
|
||||||
|
],
|
||||||
|
admin: Annotated[AdminBackend, Depends(dependencies.get_admin_backend)],
|
||||||
):
|
):
|
||||||
"""Delete a secret."""
|
"""Delete a secret."""
|
||||||
await admin.delete_secret(name)
|
await admin.delete_secret(name)
|
||||||
@ -172,7 +178,4 @@ def create_secrets_view(
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --------------#
|
|
||||||
# END OF ROUTES #
|
|
||||||
# --------------#
|
|
||||||
return app
|
return app
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
"""Services module.
|
||||||
|
|
||||||
|
This module contains business logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .admin_backend import AdminBackend
|
||||||
|
|
||||||
|
__all__ = ["AdminBackend"]
|
||||||
@ -7,13 +7,14 @@ import logging
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from sshecret.backend import AuditLog, Client, ClientFilter, Secret, SshecretBackend
|
from sshecret.backend import AuditLog, Client, ClientFilter, Secret, SshecretBackend, Operation, SubSystem
|
||||||
from sshecret.backend.models import DetailedSecrets
|
from sshecret.backend.models import DetailedSecrets
|
||||||
|
from sshecret.backend.api import AuditAPI
|
||||||
from sshecret.crypto import encrypt_string, load_public_key
|
from sshecret.crypto import encrypt_string, load_public_key
|
||||||
|
|
||||||
from .keepass import PasswordContext, load_password_manager
|
from .keepass import PasswordContext, load_password_manager
|
||||||
from .settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
from .view_models import SecretView
|
from .models import SecretView
|
||||||
|
|
||||||
|
|
||||||
class ClientManagementError(Exception):
|
class ClientManagementError(Exception):
|
||||||
@ -381,6 +382,11 @@ class AdminBackend:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise BackendUnavailableError() from e
|
raise BackendUnavailableError() from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audit(self) -> AuditAPI:
|
||||||
|
"""Resolve audit API."""
|
||||||
|
return self.backend.audit(SubSystem.ADMIN)
|
||||||
|
|
||||||
async def get_audit_log(
|
async def get_audit_log(
|
||||||
self,
|
self,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
@ -389,14 +395,36 @@ class AdminBackend:
|
|||||||
subsystem: str | None = None,
|
subsystem: str | None = None,
|
||||||
) -> list[AuditLog]:
|
) -> list[AuditLog]:
|
||||||
"""Get audit log from backend."""
|
"""Get audit log from backend."""
|
||||||
return await self.backend.get_audit_log(offset, limit, client_name, subsystem)
|
return await self.audit.get(offset, limit, client_name, subsystem)
|
||||||
|
|
||||||
|
async def write_audit_message(
|
||||||
|
self,
|
||||||
|
operation: Operation,
|
||||||
|
message: str,
|
||||||
|
origin: str,
|
||||||
|
client: Client | None = None,
|
||||||
|
secret_name: str | None = None,
|
||||||
|
**data: str,
|
||||||
|
) -> None:
|
||||||
|
"""Write an audit message."""
|
||||||
|
await self.audit.write_async(
|
||||||
|
operation=operation,
|
||||||
|
message=message,
|
||||||
|
origin=origin,
|
||||||
|
client=client,
|
||||||
|
secret=None,
|
||||||
|
secret_name=secret_name,
|
||||||
|
**data,
|
||||||
|
)
|
||||||
|
|
||||||
async def write_audit_log(self, entry: AuditLog) -> None:
|
async def write_audit_log(self, entry: AuditLog) -> None:
|
||||||
"""Write to the audit log."""
|
"""Write to the audit log."""
|
||||||
if not entry.subsystem:
|
if not entry.subsystem:
|
||||||
entry.subsystem = "admin"
|
entry.subsystem = SubSystem.ADMIN
|
||||||
await self.backend.add_audit_log(entry)
|
|
||||||
|
await self.audit.write_model_async(entry)
|
||||||
|
#await self.backend.add_audit_log(entry)
|
||||||
|
|
||||||
async def get_audit_log_count(self) -> int:
|
async def get_audit_log_count(self) -> int:
|
||||||
"""Get audit log count."""
|
"""Get audit log count."""
|
||||||
return await self.backend.get_audit_log_count()
|
return await self.audit.count()
|
||||||
@ -8,7 +8,7 @@ from typing import cast
|
|||||||
|
|
||||||
import pykeepass
|
import pykeepass
|
||||||
from .master_password import decrypt_master_password
|
from .master_password import decrypt_master_password
|
||||||
from .settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -8,7 +8,7 @@ from sshecret.crypto import (
|
|||||||
encrypt_string,
|
encrypt_string,
|
||||||
decode_string,
|
decode_string,
|
||||||
)
|
)
|
||||||
from .settings import AdminServerSettings
|
from sshecret_admin.core.settings import AdminServerSettings
|
||||||
|
|
||||||
KEY_FILENAME = "sshecret-admin-key"
|
KEY_FILENAME = "sshecret-admin-key"
|
||||||
|
|
||||||
@ -1,7 +1,7 @@
|
|||||||
"""Models for the API."""
|
"""Models for the API."""
|
||||||
|
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Annotated, Literal, Self, Union
|
from typing import Annotated, Literal
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
AfterValidator,
|
AfterValidator,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@ -9,7 +9,6 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
IPvAnyAddress,
|
IPvAnyAddress,
|
||||||
IPvAnyNetwork,
|
IPvAnyNetwork,
|
||||||
model_validator,
|
|
||||||
)
|
)
|
||||||
from sshecret.crypto import validate_public_key
|
from sshecret.crypto import validate_public_key
|
||||||
|
|
||||||
@ -1,10 +0,0 @@
|
|||||||
{% extends "/dashboard/_base.html" %} {% block content %}
|
|
||||||
|
|
||||||
<div
|
|
||||||
class="p-4 bg-white block sm:flex items-center justify-between border-b border-gray-200 lg:mt-1.5 dark:bg-gray-800 dark:border-gray-700"
|
|
||||||
>
|
|
||||||
<h1 class="mb-4 text-4xl font-extrabold leading-none tracking-tight text-gray-900 md:text-5xl lg:text-6xl dark:text-white">Welcome to Sshecret</h1>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
{% endblock %}
|
|
||||||
@ -4,8 +4,7 @@ import os
|
|||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
from sqlalchemy import Engine
|
from sqlmodel import Session
|
||||||
from sqlmodel import Session, select
|
|
||||||
from .auth_models import User
|
from .auth_models import User
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,16 +6,10 @@ from fastapi import Request
|
|||||||
from sqlmodel import Session
|
from sqlmodel import Session
|
||||||
from sshecret_admin.admin_backend import AdminBackend
|
from sshecret_admin.admin_backend import AdminBackend
|
||||||
from sshecret_admin.auth_models import User
|
from sshecret_admin.auth_models import User
|
||||||
from sshecret.backend import SshecretBackend
|
|
||||||
from . import keepass
|
|
||||||
|
|
||||||
|
|
||||||
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
DBSessionDep = Callable[[], Generator[Session, None, None]]
|
||||||
|
|
||||||
BackendDep = Callable[[], AsyncGenerator[SshecretBackend, None]]
|
|
||||||
|
|
||||||
PasswdCtxDep = Callable[[DBSessionDep], AsyncGenerator[keepass.PasswordContext, None]]
|
|
||||||
|
|
||||||
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
AdminDep = Callable[[Session], AsyncGenerator[AdminBackend, None]]
|
||||||
|
|
||||||
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
UserTokenDep = Callable[[Request, Session], Awaitable[User]]
|
||||||
|
|||||||
@ -1,5 +0,0 @@
|
|||||||
from .audit import create_audit_view
|
|
||||||
from .clients import create_client_view
|
|
||||||
from .secrets import create_secrets_view
|
|
||||||
|
|
||||||
__all__ = ["create_audit_view", "create_client_view", "create_secrets_view"]
|
|
||||||
@ -1,24 +0,0 @@
|
|||||||
<!doctype html>
|
|
||||||
<html class="no-js" lang="">
|
|
||||||
<head>
|
|
||||||
<meta charset="utf-8" />
|
|
||||||
<meta http-equiv="x-ua-compatible" content="ie=edge" />
|
|
||||||
<title>Untitled</title>
|
|
||||||
<meta name="description" content="" />
|
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
|
||||||
|
|
||||||
<link rel="apple-touch-icon" href="/apple-touch-icon.png" />
|
|
||||||
<!-- Place favicon.ico in the root directory -->
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<!--[if lt IE 8]>
|
|
||||||
<p class="browserupgrade">
|
|
||||||
You are using an <strong>outdated</strong> browser. Please
|
|
||||||
<a href="http://browsehappy.com/">upgrade your browser</a> to improve
|
|
||||||
your experience.
|
|
||||||
</p>
|
|
||||||
<![endif]-->
|
|
||||||
|
|
||||||
<p>I am outside of the package</p>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@ -3,22 +3,22 @@ from logging.config import fileConfig
|
|||||||
|
|
||||||
from sqlalchemy import engine_from_config
|
from sqlalchemy import engine_from_config
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import pool
|
||||||
from sqlmodel import create_engine
|
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sshecret_backend.models import *
|
from sshecret_backend.models import Base
|
||||||
|
|
||||||
def get_database_url() -> str:
|
|
||||||
"""Get database URL."""
|
|
||||||
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
|
||||||
return f"sqlite:///{db_file}"
|
|
||||||
return "sqlite:///sshecret.db"
|
|
||||||
|
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
|
||||||
|
def get_database_url() -> str | None:
|
||||||
|
"""Get database URL."""
|
||||||
|
if db_file := os.getenv("SSHECRET_BACKEND_DB"):
|
||||||
|
return f"sqlite:///{db_file}"
|
||||||
|
return config.get_main_option("sqlalchemy.url")
|
||||||
|
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
@ -28,8 +28,7 @@ if config.config_file_name is not None:
|
|||||||
# for 'autogenerate' support
|
# for 'autogenerate' support
|
||||||
# from myapp import mymodel
|
# from myapp import mymodel
|
||||||
# target_metadata = mymodel.Base.metadata
|
# target_metadata = mymodel.Base.metadata
|
||||||
#target_metadata = None
|
target_metadata = Base.metadata
|
||||||
target_metadata = SQLModel.metadata
|
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
# can be acquired:
|
# can be acquired:
|
||||||
@ -68,7 +67,11 @@ def run_migrations_online() -> None:
|
|||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
connectable = create_engine(get_database_url())
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from typing import Sequence, Union
|
|||||||
|
|
||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import sqlmodel
|
|
||||||
${imports if imports else ""}
|
${imports if imports else ""}
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
|
|||||||
@ -5,13 +5,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from fastapi import APIRouter, Depends, Request, Query
|
from fastapi import APIRouter, Depends, Request, Query
|
||||||
from sqlmodel import Session, col, func, select
|
from pydantic import TypeAdapter
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import AuditLog
|
from sshecret_backend.models import AuditLog
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import DBSessionDep
|
||||||
from sshecret_backend.view_models import AuditInfo
|
from sshecret_backend.view_models import AuditInfo, AuditView
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -21,7 +22,7 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
"""Construct audit sub-api."""
|
"""Construct audit sub-api."""
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.get("/audit/", response_model=list[AuditLog])
|
@router.get("/audit/", response_model=list[AuditView])
|
||||||
async def get_audit_logs(
|
async def get_audit_logs(
|
||||||
request: Request,
|
request: Request,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
@ -29,35 +30,37 @@ def get_audit_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
limit: Annotated[int, Query(le=100)] = 100,
|
limit: Annotated[int, Query(le=100)] = 100,
|
||||||
filter_client: Annotated[str | None, Query()] = None,
|
filter_client: Annotated[str | None, Query()] = None,
|
||||||
filter_subsystem: Annotated[str | None, Query()] = None,
|
filter_subsystem: Annotated[str | None, Query()] = None,
|
||||||
) -> Sequence[AuditLog]:
|
) -> Sequence[AuditView]:
|
||||||
"""Get audit logs."""
|
"""Get audit logs."""
|
||||||
#audit.audit_access_audit_log(session, request)
|
#audit.audit_access_audit_log(session, request)
|
||||||
statement = select(AuditLog).offset(offset).limit(limit).order_by(desc(col(AuditLog.timestamp)))
|
statement = select(AuditLog).offset(offset).limit(limit).order_by(AuditLog.timestamp.desc())
|
||||||
if filter_client:
|
if filter_client:
|
||||||
statement = statement.where(AuditLog.client_name == filter_client)
|
statement = statement.where(AuditLog.client_name == filter_client)
|
||||||
|
|
||||||
if filter_subsystem:
|
if filter_subsystem:
|
||||||
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
statement = statement.where(AuditLog.subsystem == filter_subsystem)
|
||||||
|
|
||||||
results = session.exec(statement).all()
|
LogAdapt = TypeAdapter(list[AuditView])
|
||||||
return results
|
results = session.scalars(statement).all()
|
||||||
|
return LogAdapt.validate_python(results, from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audit/")
|
@router.post("/audit/")
|
||||||
async def add_audit_log(
|
async def add_audit_log(
|
||||||
request: Request,
|
request: Request,
|
||||||
session: Annotated[Session, Depends(get_db_session)],
|
session: Annotated[Session, Depends(get_db_session)],
|
||||||
entry: AuditLog,
|
entry: AuditView,
|
||||||
) -> AuditLog:
|
) -> AuditView:
|
||||||
"""Add entry to audit log."""
|
"""Add entry to audit log."""
|
||||||
audit_log = AuditLog.model_validate(entry.model_dump(exclude_none=True))
|
audit_log = AuditLog(**entry.model_dump(exclude_none=True))
|
||||||
session.add(audit_log)
|
session.add(audit_log)
|
||||||
session.commit()
|
session.commit()
|
||||||
return audit_log
|
return AuditView.model_validate(audit_log, from_attributes=True)
|
||||||
|
|
||||||
@router.get("/audit/info")
|
@router.get("/audit/info")
|
||||||
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
async def get_audit_info(request: Request, session: Annotated[Session, Depends(get_db_session)]) -> AuditInfo:
|
||||||
"""Get audit info."""
|
"""Get audit info."""
|
||||||
audit_count = session.exec(select(func.count('*')).select_from(AuditLog)).one()
|
audit_count = session.scalars(select(func.count('*')).select_from(AuditLog)).one()
|
||||||
return AuditInfo(entries=audit_count)
|
return AuditInfo(entries=audit_count)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,11 +6,11 @@ import uuid
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from sqlmodel import Session, col, select
|
from typing import Annotated, Any, Self, TypeVar, cast
|
||||||
from sqlalchemy import func
|
|
||||||
from typing import Annotated, Self, TypeVar
|
|
||||||
|
|
||||||
from sqlmodel.sql.expression import SelectOfScalar
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.sql import Select
|
||||||
from sshecret_backend.types import DBSessionDep
|
from sshecret_backend.types import DBSessionDep
|
||||||
from sshecret_backend.models import Client, ClientSecret
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
from sshecret_backend.view_models import (
|
from sshecret_backend.view_models import (
|
||||||
@ -55,8 +55,8 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
|
|
||||||
def filter_client_statement(
|
def filter_client_statement(
|
||||||
statement: SelectOfScalar[T], params: ClientListParams, ignore_limits: bool = False
|
statement: Select[Any], params: ClientListParams, ignore_limits: bool = False
|
||||||
) -> SelectOfScalar[T]:
|
) -> Select[Any]:
|
||||||
"""Filter a statement with the provided params."""
|
"""Filter a statement with the provided params."""
|
||||||
if params.id:
|
if params.id:
|
||||||
statement = statement.where(Client.id == params.id)
|
statement = statement.where(Client.id == params.id)
|
||||||
@ -64,9 +64,9 @@ def filter_client_statement(
|
|||||||
if params.name:
|
if params.name:
|
||||||
statement = statement.where(Client.name == params.name)
|
statement = statement.where(Client.name == params.name)
|
||||||
elif params.name__like:
|
elif params.name__like:
|
||||||
statement = statement.where(col(Client.name).like(params.name__like))
|
statement = statement.where(Client.name.like(params.name__like))
|
||||||
elif params.name__contains:
|
elif params.name__contains:
|
||||||
statement = statement.where(col(Client.name).contains(params.name__contains))
|
statement = statement.where(Client.name.contains(params.name__contains))
|
||||||
|
|
||||||
if ignore_limits:
|
if ignore_limits:
|
||||||
return statement
|
return statement
|
||||||
@ -86,13 +86,13 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
"""Get clients."""
|
"""Get clients."""
|
||||||
# Get total results first
|
# Get total results first
|
||||||
count_statement = select(func.count("*")).select_from(Client)
|
count_statement = select(func.count("*")).select_from(Client)
|
||||||
count_statement = filter_client_statement(count_statement, filter_query, True)
|
count_statement = cast(Select[tuple[int]], filter_client_statement(count_statement, filter_query, True))
|
||||||
|
|
||||||
total_results = session.exec(count_statement).one()
|
total_results = session.scalars(count_statement).one()
|
||||||
|
|
||||||
statement = filter_client_statement(select(Client), filter_query, False)
|
statement = filter_client_statement(select(Client), filter_query, False)
|
||||||
|
|
||||||
results = session.exec(statement)
|
results = session.scalars(statement)
|
||||||
remainder = total_results - filter_query.offset - filter_query.limit
|
remainder = total_results - filter_query.offset - filter_query.limit
|
||||||
if remainder < 0:
|
if remainder < 0:
|
||||||
remainder = 0
|
remainder = 0
|
||||||
@ -170,13 +170,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
client.public_key = client_update.public_key
|
client.public_key = client_update.public_key
|
||||||
for secret in session.exec(
|
for secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all():
|
).all():
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
secret.invalidated = True
|
secret.invalidated = True
|
||||||
secret.client_id = None
|
secret.client_id = None
|
||||||
secret.client = None
|
|
||||||
|
|
||||||
session.add(client)
|
session.add(client)
|
||||||
session.refresh(client)
|
session.refresh(client)
|
||||||
@ -206,13 +205,12 @@ def get_clients_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
public_key_updated = False
|
public_key_updated = False
|
||||||
if client_update.public_key != client.public_key:
|
if client_update.public_key != client.public_key:
|
||||||
public_key_updated = True
|
public_key_updated = True
|
||||||
for secret in session.exec(
|
for secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all():
|
).all():
|
||||||
LOG.debug("Invalidated secret %s", secret.id)
|
LOG.debug("Invalidated secret %s", secret.id)
|
||||||
secret.invalidated = True
|
secret.invalidated = True
|
||||||
secret.client_id = None
|
secret.client_id = None
|
||||||
secret.client = None
|
|
||||||
|
|
||||||
session.add(client)
|
session.add(client)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|||||||
@ -4,7 +4,8 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from sshecret_backend.models import Client
|
from sshecret_backend.models import Client
|
||||||
|
|
||||||
@ -20,13 +21,13 @@ def verify_token(token: str, stored_hash: str) -> bool:
|
|||||||
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
async def get_client_by_name(session: Session, name: str) -> Client | None:
|
||||||
"""Get client by name."""
|
"""Get client by name."""
|
||||||
client_filter = select(Client).where(Client.name == name)
|
client_filter = select(Client).where(Client.name == name)
|
||||||
client_results = session.exec(client_filter)
|
client_results = session.scalars(client_filter)
|
||||||
return client_results.first()
|
return client_results.first()
|
||||||
|
|
||||||
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
async def get_client_by_id(session: Session, id: uuid.UUID) -> Client | None:
|
||||||
"""Get client by name."""
|
"""Get client by name."""
|
||||||
client_filter = select(Client).where(Client.id == id)
|
client_filter = select(Client).where(Client.id == id)
|
||||||
client_results = session.exec(client_filter)
|
client_results = session.scalars(client_filter)
|
||||||
return client_results.first()
|
return client_results.first()
|
||||||
|
|
||||||
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
async def get_client_by_id_or_name(session: Session, id_or_name: str) -> Client | None:
|
||||||
|
|||||||
@ -4,7 +4,8 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import ClientAccessPolicy
|
from sshecret_backend.models import ClientAccessPolicy
|
||||||
@ -54,7 +55,7 @@ def get_policy_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
status_code=404, detail="Cannot find a client with the given name."
|
status_code=404, detail="Cannot find a client with the given name."
|
||||||
)
|
)
|
||||||
# Remove old policies.
|
# Remove old policies.
|
||||||
policies = session.exec(
|
policies = session.scalars(
|
||||||
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
select(ClientAccessPolicy).where(ClientAccessPolicy.client_id == client.id)
|
||||||
).all()
|
).all()
|
||||||
deleted_policies: list[ClientAccessPolicy] = []
|
deleted_policies: list[ClientAccessPolicy] = []
|
||||||
|
|||||||
@ -5,7 +5,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from sshecret_backend.models import Client, ClientSecret
|
from sshecret_backend.models import Client, ClientSecret
|
||||||
@ -34,7 +35,7 @@ async def lookup_client_secret(
|
|||||||
.where(ClientSecret.client_id == client.id)
|
.where(ClientSecret.client_id == client.id)
|
||||||
.where(ClientSecret.name == name)
|
.where(ClientSecret.name == name)
|
||||||
)
|
)
|
||||||
results = session.exec(statement)
|
results = session.scalars(statement)
|
||||||
return results.first()
|
return results.first()
|
||||||
|
|
||||||
|
|
||||||
@ -165,7 +166,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> list[ClientSecretList]:
|
) -> list[ClientSecretList]:
|
||||||
"""Get a list of all secrets and which clients have them."""
|
"""Get a list of all secrets and which clients have them."""
|
||||||
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
client_secret_map: defaultdict[str, list[str]] = defaultdict(list)
|
||||||
for client_secret in session.exec(select(ClientSecret)).all():
|
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
if client_secret.name not in client_secret_map:
|
if client_secret.name not in client_secret_map:
|
||||||
client_secret_map[client_secret.name] = []
|
client_secret_map[client_secret.name] = []
|
||||||
@ -182,7 +183,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> list[ClientSecretDetailList]:
|
) -> list[ClientSecretDetailList]:
|
||||||
"""Get a list of all secrets and which clients have them."""
|
"""Get a list of all secrets and which clients have them."""
|
||||||
client_secrets: dict[str, ClientSecretDetailList] = {}
|
client_secrets: dict[str, ClientSecretDetailList] = {}
|
||||||
for client_secret in session.exec(select(ClientSecret)).all():
|
for client_secret in session.scalars(select(ClientSecret)).all():
|
||||||
|
|
||||||
if client_secret.name not in client_secrets:
|
if client_secret.name not in client_secrets:
|
||||||
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
client_secrets[client_secret.name] = ClientSecretDetailList(name=client_secret.name)
|
||||||
@ -202,7 +203,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> ClientSecretList:
|
) -> ClientSecretList:
|
||||||
"""Get a list of which clients has a named secret."""
|
"""Get a list of which clients has a named secret."""
|
||||||
clients: list[str] = []
|
clients: list[str] = []
|
||||||
for client_secret in session.exec(
|
for client_secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.name == name)
|
select(ClientSecret).where(ClientSecret.name == name)
|
||||||
).all():
|
).all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
@ -219,7 +220,7 @@ def get_secrets_api(get_db_session: DBSessionDep) -> APIRouter:
|
|||||||
) -> ClientSecretDetailList:
|
) -> ClientSecretDetailList:
|
||||||
"""Get a list of which clients has a named secret."""
|
"""Get a list of which clients has a named secret."""
|
||||||
detail_list = ClientSecretDetailList(name=name)
|
detail_list = ClientSecretDetailList(name=name)
|
||||||
for client_secret in session.exec(
|
for client_secret in session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.name == name)
|
select(ClientSecret).where(ClientSecret.name == name)
|
||||||
).all():
|
).all():
|
||||||
if not client_secret.client:
|
if not client_secret.client:
|
||||||
|
|||||||
@ -2,9 +2,10 @@
|
|||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy
|
from .models import AuditLog, Client, ClientSecret, ClientAccessPolicy, Operation, SubSystem
|
||||||
|
|
||||||
|
|
||||||
def _get_origin(request: Request) -> str | None:
|
def _get_origin(request: Request) -> str | None:
|
||||||
@ -22,7 +23,7 @@ def _write_audit_log(
|
|||||||
"""Write the audit log."""
|
"""Write the audit log."""
|
||||||
origin = _get_origin(request)
|
origin = _get_origin(request)
|
||||||
entry.origin = origin
|
entry.origin = origin
|
||||||
entry.subsystem = "backend"
|
entry.subsystem = SubSystem.BACKEND
|
||||||
session.add(entry)
|
session.add(entry)
|
||||||
if commit:
|
if commit:
|
||||||
session.commit()
|
session.commit()
|
||||||
@ -33,7 +34,7 @@ def audit_create_client(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Log the creation of a client."""
|
"""Log the creation of a client."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client Created",
|
message="Client Created",
|
||||||
@ -46,7 +47,7 @@ def audit_delete_client(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Log the creation of a client."""
|
"""Log the creation of a client."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client deleted",
|
message="Client deleted",
|
||||||
@ -63,9 +64,9 @@ def audit_create_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit a create secret event."""
|
"""Audit a create secret event."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
object="ClientSecret",
|
secret_id=secret.id,
|
||||||
object_id=str(secret.id),
|
secret_name=secret.name,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Added secret to client",
|
message="Added secret to client",
|
||||||
@ -81,13 +82,13 @@ def audit_remove_policy(
|
|||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Audit removal of policy."""
|
"""Audit removal of policy."""
|
||||||
|
data = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="DELETE",
|
operation=Operation.DELETE,
|
||||||
object="ClientAccessPolicy",
|
|
||||||
object_id=str(policy.id),
|
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Deleted client policy",
|
message="Deleted client policy",
|
||||||
|
data=data,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
@ -100,13 +101,13 @@ def audit_update_policy(
|
|||||||
commit: bool = True,
|
commit: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Audit update of policy."""
|
"""Audit update of policy."""
|
||||||
|
data: dict[str, str] = {"object": "ClientAccessPolicy", "object_id": str(policy.id)}
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="CREATE",
|
operation=Operation.CREATE,
|
||||||
object="ClientAccessPolicy",
|
|
||||||
object_id=str(policy.id),
|
|
||||||
client_id=client.id,
|
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
|
client_id=client.id,
|
||||||
message="Updated client policy",
|
message="Updated client policy",
|
||||||
|
data=data,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
@ -119,11 +120,10 @@ def audit_update_client(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit an update secret event."""
|
"""Audit an update secret event."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="UPDATE",
|
operation=Operation.UPDATE,
|
||||||
object="Client",
|
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
message="Client updated",
|
message="Client data updated",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
@ -137,11 +137,11 @@ def audit_update_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit an update secret event."""
|
"""Audit an update secret event."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="UPDATE",
|
operation=Operation.UPDATE,
|
||||||
object="ClientSecret",
|
|
||||||
object_id=str(secret.id),
|
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
|
secret_name=secret.name,
|
||||||
|
secret_id=secret.id,
|
||||||
message="Secret value updated",
|
message="Secret value updated",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
@ -155,8 +155,7 @@ def audit_invalidate_secrets(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit Invalidate client secrets."""
|
"""Audit Invalidate client secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="INVALIDATE",
|
operation=Operation.UPDATE,
|
||||||
object="ClientSecret",
|
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
message="Client public-key changed. All secrets invalidated.",
|
message="Client public-key changed. All secrets invalidated.",
|
||||||
@ -173,9 +172,9 @@ def audit_delete_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit Delete client secrets."""
|
"""Audit Delete client secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="DELETE",
|
operation=Operation.DELETE,
|
||||||
object="ClientSecret",
|
secret_name=secret.name,
|
||||||
object_id=str(secret.id),
|
secret_id=secret.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
message="Deleted secret.",
|
message="Deleted secret.",
|
||||||
@ -195,7 +194,7 @@ def audit_access_secrets(
|
|||||||
With no secrets provided, all secrets of the client will be resolved.
|
With no secrets provided, all secrets of the client will be resolved.
|
||||||
"""
|
"""
|
||||||
if not secrets:
|
if not secrets:
|
||||||
secrets = session.exec(
|
secrets = session.scalars(
|
||||||
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
select(ClientSecret).where(ClientSecret.client_id == client.id)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
@ -215,37 +214,21 @@ def audit_access_secret(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Audit that someone accessed one secrets."""
|
"""Audit that someone accessed one secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="ACCESS",
|
operation=Operation.READ,
|
||||||
message="Secret was viewed",
|
message="Secret was viewed",
|
||||||
object="ClientSecret",
|
secret_name=secret.name,
|
||||||
object_id=str(secret.id),
|
secret_id=secret.id,
|
||||||
client_id=client.id,
|
client_id=client.id,
|
||||||
client_name=client.name,
|
client_name=client.name,
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|
||||||
|
|
||||||
def audit_access_audit_log(
|
|
||||||
session: Session, request: Request, commit: bool = True
|
|
||||||
) -> None:
|
|
||||||
"""Audit access to the audit log.
|
|
||||||
|
|
||||||
Because why not...
|
|
||||||
"""
|
|
||||||
entry = AuditLog(
|
|
||||||
operation="ACCESS",
|
|
||||||
message="Audit log was viewed",
|
|
||||||
object="AuditLog",
|
|
||||||
)
|
|
||||||
_write_audit_log(session, request, entry, commit)
|
|
||||||
|
|
||||||
|
|
||||||
def audit_client_secret_list(
|
def audit_client_secret_list(
|
||||||
session: Session, request: Request, commit: bool = True
|
session: Session, request: Request, commit: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Audit a list of all secrets."""
|
"""Audit a list of all secrets."""
|
||||||
entry = AuditLog(
|
entry = AuditLog(
|
||||||
operation="ACCESS",
|
operation=Operation.READ,
|
||||||
message="All secret names and their clients was viewed",
|
message="All secret names and their clients was viewed",
|
||||||
)
|
)
|
||||||
_write_audit_log(session, request, entry, commit)
|
_write_audit_log(session, request, entry, commit)
|
||||||
|
|||||||
17
packages/sshecret-backend/src/sshecret_backend/auth.py
Normal file
17
packages/sshecret-backend/src/sshecret_backend/auth.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Auth helpers."""
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
|
||||||
|
def hash_token(token: str) -> str:
|
||||||
|
"""Hash a token."""
|
||||||
|
pwbytes = token.encode("utf-8")
|
||||||
|
salt = bcrypt.gensalt()
|
||||||
|
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
|
||||||
|
return hashed_bytes.decode()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token(token: str, stored_hash: str) -> bool:
|
||||||
|
"""Verify token."""
|
||||||
|
token_bytes = token.encode("utf-8")
|
||||||
|
stored_bytes = stored_hash.encode("utf-8")
|
||||||
|
return bcrypt.checkpw(token_bytes, stored_bytes)
|
||||||
@ -3,11 +3,11 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException
|
from fastapi import APIRouter, Depends, Header, HTTPException
|
||||||
from sqlmodel import Session, select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
from .api import get_audit_api, get_clients_api, get_policy_api, get_secrets_api
|
||||||
|
from .auth import verify_token
|
||||||
from .models import (
|
from .models import (
|
||||||
APIClient,
|
APIClient,
|
||||||
)
|
)
|
||||||
@ -18,13 +18,6 @@ LOG = logging.getLogger(__name__)
|
|||||||
API_VERSION = "v1"
|
API_VERSION = "v1"
|
||||||
|
|
||||||
|
|
||||||
def verify_token(token: str, stored_hash: str) -> bool:
|
|
||||||
"""Verify token."""
|
|
||||||
token_bytes = token.encode("utf-8")
|
|
||||||
stored_bytes = stored_hash.encode("utf-8")
|
|
||||||
return bcrypt.checkpw(token_bytes, stored_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def get_backend_api(
|
def get_backend_api(
|
||||||
get_db_session: DBSessionDep,
|
get_db_session: DBSessionDep,
|
||||||
) -> APIRouter:
|
) -> APIRouter:
|
||||||
@ -37,7 +30,7 @@ def get_backend_api(
|
|||||||
"""Validate token."""
|
"""Validate token."""
|
||||||
LOG.debug("Validating token %s", x_api_token)
|
LOG.debug("Validating token %s", x_api_token)
|
||||||
statement = select(APIClient)
|
statement = select(APIClient)
|
||||||
results = session.exec(statement)
|
results = session.scalars(statement)
|
||||||
valid = False
|
valid = False
|
||||||
for result in results:
|
for result in results:
|
||||||
if verify_token(x_api_token, result.token):
|
if verify_token(x_api_token, result.token):
|
||||||
|
|||||||
@ -3,15 +3,24 @@
|
|||||||
import code
|
import code
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import cast
|
from typing import Literal, cast
|
||||||
from dotenv import load_dotenv
|
|
||||||
import click
|
import click
|
||||||
from sqlmodel import Session, col, func, select
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from .db import get_engine, create_api_token
|
from .db import create_api_token, get_engine, hash_token
|
||||||
|
from .models import (
|
||||||
from .models import Client, ClientSecret, ClientAccessPolicy, AuditLog, APIClient, init_db
|
APIClient,
|
||||||
|
AuditLog,
|
||||||
|
Client,
|
||||||
|
ClientAccessPolicy,
|
||||||
|
ClientSecret,
|
||||||
|
SubSystem,
|
||||||
|
init_db,
|
||||||
|
)
|
||||||
from .settings import BackendSettings
|
from .settings import BackendSettings
|
||||||
|
|
||||||
DEFAULT_LISTEN = "127.0.0.1"
|
DEFAULT_LISTEN = "127.0.0.1"
|
||||||
@ -21,22 +30,44 @@ WORKDIR = Path(os.getcwd())
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
def generate_token(settings: BackendSettings) -> str:
|
|
||||||
|
def generate_token(
|
||||||
|
settings: BackendSettings, subsystem: Literal["admin", "sshd"]
|
||||||
|
) -> str:
|
||||||
"""Generate a token."""
|
"""Generate a token."""
|
||||||
engine = get_engine(settings.db_url)
|
engine = get_engine(settings.db_url)
|
||||||
init_db(engine)
|
init_db(engine)
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
token = create_api_token(session, True)
|
token = create_api_token(session, subsystem)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def count_tokens(settings: BackendSettings) -> int:
|
|
||||||
"""Count the amount of tokens created."""
|
def add_system_tokens(settings: BackendSettings) -> None:
|
||||||
|
"""Add token for subsystems."""
|
||||||
|
if not settings.admin_token and not settings.sshd_token:
|
||||||
|
# Tokens should be generated manually.
|
||||||
|
return
|
||||||
|
|
||||||
engine = get_engine(settings.db_url)
|
engine = get_engine(settings.db_url)
|
||||||
init_db(engine)
|
init_db(engine)
|
||||||
|
tokens: list[tuple[str, SubSystem]] = []
|
||||||
|
if admin_token := settings.admin_token:
|
||||||
|
tokens.append((admin_token, SubSystem.ADMIN))
|
||||||
|
if sshd_token := settings.sshd_token:
|
||||||
|
tokens.append((sshd_token, SubSystem.SSHD))
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
count = session.exec(select(func.count("*")).select_from(APIClient)).one()
|
for token, subsystem in tokens:
|
||||||
|
hashed_token = hash_token(token)
|
||||||
|
if existing := session.scalars(
|
||||||
|
select(APIClient).where(APIClient.subsystem == subsystem)
|
||||||
|
).first():
|
||||||
|
existing.token = hashed_token
|
||||||
|
else:
|
||||||
|
new_token = APIClient(token=hashed_token, subsystem=subsystem)
|
||||||
|
session.add(new_token)
|
||||||
|
|
||||||
return count
|
session.commit()
|
||||||
|
click.echo("Generated system tokens.")
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@ -49,27 +80,30 @@ def cli(ctx: click.Context, database: str) -> None:
|
|||||||
else:
|
else:
|
||||||
settings = BackendSettings()
|
settings = BackendSettings()
|
||||||
|
|
||||||
|
add_system_tokens(settings)
|
||||||
|
|
||||||
if settings.generate_initial_tokens:
|
# if settings.generate_initial_tokens:
|
||||||
if count_tokens(settings) == 0:
|
# if count_tokens(settings) == 0:
|
||||||
click.echo("Creating initial tokens for admin and sshd.")
|
# click.echo("Creating initial tokens for admin and sshd.")
|
||||||
admin_token = generate_token(settings)
|
# admin_token = generate_token(settings)
|
||||||
sshd_token = generate_token(settings)
|
# sshd_token = generate_token(settings)
|
||||||
click.echo(f"Admin token: {admin_token}")
|
# click.echo(f"Admin token: {admin_token}")
|
||||||
click.echo(f"SSHD token: {sshd_token}")
|
# click.echo(f"SSHD token: {sshd_token}")
|
||||||
|
|
||||||
ctx.obj = settings
|
ctx.obj = settings
|
||||||
|
|
||||||
|
|
||||||
@cli.command("generate-token")
|
@cli.command("generate-token")
|
||||||
|
@click.argument("subsystem", type=click.Choice(["sshd", "admin"]))
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def cli_generate_token(ctx: click.Context) -> None:
|
def cli_generate_token(ctx: click.Context, subsystem: Literal["sshd", "admin"]) -> None:
|
||||||
"""Generate a token."""
|
"""Generate a token for a subsystem.."""
|
||||||
settings = cast(BackendSettings, ctx.obj)
|
settings = cast(BackendSettings, ctx.obj)
|
||||||
token = generate_token(settings)
|
token = generate_token(settings, subsystem)
|
||||||
click.echo("Generated api token:")
|
click.echo("Generated api token:")
|
||||||
click.echo(token)
|
click.echo(token)
|
||||||
|
|
||||||
|
|
||||||
@cli.command("run")
|
@cli.command("run")
|
||||||
@click.option("--host", default="127.0.0.1")
|
@click.option("--host", default="127.0.0.1")
|
||||||
@click.option("--port", default=8022, type=click.INT)
|
@click.option("--port", default=8022, type=click.INT)
|
||||||
@ -77,7 +111,10 @@ def cli_generate_token(ctx: click.Context) -> None:
|
|||||||
@click.option("--workers", type=click.INT)
|
@click.option("--workers", type=click.INT)
|
||||||
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
def cli_run(host: str, port: int, dev: bool, workers: int | None) -> None:
|
||||||
"""Run the server."""
|
"""Run the server."""
|
||||||
uvicorn.run("sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers)
|
uvicorn.run(
|
||||||
|
"sshecret_backend.main:app", host=host, port=port, reload=dev, workers=workers
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cli.command("repl")
|
@cli.command("repl")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
|
|||||||
@ -2,57 +2,109 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
from collections.abc import Generator, Callable
|
from collections.abc import Generator, Callable
|
||||||
from pathlib import Path
|
from typing import Literal
|
||||||
from sqlalchemy import Engine
|
from sqlalchemy import create_engine, Engine, event, select
|
||||||
from sqlmodel import Session, create_engine, text
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||||
import bcrypt
|
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
|
|
||||||
|
from .auth import hash_token, verify_token
|
||||||
from .models import APIClient
|
from .models import APIClient, SubSystem
|
||||||
|
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def setup_database(
|
def setup_database(
|
||||||
db_url: URL | str,
|
db_url: URL,
|
||||||
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
) -> tuple[Engine, Callable[[], Generator[Session, None, None]]]:
|
||||||
"""Setup database."""
|
"""Setup database."""
|
||||||
|
|
||||||
engine = create_engine(db_url, echo=False)
|
engine = get_engine(db_url)
|
||||||
with engine.connect() as connection:
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, future=True)
|
||||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
|
||||||
|
|
||||||
def get_db_session() -> Generator[Session, None, None]:
|
def get_db_session() -> Generator[Session, None, None]:
|
||||||
"""Get DB Session."""
|
"""Get DB Session."""
|
||||||
with Session(engine) as session:
|
session = SessionLocal(bind=engine)
|
||||||
|
try:
|
||||||
yield session
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
return engine, get_db_session
|
return engine, get_db_session
|
||||||
|
|
||||||
|
|
||||||
def get_engine(url: URL, echo: bool = False) -> Engine:
|
def get_engine(url: URL, echo: bool = False) -> Engine:
|
||||||
"""Initialize the engine."""
|
"""Initialize the engine."""
|
||||||
engine = create_engine(url, echo=echo)
|
engine = create_engine(url, echo=echo, future=True)
|
||||||
with engine.connect() as connection:
|
if url.drivername.startswith("sqlite"):
|
||||||
connection.execute(text("PRAGMA foreign_keys=ON")) # for SQLite only
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def set_sqlite_pragma(
|
||||||
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
|
) -> None:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def create_api_token(session: Session, read_write: bool) -> str:
|
def get_async_engine(url: URL, echo: bool = False) -> AsyncEngine:
|
||||||
"""Create API token."""
|
"""Get an async engine."""
|
||||||
token = secrets.token_urlsafe(32)
|
engine = create_async_engine(url, echo=echo, future=True)
|
||||||
pwbytes = token.encode("utf-8")
|
if url.drivername.startswith("sqlite+"):
|
||||||
salt = bcrypt.gensalt()
|
|
||||||
hashed_bytes = bcrypt.hashpw(password=pwbytes, salt=salt)
|
@event.listens_for(engine, "connect")
|
||||||
hashed = hashed_bytes.decode()
|
def set_sqlite_pragma(
|
||||||
|
dbapi_connection: sqlite3.Connection, _connection_record: object
|
||||||
|
) -> None:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def create_api_token_with_value(session: Session, token: str, subsystem: Literal["admin", "sshd"]) -> None:
|
||||||
|
"""Create API token with a given value."""
|
||||||
|
|
||||||
|
existing = session.scalars(select(APIClient).where(APIClient.subsystem == SubSystem(subsystem))).first()
|
||||||
|
if existing:
|
||||||
|
if verify_token(token, existing.token):
|
||||||
|
LOG.info("Token is up to date.")
|
||||||
|
return
|
||||||
|
LOG.info("Updating token value for subsystem %s", subsystem)
|
||||||
|
hashed = hash_token(token)
|
||||||
|
existing.token=hashed
|
||||||
|
session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
LOG.info("No existing token found. Creating new")
|
||||||
|
hashed = hash_token(token)
|
||||||
|
api_token = APIClient(token=hashed, subsystem=SubSystem(subsystem))
|
||||||
|
|
||||||
api_token = APIClient(token=hashed, read_write=read_write)
|
|
||||||
session.add(api_token)
|
session.add(api_token)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
def create_api_token(session: Session, subsystem: Literal["admin", "sshd", "test"], recreate: bool = False) -> str:
|
||||||
|
"""Create API token."""
|
||||||
|
subsys = SubSystem(subsystem)
|
||||||
|
token = secrets.token_urlsafe(32)
|
||||||
|
hashed = hash_token(token)
|
||||||
|
if existing := session.scalars(select(APIClient).where(APIClient.subsystem == subsys)).first():
|
||||||
|
if not recreate:
|
||||||
|
raise RuntimeError("Error: A token already exist for this subsystem.")
|
||||||
|
existing.token = hashed
|
||||||
|
else:
|
||||||
|
api_token = APIClient(token=hashed, subsystem=subsys)
|
||||||
|
session.add(api_token)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user