Use peewee ORM

This commit is contained in:
Johnny Gear 2025-11-13 22:03:21 -06:00
parent 38ea6c9bbd
commit 9f5cec705e
12 changed files with 174 additions and 193 deletions

View file

@ -8,6 +8,7 @@ pyyaml = "*"
aiohttp = "*" aiohttp = "*"
scheduler = "*" scheduler = "*"
pytz = "*" pytz = "*"
peewee = "*"
[dev-packages] [dev-packages]

9
Pipfile.lock generated
View file

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "97e23fc67642a95bffea8782f93d340d26a92a440feae5e4b7d10dedbec5cccb" "sha256": "3189468ccffbb23866656384675ad9f79f061393b93965ce86aa2c54483e4776"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -463,6 +463,13 @@
"markers": "python_version >= '3.9'", "markers": "python_version >= '3.9'",
"version": "==6.7.0" "version": "==6.7.0"
}, },
"peewee": {
"hashes": [
"sha256:62c3d93315b1a909360c4b43c3a573b47557a1ec7a4583a71286df2a28d4b72e"
],
"index": "pypi",
"version": "==3.18.3"
},
"propcache": { "propcache": {
"hashes": [ "hashes": [
"sha256:0002004213ee1f36cfb3f9a42b5066100c44276b9b72b4e1504cddd3d692e86e", "sha256:0002004213ee1f36cfb3f9a42b5066100c44276b9b72b4e1504cddd3d692e86e",

152
db.py
View file

@ -1,152 +0,0 @@
import sqlite3
from settings import SQLITE_DB
TABLE_REPEAT = '''
CREATE TABLE IF NOT EXISTS repeat (
id INTEGER PRIMARY KEY AUTOINCREMENT,
probability FLOAT NOT NULL,
orders TEXT NOT NULL,
count INTEGER DEFAULT 0
);
'''
TABLE_SKIP_DAY = '''
CREATE TABLE IF NOT EXISTS skip_day (
id INTEGER PRIMARY KEY AUTOINCREMENT,
date TIMESTAMP UNIQUE NOT NULL
);
'''
TABLE_ORDER_STATUS = '''
CREATE TABLE IF NOT EXISTS order_status (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mastodon_id TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
due_at TIMESTAMP NOT NULL,
text TEXT NOT NULL,
confirmed_at TIMESTAMP,
punishment_id INTEGER,
FOREIGN KEY(punishment_id) REFERENCES punishment_status(id)
);
'''
TABLE_PUNISHMENT_STATUS = '''
CREATE TABLE IF NOT EXISTS punishment_status (
id INTEGER PRIMARY KEY AUTOINCREMENT,
mastodon_id TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
text TEXT NOT NULL,
confirmed_at TIMESTAMP
);
'''
class Database:
def __init__(self):
self.conn = sqlite3.connect(SQLITE_DB)
self.conn.row_factory = sqlite3.Row
self.table_init(TABLE_REPEAT)
self.table_init(TABLE_SKIP_DAY)
self.table_init(TABLE_PUNISHMENT_STATUS)
self.table_init(TABLE_ORDER_STATUS)
def table_init(self, table_sql):
c = self.conn.cursor()
c.execute(table_sql)
self.conn.commit()
def update(self, sql, args=[]):
c = self.conn.cursor()
c.execute(sql, args)
self.conn.commit()
return c.lastrowid
def repeat_get(self):
c = self.conn.cursor()
sql = 'SELECT id, probability, orders, count FROM repeat LIMIT 1'
c.execute(sql)
return c.fetchone()
def repeat_increment(self):
self.update('UPDATE repeat SET count = count + 1')
def repeat_put(self, probability, orders):
self.update(
'INSERT INTO repeat (probability, orders) VALUES (?, ?);',
[probability, orders]
)
def repeat_clear(self):
self.update('DELETE FROM repeat')
def skip_day_put(self, date):
self.update(
'INSERT INTO skip_day (date) VALUES (?);',
[date]
)
def skip_day_contains(self, date):
c = self.conn.cursor()
sql = 'SELECT * FROM skip_day WHERE date=?;'
c.execute(sql, [date])
return c.fetchone() is not None
def order_status_put(self, mastodon_id, created_at, due_at, text):
self.update(
'''
INSERT INTO order_status
(mastodon_id, created_at, due_at, text)
VALUES (?, ?, ?, ?);
''',
[
mastodon_id,
created_at,
due_at,
text
]
)
def order_status_outstanding(self):
c = self.conn.cursor()
sql = '''
SELECT id, mastodon_id, created_at, due_at, confirmed_at
FROM order_status
WHERE confirmed_at IS NULL AND punishment_id IS NULL
'''
c.execute(sql)
return c.fetchall()
def order_status_confirm(self, id, confirmed_at):
self.update(
'''
UPDATE order_status
SET confirmed_at=?
WHERE id=?;
''',
[
confirmed_at,
id
]
)
def punishment_status_put(self, order_status_id, mastodon_id, created_at, text):
punishment_status_id = self.update(
'''
INSERT INTO punishment_status
(mastodon_id, created_at, text)
VALUES (?, ?, ?);
''',
[
mastodon_id,
created_at,
text
]
)
self.update(
'''
UPDATE order_status
SET punishment_id=?
WHERE id=?
''',
[punishment_status_id, order_status_id]
)

0
db/__init__.py Normal file
View file

42
db/models.py Normal file
View file

@ -0,0 +1,42 @@
from peewee import *
from settings import SQLITE_DB
database = SqliteDatabase(SQLITE_DB)
class BaseModel(Model):
class Meta:
database = database
class PunishmentStatus(BaseModel):
confirmed_at = DateTimeField(null=True) # TIMESTAMP
created_at = DateTimeField() # TIMESTAMP
mastodon_id = TextField()
text = TextField()
class Meta:
table_name = 'punishment_status'
class OrderStatus(BaseModel):
confirmed_at = DateTimeField(null=True) # TIMESTAMP
created_at = DateTimeField() # TIMESTAMP
due_at = DateTimeField() # TIMESTAMP
mastodon_id = TextField()
punishment = ForeignKeyField(column_name='punishment_id', field='id', model=PunishmentStatus, null=True)
text = TextField()
class Meta:
table_name = 'order_status'
class Repeat(BaseModel):
count = IntegerField(default=0)
orders = TextField()
probability = FloatField()
class Meta:
table_name = 'repeat'
class SkipDay(BaseModel):
date = DateField(unique=True)
class Meta:
table_name = 'skip_day'

72
db/queries.py Normal file
View file

@ -0,0 +1,72 @@
from .models import database, Repeat, SkipDay, OrderStatus, PunishmentStatus
def initdb():
database.connect()
database.create_tables([
Repeat,
SkipDay,
OrderStatus,
PunishmentStatus
])
def repeat_get():
try:
return Repeat.get()
except Repeat.DoesNotExist:
return None
def repeat_increment():
q = Repeat.update(count=Repeat.count + 1)
return q.execute()
def repeat_put(probability, orders):
return Repeat.create(
probability=probability,
orders=orders
)
def repeat_clear():
q = Repeat.delete()
q.execute()
def skip_day_put(date):
return SkipDay.create(date=date)
def skip_day_contains(date):
q = SkipDay.select().where(SkipDay.date == date)
return len(q) > 0
def order_status_put(mastodon_id, created_at, due_at, text):
return OrderStatus.create(
mastodon_id=mastodon_id,
created_at=created_at,
due_at=due_at,
text=text
)
def order_status_outstanding():
return OrderStatus.select().where(
(OrderStatus.confirmed_at.is_null()) & (OrderStatus.punishment_id.is_null())
)
def order_status_confirm(id, confirmed_at):
q = OrderStatus.update(
confirmed_at=confirmed_at
).where(
OrderStatus.id == id
)
return q.execute()
def punishment_status_put(order_status_id, mastodon_id, created_at, text):
punishment_status = PunishmentStatus.create(
mastodon_id=mastodon_id,
created_at=created_at,
text=text
)
q = OrderStatus.update(
punishment_id=punishment_status.id
).where(
OrderStatus.id == order_status_id
)
q.execute()

View file

@ -3,7 +3,7 @@ import json
import logging import logging
import random import random
from db import Database from db.queries import repeat_get, repeat_increment, repeat_clear, repeat_put
from settings import ORDERS_YML from settings import ORDERS_YML
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,17 +45,16 @@ def read_config():
def generate_order(): def generate_order():
# Do we want to repeat? # Do we want to repeat?
db = Database() repeat = repeat_get()
repeat = db.repeat_get()
if repeat is not None: if repeat is not None:
if repeat['probability'] > random.random(): if repeat.probability > random.random():
db.repeat_increment() repeat_increment()
return { return {
"orders": json.loads(repeat['orders']), "orders": json.loads(repeat.orders),
"count": repeat['count'] "count": repeat.count
} }
else: else:
db.repeat_clear() repeat_clear()
orders_config = read_config() orders_config = read_config()
@ -68,7 +67,7 @@ def generate_order():
# Log the repeat # Log the repeat
if repeat_p > 0.0: if repeat_p > 0.0:
db.repeat_put(repeat_p, json.dumps(result)) repeat_put(repeat_p, json.dumps(result))
return { return {
"orders": result "orders": result

View file

@ -6,6 +6,7 @@ import asyncio
from scheduling import OrderScheduler from scheduling import OrderScheduler
from orders import order_issue, order_check from orders import order_issue, order_check
from telegram import handle_commands from telegram import handle_commands
from db.queries import initdb
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -23,6 +24,7 @@ if __name__=='__main__':
parser_immediate = subparsers.add_parser('immediate', help='Immediately generate a command') parser_immediate = subparsers.add_parser('immediate', help='Immediately generate a command')
parser_check = subparsers.add_parser('check', help="Checks if any orders are outstanding") parser_check = subparsers.add_parser('check', help="Checks if any orders are outstanding")
parser_initdb = subparsers.add_parser('initdb', help="Creates the database tables")
args = parser.parse_args() args = parser.parse_args()
@ -34,6 +36,8 @@ if __name__=='__main__':
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
loop.run_until_complete(order_check()) loop.run_until_complete(order_check())
loop.close() loop.close()
elif args.command == 'initdb':
initdb()
else: else:
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
s = OrderScheduler(loop) s = OrderScheduler(loop)

View file

@ -3,10 +3,10 @@ import datetime
from util import make_session from util import make_session
from generate import generate_order, generate_punishment from generate import generate_order, generate_punishment
from db import Database from db.queries import order_status_put, punishment_status_put, order_status_outstanding, order_status_confirm
from mastodon import Mastodon from mastodon import Mastodon
from telegram import Telegram from telegram import Telegram
from settings import MASTODON_USERNAME, ORDER_TIMEOUT from settings import MASTODON_USERNAME, ORDER_TIMEOUT, ENV
from util import timezone from util import timezone
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -17,6 +17,9 @@ async def order_mastodon_post(session, orders_str, repeats, due_at):
if repeats > 1: if repeats > 1:
post += f"These are the same orders from the last {repeats} days\n\n" post += f"These are the same orders from the last {repeats} days\n\n"
post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n" post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n"
if ENV == 'dev':
post += "⚠️ DEV"
else:
post += "CC - @chicagogear @s10boi" post += "CC - @chicagogear @s10boi"
m = Mastodon(session) m = Mastodon(session)
@ -29,15 +32,19 @@ async def order_telegram_post(session, orders_str, repeats, due_at, m_url):
post += f"These are the same orders from the last {repeats} days\n\n" post += f"These are the same orders from the last {repeats} days\n\n"
post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n" post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n"
post += m_url post += m_url
if ENV == 'dev':
post += "\n⚠️ DEV"
t = Telegram(session) t = Telegram(session)
await t.message_send(post) await t.message_send(post)
async def order_telegram_post_none(session): async def order_telegram_post_none(session):
t_post = "No orders for today" post = "No orders for today"
if ENV == 'dev':
post += "\n⚠️ DEV"
t = Telegram(session) t = Telegram(session)
await t.message_send(t_post) await t.message_send(post)
async def order_issue(): async def order_issue():
async with make_session() as session: async with make_session() as session:
@ -70,8 +77,7 @@ async def order_issue():
m_status['url'] m_status['url']
) )
db = Database() order_status_put(
db.order_status_put(
m_status['id'], m_status['id'],
created_at, created_at,
due_at, due_at,
@ -83,6 +89,9 @@ async def order_issue():
async def punishment_mastodon_post(session, punishment_str, reply_id=None): async def punishment_mastodon_post(session, punishment_str, reply_id=None):
post = "@%s has failed to post proof of compliance. Here is the punishment -\n\n" % MASTODON_USERNAME post = "@%s has failed to post proof of compliance. Here is the punishment -\n\n" % MASTODON_USERNAME
post += punishment_str + "\n\n" post += punishment_str + "\n\n"
if ENV == 'dev':
post += "⚠️ DEV"
else:
post += "CC - @chicagogear @s10boi" post += "CC - @chicagogear @s10boi"
m = Mastodon(session) m = Mastodon(session)
@ -95,18 +104,20 @@ async def punishment_telegram_post(session, punishment_str, m_url):
post = "You failed to show proof of compliance. Here is your punishment -\n\n" post = "You failed to show proof of compliance. Here is your punishment -\n\n"
post += punishment_str + "\n\n" post += punishment_str + "\n\n"
post += m_url post += m_url
if ENV == 'dev':
post += "\n\n⚠️ DEV"
t = Telegram(session) t = Telegram(session)
await t.message_send(post) await t.message_send(post)
async def punishment_issue(db, session, outstanding_order): async def punishment_issue(session, outstanding_order):
punishment = generate_punishment() punishment = generate_punishment()
punishment_str = "\n".join(punishment) punishment_str = "\n".join(punishment)
punishment_status = await punishment_mastodon_post( punishment_status = await punishment_mastodon_post(
session, session,
punishment_str, punishment_str,
outstanding_order['mastodon_id'], outstanding_order.mastodon_id,
) )
await punishment_telegram_post( await punishment_telegram_post(
@ -115,8 +126,8 @@ async def punishment_issue(db, session, outstanding_order):
punishment_status['url'] punishment_status['url']
) )
db.punishment_status_put( punishment_status_put(
outstanding_order['id'], outstanding_order.id,
punishment_status['id'], punishment_status['id'],
punishment_status['created_at'], punishment_status['created_at'],
punishment_str punishment_str
@ -124,30 +135,28 @@ async def punishment_issue(db, session, outstanding_order):
async def order_check(): async def order_check():
async with make_session() as session: async with make_session() as session:
db = Database() outstanding_orders = order_status_outstanding()
outstanding_orders = db.order_status_outstanding()
for outstanding_order in outstanding_orders: for outstanding_order in outstanding_orders:
m = Mastodon(session) m = Mastodon(session)
context = await m.statusContext(outstanding_order['mastodon_id']) context = await m.statusContext(outstanding_order.mastodon_id)
confirmed_at = None confirmed_at = None
for d in context['descendants']: for d in context['descendants']:
if ( if (
d['in_reply_to_id'] == outstanding_order['mastodon_id'] and d['in_reply_to_id'] == outstanding_order.mastodon_id and
d['account']['username'] == MASTODON_USERNAME and d['account']['username'] == MASTODON_USERNAME and
len(d['media_attachments']) > 0 len(d['media_attachments']) > 0
): ):
confirmed_at = d['created_at'] confirmed_at = d['created_at']
db.order_status_confirm(outstanding_order['id'], confirmed_at) order_status_confirm(outstanding_order.id, confirmed_at)
logger.info('Confirmed order %s' % (outstanding_order['id'])) logger.info('Confirmed order %s' % (outstanding_order.id))
break break
if confirmed_at is None: if confirmed_at is None:
logger.info('Order %s remains unconfirmed' % (outstanding_order['id'])) logger.info('Order %s remains unconfirmed' % (outstanding_order.id))
due_at = datetime.datetime.fromisoformat(outstanding_order['due_at']) due_at = datetime.datetime.fromisoformat(outstanding_order.due_at)
if(due_at < datetime.datetime.now(datetime.UTC)): if(due_at < datetime.datetime.now(datetime.UTC)):
logger.info('Time to issue a punishment for %s' % outstanding_order['id']) logger.info('Time to issue a punishment for %s' % outstanding_order.id)
await punishment_issue(db, session, outstanding_order) await punishment_issue(session, outstanding_order)

View file

@ -6,7 +6,7 @@ from scheduler.asyncio import Scheduler
from settings import TIMEZONE from settings import TIMEZONE
from orders import order_issue, order_check from orders import order_issue, order_check
from db import Database from db.queries import order_status_outstanding, skip_day_contains
from util import order_time from util import order_time
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,8 +27,7 @@ class OrderScheduler():
self.scheduler.daily(order_time_dt, self.scheduled_order) self.scheduler.daily(order_time_dt, self.scheduled_order)
# Schedule any outstanding orders # Schedule any outstanding orders
db = Database() outstanding_orders = order_status_outstanding()
outstanding_orders = db.order_status_outstanding()
for oo in outstanding_orders: for oo in outstanding_orders:
self.scheduler.once( self.scheduler.once(
datetime.datetime.fromisoformat(oo['due_at']) + GRACE_PERIOD, datetime.datetime.fromisoformat(oo['due_at']) + GRACE_PERIOD,
@ -45,10 +44,9 @@ class OrderScheduler():
return return
# Skip stored dates # Skip stored dates
d = Database()
today = datetime.datetime.now(tz=self.tz).strftime("%Y-%m-%d") today = datetime.datetime.now(tz=self.tz).strftime("%Y-%m-%d")
logger.info('Today %s' % today) logger.info('Today %s' % today)
if (d.skip_day_contains(today)): if (skip_day_contains(today)):
logger.info('Today is a skip day') logger.info('Today is a skip day')
return return

View file

@ -1,6 +1,8 @@
import os import os
import datetime import datetime
ENV = os.environ.get('ENV', 'dev')
ORDER_TIME = os.environ.get('ORDER_TIME', '9:00') ORDER_TIME = os.environ.get('ORDER_TIME', '9:00')
ORDER_TIMEOUT = datetime.timedelta( ORDER_TIMEOUT = datetime.timedelta(
hours=os.environ.get('ORDER_TIMEOUT', 3) hours=os.environ.get('ORDER_TIMEOUT', 3)

View file

@ -3,7 +3,7 @@ import logging
import asyncio import asyncio
from settings import TELEGRAM_API_TOKEN, TELEGRAM_CHAT_ID from settings import TELEGRAM_API_TOKEN, TELEGRAM_CHAT_ID
from db import Database from db.queries import skip_day_put
from util import make_session from util import make_session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -75,8 +75,7 @@ class SkipDayAddCommand(TelegramCommand):
async def exec(self, text, update, session): async def exec(self, text, update, session):
date_str = text.split(' ')[1] date_str = text.split(' ')[1]
db = Database() skip_day_put(date_str)
db.skip_day_put(date_str)
yield f"Added skip day {date_str}" yield f"Added skip day {date_str}"