From 9f5cec705e0391d33ed702a2505a39f3bcf9e075 Mon Sep 17 00:00:00 2001 From: Johnny Gear Date: Thu, 13 Nov 2025 22:03:21 -0600 Subject: [PATCH] Use peewee ORM --- Pipfile | 1 + Pipfile.lock | 9 ++- db.py | 152 ------------------------------------------------- db/__init__.py | 0 db/models.py | 42 ++++++++++++++ db/queries.py | 72 +++++++++++++++++++++++ generate.py | 17 +++--- main.py | 4 ++ orders.py | 55 ++++++++++-------- scheduling.py | 8 +-- settings.py | 2 + telegram.py | 5 +- 12 files changed, 174 insertions(+), 193 deletions(-) delete mode 100644 db.py create mode 100644 db/__init__.py create mode 100644 db/models.py create mode 100644 db/queries.py diff --git a/Pipfile b/Pipfile index aa8a23b..3d5833e 100644 --- a/Pipfile +++ b/Pipfile @@ -8,6 +8,7 @@ pyyaml = "*" aiohttp = "*" scheduler = "*" pytz = "*" +peewee = "*" [dev-packages] diff --git a/Pipfile.lock b/Pipfile.lock index f555222..6e89182 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "97e23fc67642a95bffea8782f93d340d26a92a440feae5e4b7d10dedbec5cccb" + "sha256": "3189468ccffbb23866656384675ad9f79f061393b93965ce86aa2c54483e4776" }, "pipfile-spec": 6, "requires": { @@ -463,6 +463,13 @@ "markers": "python_version >= '3.9'", "version": "==6.7.0" }, + "peewee": { + "hashes": [ + "sha256:62c3d93315b1a909360c4b43c3a573b47557a1ec7a4583a71286df2a28d4b72e" + ], + "index": "pypi", + "version": "==3.18.3" + }, "propcache": { "hashes": [ "sha256:0002004213ee1f36cfb3f9a42b5066100c44276b9b72b4e1504cddd3d692e86e", diff --git a/db.py b/db.py deleted file mode 100644 index c4f9007..0000000 --- a/db.py +++ /dev/null @@ -1,152 +0,0 @@ -import sqlite3 - -from settings import SQLITE_DB - -TABLE_REPEAT = ''' - CREATE TABLE IF NOT EXISTS repeat ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - probability FLOAT NOT NULL, - orders TEXT NOT NULL, - count INTEGER DEFAULT 0 - ); -''' - -TABLE_SKIP_DAY = ''' - CREATE TABLE IF NOT EXISTS skip_day ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - date TIMESTAMP UNIQUE NOT NULL - ); -''' - -TABLE_ORDER_STATUS = ''' - CREATE TABLE IF NOT EXISTS order_status ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mastodon_id TEXT NOT NULL, - created_at TIMESTAMP NOT NULL, - due_at TIMESTAMP NOT NULL, - text TEXT NOT NULL, - confirmed_at TIMESTAMP, - punishment_id INTEGER, - FOREIGN KEY(punishment_id) REFERENCES punishment_status(id) - ); -''' - -TABLE_PUNISHMENT_STATUS = ''' - CREATE TABLE IF NOT EXISTS punishment_status ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - mastodon_id TEXT NOT NULL, - created_at TIMESTAMP NOT NULL, - text TEXT NOT NULL, - confirmed_at TIMESTAMP - ); -''' - -class Database: - def __init__(self): - self.conn = sqlite3.connect(SQLITE_DB) - self.conn.row_factory = sqlite3.Row - self.table_init(TABLE_REPEAT) - self.table_init(TABLE_SKIP_DAY) - self.table_init(TABLE_PUNISHMENT_STATUS) - self.table_init(TABLE_ORDER_STATUS) - - def table_init(self, table_sql): - c = self.conn.cursor() - c.execute(table_sql) - self.conn.commit() - - def update(self, sql, args=[]): - c = self.conn.cursor() - c.execute(sql, args) - self.conn.commit() - return c.lastrowid - - def repeat_get(self): - c = self.conn.cursor() - sql = 'SELECT id, probability, orders, count FROM repeat LIMIT 1' - c.execute(sql) - return c.fetchone() - - def repeat_increment(self): - self.update('UPDATE repeat SET count = count + 1') - - def repeat_put(self, probability, orders): - self.update( - 'INSERT INTO repeat (probability, orders) VALUES (?, ?);', - [probability, orders] - ) - - def repeat_clear(self): - self.update('DELETE FROM repeat') - - def skip_day_put(self, date): - self.update( - 'INSERT INTO skip_day (date) VALUES (?);', - [date] - ) - - def skip_day_contains(self, date): - c = self.conn.cursor() - sql = 'SELECT * FROM skip_day WHERE date=?;' - c.execute(sql, [date]) - return c.fetchone() is not None - - def order_status_put(self, mastodon_id, created_at, due_at, text): - self.update( - ''' - INSERT INTO order_status - (mastodon_id, created_at, due_at, text) - VALUES (?, ?, ?, ?); - ''', - [ - mastodon_id, - created_at, - due_at, - text - ] - ) - - def order_status_outstanding(self): - c = self.conn.cursor() - sql = ''' - SELECT id, mastodon_id, created_at, due_at, confirmed_at - FROM order_status - WHERE confirmed_at IS NULL AND punishment_id IS NULL - ''' - c.execute(sql) - return c.fetchall() - - def order_status_confirm(self, id, confirmed_at): - self.update( - ''' - UPDATE order_status - SET confirmed_at=? - WHERE id=?; - ''', - [ - confirmed_at, - id - ] - ) - - def punishment_status_put(self, order_status_id, mastodon_id, created_at, text): - punishment_status_id = self.update( - ''' - INSERT INTO punishment_status - (mastodon_id, created_at, text) - VALUES (?, ?, ?); - ''', - [ - mastodon_id, - created_at, - text - ] - ) - self.update( - ''' - UPDATE order_status - SET punishment_id=? - WHERE id=? - ''', - [punishment_status_id, order_status_id] - ) diff --git a/db/__init__.py b/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/db/models.py b/db/models.py new file mode 100644 index 0000000..a6e921c --- /dev/null +++ b/db/models.py @@ -0,0 +1,42 @@ +from peewee import * +from settings import SQLITE_DB + +database = SqliteDatabase(SQLITE_DB) + +class BaseModel(Model): + class Meta: + database = database + +class PunishmentStatus(BaseModel): + confirmed_at = DateTimeField(null=True) # TIMESTAMP + created_at = DateTimeField() # TIMESTAMP + mastodon_id = TextField() + text = TextField() + + class Meta: + table_name = 'punishment_status' + +class OrderStatus(BaseModel): + confirmed_at = DateTimeField(null=True) # TIMESTAMP + created_at = DateTimeField() # TIMESTAMP + due_at = DateTimeField() # TIMESTAMP + mastodon_id = TextField() + punishment = ForeignKeyField(column_name='punishment_id', field='id', model=PunishmentStatus, null=True) + text = TextField() + + class Meta: + table_name = 'order_status' + +class Repeat(BaseModel): + count = IntegerField(default=0) + orders = TextField() + probability = FloatField() + + class Meta: + table_name = 'repeat' + +class SkipDay(BaseModel): + date = DateField(unique=True) + + class Meta: + table_name = 'skip_day' diff --git a/db/queries.py b/db/queries.py new file mode 100644 index 0000000..c549884 --- /dev/null +++ b/db/queries.py @@ -0,0 +1,72 @@ +from .models import database, Repeat, SkipDay, OrderStatus, PunishmentStatus + +def initdb(): + database.connect() + database.create_tables([ + Repeat, + SkipDay, + OrderStatus, + PunishmentStatus + ]) + +def repeat_get(): + try: + return Repeat.get() + except Repeat.DoesNotExist: + return None + +def repeat_increment(): + q = Repeat.update(count=Repeat.count + 1) + return q.execute() + +def repeat_put(probability, orders): + return Repeat.create( + probability=probability, + orders=orders + ) + +def repeat_clear(): + q = Repeat.delete() + q.execute() + +def skip_day_put(date): + return SkipDay.create(date=date) + +def skip_day_contains(date): + q = SkipDay.select().where(SkipDay.date == date) + return len(q) > 0 + +def order_status_put(mastodon_id, created_at, due_at, text): + return OrderStatus.create( + mastodon_id=mastodon_id, + created_at=created_at, + due_at=due_at, + text=text + ) + +def order_status_outstanding(): + return OrderStatus.select().where( + (OrderStatus.confirmed_at.is_null()) & (OrderStatus.punishment_id.is_null()) + ) + +def order_status_confirm(id, confirmed_at): + q = OrderStatus.update( + confirmed_at=confirmed_at + ).where( + OrderStatus.id == id + ) + return q.execute() + +def punishment_status_put(order_status_id, mastodon_id, created_at, text): + punishment_status = PunishmentStatus.create( + mastodon_id=mastodon_id, + created_at=created_at, + text=text + ) + + q = OrderStatus.update( + punishment_id=punishment_status.id + ).where( + OrderStatus.id == order_status_id + ) + q.execute() diff --git a/generate.py b/generate.py index a701514..e32f07e 100644 --- a/generate.py +++ b/generate.py @@ -3,7 +3,7 @@ import json import logging import random -from db import Database +from db.queries import repeat_get, repeat_increment, repeat_clear, repeat_put from settings import ORDERS_YML logger = logging.getLogger(__name__) @@ -45,17 +45,16 @@ def read_config(): def generate_order(): # Do we want to repeat? - db = Database() - repeat = db.repeat_get() + repeat = repeat_get() if repeat is not None: - if repeat['probability'] > random.random(): - db.repeat_increment() + if repeat.probability > random.random(): + repeat_increment() return { - "orders": json.loads(repeat['orders']), - "count": repeat['count'] + "orders": json.loads(repeat.orders), + "count": repeat.count } else: - db.repeat_clear() + repeat_clear() orders_config = read_config() @@ -68,7 +67,7 @@ def generate_order(): # Log the repeat if repeat_p > 0.0: - db.repeat_put(repeat_p, json.dumps(result)) + repeat_put(repeat_p, json.dumps(result)) return { "orders": result diff --git a/main.py b/main.py index e0b97ba..58c841c 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,7 @@ import asyncio from scheduling import OrderScheduler from orders import order_issue, order_check from telegram import handle_commands +from db.queries import initdb logger = logging.getLogger(__name__) @@ -23,6 +24,7 @@ if __name__=='__main__': parser_immediate = subparsers.add_parser('immediate', help='Immediately generate a command') parser_check = subparsers.add_parser('check', help="Checks if any orders are outstanding") + parser_initdb = subparsers.add_parser('initdb', help="Creates the database tables") args = parser.parse_args() @@ -34,6 +36,8 @@ if __name__=='__main__': loop = asyncio.new_event_loop() loop.run_until_complete(order_check()) loop.close() + elif args.command == 'initdb': + initdb() else: loop = asyncio.new_event_loop() s = OrderScheduler(loop) diff --git a/orders.py b/orders.py index bcccc73..5585b90 100644 --- a/orders.py +++ b/orders.py @@ -3,10 +3,10 @@ import datetime from util import make_session from generate import generate_order, generate_punishment -from db import Database +from db.queries import order_status_put, punishment_status_put, order_status_outstanding, order_status_confirm from mastodon import Mastodon from telegram import Telegram -from settings import MASTODON_USERNAME, ORDER_TIMEOUT +from settings import MASTODON_USERNAME, ORDER_TIMEOUT, ENV from util import timezone logger = logging.getLogger(__name__) @@ -17,7 +17,10 @@ async def order_mastodon_post(session, orders_str, repeats, due_at): if repeats > 1: post += f"These are the same orders from the last {repeats} days\n\n" post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n" - post += "CC - @chicagogear @s10boi" + if ENV == 'dev': + post += "⚠️ DEV" + else: + post += "CC - @chicagogear @s10boi" m = Mastodon(session) return await m.statusPost(post) @@ -29,15 +32,19 @@ async def order_telegram_post(session, orders_str, repeats, due_at, m_url): post += f"These are the same orders from the last {repeats} days\n\n" post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n" post += m_url + if ENV == 'dev': + post += "\n⚠️ DEV" t = Telegram(session) await t.message_send(post) async def order_telegram_post_none(session): - t_post = "No orders for today" + post = "No orders for today" + if ENV == 'dev': + post += "\n⚠️ DEV" t = Telegram(session) - await t.message_send(t_post) + await t.message_send(post) async def order_issue(): async with make_session() as session: @@ -70,8 +77,7 @@ async def order_issue(): m_status['url'] ) - db = Database() - db.order_status_put( + order_status_put( m_status['id'], created_at, due_at, @@ -83,7 +89,10 @@ async def order_issue(): async def punishment_mastodon_post(session, punishment_str, reply_id=None): post = "@%s has failed to post proof of compliance. Here is the punishment -\n\n" % MASTODON_USERNAME post += punishment_str + "\n\n" - post += "CC - @chicagogear @s10boi" + if ENV == 'dev': + post += "⚠️ DEV" + else: + post += "CC - @chicagogear @s10boi" m = Mastodon(session) return await m.statusPost( @@ -95,18 +104,20 @@ async def punishment_telegram_post(session, punishment_str, m_url): post = "You failed to show proof of compliance. Here is your punishment -\n\n" post += punishment_str + "\n\n" post += m_url + if ENV == 'dev': + post += "\n\n⚠️ DEV" t = Telegram(session) await t.message_send(post) -async def punishment_issue(db, session, outstanding_order): +async def punishment_issue(session, outstanding_order): punishment = generate_punishment() punishment_str = "\n".join(punishment) punishment_status = await punishment_mastodon_post( session, punishment_str, - outstanding_order['mastodon_id'], + outstanding_order.mastodon_id, ) await punishment_telegram_post( @@ -115,8 +126,8 @@ async def punishment_issue(db, session, outstanding_order): punishment_status['url'] ) - db.punishment_status_put( - outstanding_order['id'], + punishment_status_put( + outstanding_order.id, punishment_status['id'], punishment_status['created_at'], punishment_str @@ -124,30 +135,28 @@ async def punishment_issue(db, session, outstanding_order): async def order_check(): async with make_session() as session: - db = Database() - - outstanding_orders = db.order_status_outstanding() + outstanding_orders = order_status_outstanding() for outstanding_order in outstanding_orders: m = Mastodon(session) - context = await m.statusContext(outstanding_order['mastodon_id']) + context = await m.statusContext(outstanding_order.mastodon_id) confirmed_at = None for d in context['descendants']: if ( - d['in_reply_to_id'] == outstanding_order['mastodon_id'] and + d['in_reply_to_id'] == outstanding_order.mastodon_id and d['account']['username'] == MASTODON_USERNAME and len(d['media_attachments']) > 0 ): confirmed_at = d['created_at'] - db.order_status_confirm(outstanding_order['id'], confirmed_at) - logger.info('Confirmed order %s' % (outstanding_order['id'])) + order_status_confirm(outstanding_order.id, confirmed_at) + logger.info('Confirmed order %s' % (outstanding_order.id)) break if confirmed_at is None: - logger.info('Order %s remains unconfirmed' % (outstanding_order['id'])) + logger.info('Order %s remains unconfirmed' % (outstanding_order.id)) - due_at = datetime.datetime.fromisoformat(outstanding_order['due_at']) + due_at = datetime.datetime.fromisoformat(outstanding_order.due_at) if(due_at < datetime.datetime.now(datetime.UTC)): - logger.info('Time to issue a punishment for %s' % outstanding_order['id']) + logger.info('Time to issue a punishment for %s' % outstanding_order.id) - await punishment_issue(db, session, outstanding_order) + await punishment_issue(session, outstanding_order) diff --git a/scheduling.py b/scheduling.py index dda7b02..a3e3ac8 100644 --- a/scheduling.py +++ b/scheduling.py @@ -6,7 +6,7 @@ from scheduler.asyncio import Scheduler from settings import TIMEZONE from orders import order_issue, order_check -from db import Database +from db.queries import order_status_outstanding, skip_day_contains from util import order_time logger = logging.getLogger(__name__) @@ -27,8 +27,7 @@ class OrderScheduler(): self.scheduler.daily(order_time_dt, self.scheduled_order) # Schedule any outstanding orders - db = Database() - outstanding_orders = db.order_status_outstanding() + outstanding_orders = order_status_outstanding() for oo in outstanding_orders: self.scheduler.once( datetime.datetime.fromisoformat(oo['due_at']) + GRACE_PERIOD, @@ -45,10 +44,9 @@ class OrderScheduler(): return # Skip stored dates - d = Database() today = datetime.datetime.now(tz=self.tz).strftime("%Y-%m-%d") logger.info('Today %s' % today) - if (d.skip_day_contains(today)): + if (skip_day_contains(today)): logger.info('Today is a skip day') return diff --git a/settings.py b/settings.py index ad4460d..f3bd9e9 100644 --- a/settings.py +++ b/settings.py @@ -1,6 +1,8 @@ import os import datetime +ENV = os.environ.get('ENV', 'dev') + ORDER_TIME = os.environ.get('ORDER_TIME', '9:00') ORDER_TIMEOUT = datetime.timedelta( hours=os.environ.get('ORDER_TIMEOUT', 3) diff --git a/telegram.py b/telegram.py index c1883d7..b93b51e 100644 --- a/telegram.py +++ b/telegram.py @@ -3,7 +3,7 @@ import logging import asyncio from settings import TELEGRAM_API_TOKEN, TELEGRAM_CHAT_ID -from db import Database +from db.queries import skip_day_put from util import make_session logger = logging.getLogger(__name__) @@ -75,8 +75,7 @@ class SkipDayAddCommand(TelegramCommand): async def exec(self, text, update, session): date_str = text.split(' ')[1] - db = Database() - db.skip_day_put(date_str) + skip_day_put(date_str) yield f"Added skip day {date_str}"