This commit is contained in:
1
admidio/__init__.py
Normal file
1
admidio/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
147
admidio/db.py
Normal file
147
admidio/db.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Admidio db functions.
|
||||
|
||||
Here's some queries that might be relevant:
|
||||
|
||||
Special fields
|
||||
SELECT usf_id FROM adm_user_fields where usf_name =
|
||||
- PMB_PAID : Date paid
|
||||
- PMB_FEE : The amount to pay
|
||||
- PMB_DUEDATE - the due date
|
||||
|
||||
|
||||
Get user:
|
||||
'SELECT usr_login_name FROM adm_users WHERE usr_id = '
|
||||
Get Payment Date:
|
||||
|
||||
'SELECT usd_value FROM adm_user_data WHERE usd_usr_id = " + body.id + " AND usd_usf_id = 22
|
||||
Get Payment Amount:
|
||||
'SELECT usd_value FROM adm_user_data WHERE usd_usr_id = " + body.id + " AND usd_usf_id = 23'
|
||||
paymentAmount[0].usd_value
|
||||
|
||||
Success:
|
||||
'INSERT INTO adm_user_data (usd_usr_id, usd_usf_id, usd_value) VALUES (' + user_id + ', 23, "' + date + '") ON DUPLICATE KEY UPDATE usd_value = "' + date + '";'
|
||||
|
||||
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
import mariadb
|
||||
from typing import Final, Iterator
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdmidioDB:
|
||||
"""Low-level admidio class."""
|
||||
|
||||
def __init__(
|
||||
self, user: str, password: str, host: str, database: str, port: int = 3306
|
||||
) -> None:
|
||||
"""Initialize database class."""
|
||||
|
||||
def connect() -> mariadb.Connection:
|
||||
connection: mariadb.Connection = mariadb.connect(
|
||||
user=user,
|
||||
password=password,
|
||||
host=host,
|
||||
database=database,
|
||||
port=port,
|
||||
)
|
||||
return connection
|
||||
|
||||
self._connect: Final = connect
|
||||
|
||||
# self.cursor: mariadb.Cursor = connection.cursor()
|
||||
|
||||
@contextmanager
|
||||
def query(self, commit: bool = False) -> Iterator[mariadb.Cursor]:
|
||||
"""Query in context."""
|
||||
connection = self._connect()
|
||||
cursor = connection.cursor()
|
||||
yield cursor
|
||||
if commit:
|
||||
connection.commit()
|
||||
connection.close()
|
||||
|
||||
def get_custom_fields(self) -> dict[str, int]:
|
||||
"""Get a mapping of custom fields."""
|
||||
query = "SELECT usf_name, usf_id FROM adm_user_fields"
|
||||
results: dict[str, int] = {}
|
||||
with self.query() as cursor:
|
||||
cursor.execute(query)
|
||||
for row in cursor:
|
||||
usf_name, usf_id = row
|
||||
results[str(usf_name)] = int(usf_id)
|
||||
|
||||
return results
|
||||
|
||||
def get_user_id_by_field(self, field: int, value: str) -> int | None:
|
||||
"""Get user ID by lookup of custom field."""
|
||||
# SELECT usd_value, usd_usr_id from adm_user_data WHERE usd_usf_id = '11'
|
||||
query = "SELECT usd_usr_id FROM adm_user_data WHERE usd_usf_id = ? AND usd_value = ? LIMIT 1"
|
||||
with self.query() as cursor:
|
||||
cursor.execute(query, (field, value))
|
||||
result = cursor.fetchone()
|
||||
|
||||
if result:
|
||||
return int(result[0])
|
||||
|
||||
return None
|
||||
|
||||
def get_adm_user_data(self, user_id: int) -> dict[int, str]:
|
||||
"""Get adm user data.
|
||||
|
||||
Returns a list of tuples of field ID and value.
|
||||
"""
|
||||
query = f"SELECT usd_usf_id, usd_value FROM adm_user_data WHERE usd_usr_id = ?"
|
||||
with self.query() as cursor:
|
||||
cursor.execute(query, (user_id,))
|
||||
results = {int(usf_id): str(usd_value) for usf_id, usd_value in cursor}
|
||||
return results
|
||||
|
||||
def create_user_data(self, user_id: int, usf_id: int, usd_value: str) -> None:
|
||||
"""Create or update user data."""
|
||||
query = (
|
||||
"INSERT INTO adm_user_data (usd_usr_id, usd_usf_id, usd_value) VALUES (?, ?, ?)"
|
||||
" ON DUPLICATE KEY UPDATE usd_value = ?"
|
||||
)
|
||||
parameters = (user_id, usf_id, usd_value, usd_value)
|
||||
with self.query(True) as cursor:
|
||||
cursor.execute(query, parameters)
|
||||
LOG.debug(
|
||||
"Ran query %s (%r), last insert ID: %s",
|
||||
query,
|
||||
parameters,
|
||||
cursor.lastrowid,
|
||||
)
|
||||
|
||||
def get_roles(self) -> dict[int, str]:
|
||||
"""Get roles.
|
||||
|
||||
dict int for role ID, str for role name.
|
||||
"""
|
||||
query = "SELECT rol_id, rol_name FROM adm_roles"
|
||||
with self.query() as cursor:
|
||||
cursor.execute(query)
|
||||
result = {int(rol_id): str(rol_name) for (rol_id, rol_name) in cursor}
|
||||
return result
|
||||
|
||||
def get_user_roles(self, user_id: int) -> list[int]:
|
||||
"""Get Role ID for a user"""
|
||||
query = "SELECT mem_rol_id FROM adm_members WHERE mem_usr_id ?"
|
||||
with self.query() as cursor:
|
||||
cursor.execute(query, (user_id,))
|
||||
result = [int(mem_rol_id) for mem_rol_id in cursor]
|
||||
return result
|
||||
|
||||
def get_role_payments(self) -> dict[int, int]:
|
||||
"""Get payment fees for all roles."""
|
||||
query = "SELECT rol_id, rol_cost FROM adm_roles"
|
||||
role_fees: dict[int, int] = {}
|
||||
with self.query() as cursor:
|
||||
cursor.execute(query)
|
||||
for role_id, role_cost in cursor:
|
||||
if role_cost is not None:
|
||||
role_fees[int(role_id)] = role_cost
|
||||
return role_fees
|
||||
161
admidio/model.py
Normal file
161
admidio/model.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""High-levle api."""
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
|
||||
from typing import Final
|
||||
from .db import AdmidioDB
|
||||
|
||||
|
||||
MEMBER_ROLE = "Member"
|
||||
FIELD_PAID_DATE = "PMB_PAID"
|
||||
FIELD_FEE = "PMB_FEE"
|
||||
FIELD_DUEDATE = "PMB_DUEDATE"
|
||||
FIELD_EMAIL = "SYS_EMAIL"
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdmidioDBSettings:
|
||||
"""Database settings class."""
|
||||
|
||||
user: str
|
||||
password: str
|
||||
host: str
|
||||
database: str
|
||||
port: int = 3306
|
||||
|
||||
|
||||
class AdmidioMemberFee:
|
||||
"""Admidio member fee manager."""
|
||||
|
||||
def __init__(self, user_id: int, connection: AdmidioDBSettings) -> None:
|
||||
"""Create member fee manager class."""
|
||||
self.db: AdmidioDB = AdmidioDB(
|
||||
connection.user,
|
||||
connection.password,
|
||||
connection.host,
|
||||
connection.database,
|
||||
connection.port,
|
||||
)
|
||||
self.user_id: Final = user_id
|
||||
self._user_data: dict[str, str] | None = None
|
||||
|
||||
@property
|
||||
def is_member(self) -> bool:
|
||||
"""Check if user is a member."""
|
||||
role_ids = self.db.get_user_roles(self.user_id)
|
||||
role_map = self.db.get_roles()
|
||||
roles = [role_map[role_id] for role_id in role_ids]
|
||||
return MEMBER_ROLE in roles
|
||||
|
||||
@property
|
||||
def member_fee(self) -> int:
|
||||
"""Get membership fee."""
|
||||
roles = self.db.get_roles()
|
||||
member_role_id = next(
|
||||
iter(
|
||||
[
|
||||
role_id
|
||||
for role_id, role_name in roles.items()
|
||||
if role_name == MEMBER_ROLE
|
||||
]
|
||||
)
|
||||
)
|
||||
role_payments = self.db.get_role_payments()
|
||||
fee = role_payments.get(member_role_id)
|
||||
if not fee:
|
||||
raise RuntimeError("Error: No membership cost set on Member role.")
|
||||
return int(fee)
|
||||
|
||||
@property
|
||||
def user_data(self) -> dict[str, str]:
|
||||
"""Get user data."""
|
||||
if not self._user_data:
|
||||
self._user_data = self._get_user_data()
|
||||
return self._user_data
|
||||
|
||||
def _get_user_data(self) -> dict[str, str]:
|
||||
"""Get user data."""
|
||||
user_fields = self.db.get_custom_fields()
|
||||
user_field_ids = {value: key for key, value in user_fields.items()}
|
||||
adm_user_data = self.db.get_adm_user_data(self.user_id)
|
||||
user_data: dict[str, str] = {}
|
||||
for field_id, data in adm_user_data.items():
|
||||
if field_name := user_field_ids.get(field_id):
|
||||
user_data[field_name] = data
|
||||
else:
|
||||
LOG.warning("Could not find a field definition for %s", field_id)
|
||||
return user_data
|
||||
|
||||
def _lookup_user_field(self, name: str) -> int:
|
||||
"""Lookup the ID of a user field."""
|
||||
user_fields = self.db.get_custom_fields()
|
||||
field_id = user_fields.get(name)
|
||||
if not field_id:
|
||||
raise RuntimeError(f"Unable to find field ID for field {name}")
|
||||
return field_id
|
||||
|
||||
@property
|
||||
def last_paid(self) -> date | None:
|
||||
"""Get the date of last payment."""
|
||||
paid_data = self.user_data.get(FIELD_PAID_DATE)
|
||||
if not paid_data:
|
||||
return None
|
||||
paid_date = date.fromisoformat(paid_data)
|
||||
return paid_date
|
||||
|
||||
@property
|
||||
def has_paid(self) -> bool:
|
||||
"""Check if a user has paid."""
|
||||
if not self.last_paid:
|
||||
return False
|
||||
paid_days = date.today() - self.last_paid
|
||||
if paid_days.days > 365:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def amount_due(self) -> int:
|
||||
"""Get amount due."""
|
||||
if fee := self.user_data.get(FIELD_FEE):
|
||||
return int(fee)
|
||||
if not self.is_member:
|
||||
raise RuntimeError(
|
||||
"User does not seem to be a member, and has no due fees."
|
||||
)
|
||||
return self.member_fee
|
||||
|
||||
def register_payment(self, amount_paid: int | None = None) -> None:
|
||||
"""Register payment,"""
|
||||
date_field = self._lookup_user_field(FIELD_PAID_DATE)
|
||||
amount_field = self._lookup_user_field(FIELD_FEE)
|
||||
date_paid = date.today().isoformat()
|
||||
if not amount_paid:
|
||||
amount_paid = self.member_fee
|
||||
self.db.create_user_data(self.user_id, date_field, date_paid)
|
||||
self.db.create_user_data(self.user_id, amount_field, str(amount_paid))
|
||||
|
||||
@classmethod
|
||||
def lookup_email(
|
||||
cls, email: str, settings: AdmidioDBSettings
|
||||
) -> "AdmidioMemberFee | None":
|
||||
"""Create instance by lookup up user email address."""
|
||||
db = AdmidioDB(
|
||||
settings.user,
|
||||
settings.password,
|
||||
settings.host,
|
||||
settings.database,
|
||||
settings.port,
|
||||
)
|
||||
fields = db.get_custom_fields()
|
||||
email_field_id = fields.get(FIELD_EMAIL)
|
||||
if not email_field_id:
|
||||
raise RuntimeError("Could not resolve email address field in database.")
|
||||
if user_id := db.get_user_id_by_field(email_field_id, email):
|
||||
return cls(user_id, settings)
|
||||
|
||||
return None
|
||||
42
admidio/settings.py
Normal file
42
admidio/settings.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Get settings."""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
from .model import AdmidioDBSettings
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
|
||||
def get_envvar(name: str) -> str:
|
||||
"""Get environemnt variable."""
|
||||
file_var = f"{name}_FILE"
|
||||
if file_name := os.getenv(file_var):
|
||||
with open(file_name, "r") as f:
|
||||
return f.read().strip()
|
||||
if value := os.getenv(name):
|
||||
return value
|
||||
raise RuntimeError(f"Unable to read environment variable {name}")
|
||||
|
||||
def get_db_settings() -> AdmidioDBSettings:
|
||||
"""Get database settings."""
|
||||
hostname = get_envvar("ADMIDIO_DB_HOST")
|
||||
user = get_envvar("ADMIDIO_DB_USER")
|
||||
password = get_envvar("ADMIDIO_DB_PASSWORD")
|
||||
database = get_envvar("ADMIDIO_DB_NAME")
|
||||
if ":" in hostname:
|
||||
host, port = hostname.split(":")
|
||||
else:
|
||||
host = hostname
|
||||
port = 3306
|
||||
return AdmidioDBSettings(user, password, host, database, int(port))
|
||||
|
||||
def get_stripe_key() -> str:
|
||||
"""Get the stripe key."""
|
||||
return get_envvar("STRIPE_KEY")
|
||||
|
||||
def get_stripe_price() -> str:
|
||||
return get_envvar("STRIPE_PRICE_ID")
|
||||
|
||||
def get_domain() -> str:
|
||||
"""Get domain."""
|
||||
return get_envvar("DKNOG_APP_HOST")
|
||||
1
admidio/stripe.py
Normal file
1
admidio/stripe.py
Normal file
@ -0,0 +1 @@
|
||||
"""Stripe methods."""
|
||||
Reference in New Issue
Block a user