Use peewee ORM
This commit is contained in:
parent
38ea6c9bbd
commit
9f5cec705e
12 changed files with 174 additions and 193 deletions
1
Pipfile
1
Pipfile
|
|
@ -8,6 +8,7 @@ pyyaml = "*"
|
|||
aiohttp = "*"
|
||||
scheduler = "*"
|
||||
pytz = "*"
|
||||
peewee = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
||||
|
|
|
|||
9
Pipfile.lock
generated
9
Pipfile.lock
generated
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "97e23fc67642a95bffea8782f93d340d26a92a440feae5e4b7d10dedbec5cccb"
|
||||
"sha256": "3189468ccffbb23866656384675ad9f79f061393b93965ce86aa2c54483e4776"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
|
|
@ -463,6 +463,13 @@
|
|||
"markers": "python_version >= '3.9'",
|
||||
"version": "==6.7.0"
|
||||
},
|
||||
"peewee": {
|
||||
"hashes": [
|
||||
"sha256:62c3d93315b1a909360c4b43c3a573b47557a1ec7a4583a71286df2a28d4b72e"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==3.18.3"
|
||||
},
|
||||
"propcache": {
|
||||
"hashes": [
|
||||
"sha256:0002004213ee1f36cfb3f9a42b5066100c44276b9b72b4e1504cddd3d692e86e",
|
||||
|
|
|
|||
152
db.py
152
db.py
|
|
@ -1,152 +0,0 @@
|
|||
import sqlite3
|
||||
|
||||
from settings import SQLITE_DB
|
||||
|
||||
TABLE_REPEAT = '''
|
||||
CREATE TABLE IF NOT EXISTS repeat (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
probability FLOAT NOT NULL,
|
||||
orders TEXT NOT NULL,
|
||||
count INTEGER DEFAULT 0
|
||||
);
|
||||
'''
|
||||
|
||||
TABLE_SKIP_DAY = '''
|
||||
CREATE TABLE IF NOT EXISTS skip_day (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
date TIMESTAMP UNIQUE NOT NULL
|
||||
);
|
||||
'''
|
||||
|
||||
TABLE_ORDER_STATUS = '''
|
||||
CREATE TABLE IF NOT EXISTS order_status (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mastodon_id TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
due_at TIMESTAMP NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
confirmed_at TIMESTAMP,
|
||||
punishment_id INTEGER,
|
||||
FOREIGN KEY(punishment_id) REFERENCES punishment_status(id)
|
||||
);
|
||||
'''
|
||||
|
||||
TABLE_PUNISHMENT_STATUS = '''
|
||||
CREATE TABLE IF NOT EXISTS punishment_status (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
mastodon_id TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
confirmed_at TIMESTAMP
|
||||
);
|
||||
'''
|
||||
|
||||
class Database:
|
||||
def __init__(self):
|
||||
self.conn = sqlite3.connect(SQLITE_DB)
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
self.table_init(TABLE_REPEAT)
|
||||
self.table_init(TABLE_SKIP_DAY)
|
||||
self.table_init(TABLE_PUNISHMENT_STATUS)
|
||||
self.table_init(TABLE_ORDER_STATUS)
|
||||
|
||||
def table_init(self, table_sql):
|
||||
c = self.conn.cursor()
|
||||
c.execute(table_sql)
|
||||
self.conn.commit()
|
||||
|
||||
def update(self, sql, args=[]):
|
||||
c = self.conn.cursor()
|
||||
c.execute(sql, args)
|
||||
self.conn.commit()
|
||||
return c.lastrowid
|
||||
|
||||
def repeat_get(self):
|
||||
c = self.conn.cursor()
|
||||
sql = 'SELECT id, probability, orders, count FROM repeat LIMIT 1'
|
||||
c.execute(sql)
|
||||
return c.fetchone()
|
||||
|
||||
def repeat_increment(self):
|
||||
self.update('UPDATE repeat SET count = count + 1')
|
||||
|
||||
def repeat_put(self, probability, orders):
|
||||
self.update(
|
||||
'INSERT INTO repeat (probability, orders) VALUES (?, ?);',
|
||||
[probability, orders]
|
||||
)
|
||||
|
||||
def repeat_clear(self):
|
||||
self.update('DELETE FROM repeat')
|
||||
|
||||
def skip_day_put(self, date):
|
||||
self.update(
|
||||
'INSERT INTO skip_day (date) VALUES (?);',
|
||||
[date]
|
||||
)
|
||||
|
||||
def skip_day_contains(self, date):
|
||||
c = self.conn.cursor()
|
||||
sql = 'SELECT * FROM skip_day WHERE date=?;'
|
||||
c.execute(sql, [date])
|
||||
return c.fetchone() is not None
|
||||
|
||||
def order_status_put(self, mastodon_id, created_at, due_at, text):
|
||||
self.update(
|
||||
'''
|
||||
INSERT INTO order_status
|
||||
(mastodon_id, created_at, due_at, text)
|
||||
VALUES (?, ?, ?, ?);
|
||||
''',
|
||||
[
|
||||
mastodon_id,
|
||||
created_at,
|
||||
due_at,
|
||||
text
|
||||
]
|
||||
)
|
||||
|
||||
def order_status_outstanding(self):
|
||||
c = self.conn.cursor()
|
||||
sql = '''
|
||||
SELECT id, mastodon_id, created_at, due_at, confirmed_at
|
||||
FROM order_status
|
||||
WHERE confirmed_at IS NULL AND punishment_id IS NULL
|
||||
'''
|
||||
c.execute(sql)
|
||||
return c.fetchall()
|
||||
|
||||
def order_status_confirm(self, id, confirmed_at):
|
||||
self.update(
|
||||
'''
|
||||
UPDATE order_status
|
||||
SET confirmed_at=?
|
||||
WHERE id=?;
|
||||
''',
|
||||
[
|
||||
confirmed_at,
|
||||
id
|
||||
]
|
||||
)
|
||||
|
||||
def punishment_status_put(self, order_status_id, mastodon_id, created_at, text):
|
||||
punishment_status_id = self.update(
|
||||
'''
|
||||
INSERT INTO punishment_status
|
||||
(mastodon_id, created_at, text)
|
||||
VALUES (?, ?, ?);
|
||||
''',
|
||||
[
|
||||
mastodon_id,
|
||||
created_at,
|
||||
text
|
||||
]
|
||||
)
|
||||
self.update(
|
||||
'''
|
||||
UPDATE order_status
|
||||
SET punishment_id=?
|
||||
WHERE id=?
|
||||
''',
|
||||
[punishment_status_id, order_status_id]
|
||||
)
|
||||
0
db/__init__.py
Normal file
0
db/__init__.py
Normal file
42
db/models.py
Normal file
42
db/models.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from peewee import *
|
||||
from settings import SQLITE_DB
|
||||
|
||||
database = SqliteDatabase(SQLITE_DB)
|
||||
|
||||
class BaseModel(Model):
|
||||
class Meta:
|
||||
database = database
|
||||
|
||||
class PunishmentStatus(BaseModel):
|
||||
confirmed_at = DateTimeField(null=True) # TIMESTAMP
|
||||
created_at = DateTimeField() # TIMESTAMP
|
||||
mastodon_id = TextField()
|
||||
text = TextField()
|
||||
|
||||
class Meta:
|
||||
table_name = 'punishment_status'
|
||||
|
||||
class OrderStatus(BaseModel):
|
||||
confirmed_at = DateTimeField(null=True) # TIMESTAMP
|
||||
created_at = DateTimeField() # TIMESTAMP
|
||||
due_at = DateTimeField() # TIMESTAMP
|
||||
mastodon_id = TextField()
|
||||
punishment = ForeignKeyField(column_name='punishment_id', field='id', model=PunishmentStatus, null=True)
|
||||
text = TextField()
|
||||
|
||||
class Meta:
|
||||
table_name = 'order_status'
|
||||
|
||||
class Repeat(BaseModel):
|
||||
count = IntegerField(default=0)
|
||||
orders = TextField()
|
||||
probability = FloatField()
|
||||
|
||||
class Meta:
|
||||
table_name = 'repeat'
|
||||
|
||||
class SkipDay(BaseModel):
|
||||
date = DateField(unique=True)
|
||||
|
||||
class Meta:
|
||||
table_name = 'skip_day'
|
||||
72
db/queries.py
Normal file
72
db/queries.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
from .models import database, Repeat, SkipDay, OrderStatus, PunishmentStatus
|
||||
|
||||
def initdb():
|
||||
database.connect()
|
||||
database.create_tables([
|
||||
Repeat,
|
||||
SkipDay,
|
||||
OrderStatus,
|
||||
PunishmentStatus
|
||||
])
|
||||
|
||||
def repeat_get():
|
||||
try:
|
||||
return Repeat.get()
|
||||
except Repeat.DoesNotExist:
|
||||
return None
|
||||
|
||||
def repeat_increment():
|
||||
q = Repeat.update(count=Repeat.count + 1)
|
||||
return q.execute()
|
||||
|
||||
def repeat_put(probability, orders):
|
||||
return Repeat.create(
|
||||
probability=probability,
|
||||
orders=orders
|
||||
)
|
||||
|
||||
def repeat_clear():
|
||||
q = Repeat.delete()
|
||||
q.execute()
|
||||
|
||||
def skip_day_put(date):
|
||||
return SkipDay.create(date=date)
|
||||
|
||||
def skip_day_contains(date):
|
||||
q = SkipDay.select().where(SkipDay.date == date)
|
||||
return len(q) > 0
|
||||
|
||||
def order_status_put(mastodon_id, created_at, due_at, text):
|
||||
return OrderStatus.create(
|
||||
mastodon_id=mastodon_id,
|
||||
created_at=created_at,
|
||||
due_at=due_at,
|
||||
text=text
|
||||
)
|
||||
|
||||
def order_status_outstanding():
|
||||
return OrderStatus.select().where(
|
||||
(OrderStatus.confirmed_at.is_null()) & (OrderStatus.punishment_id.is_null())
|
||||
)
|
||||
|
||||
def order_status_confirm(id, confirmed_at):
|
||||
q = OrderStatus.update(
|
||||
confirmed_at=confirmed_at
|
||||
).where(
|
||||
OrderStatus.id == id
|
||||
)
|
||||
return q.execute()
|
||||
|
||||
def punishment_status_put(order_status_id, mastodon_id, created_at, text):
|
||||
punishment_status = PunishmentStatus.create(
|
||||
mastodon_id=mastodon_id,
|
||||
created_at=created_at,
|
||||
text=text
|
||||
)
|
||||
|
||||
q = OrderStatus.update(
|
||||
punishment_id=punishment_status.id
|
||||
).where(
|
||||
OrderStatus.id == order_status_id
|
||||
)
|
||||
q.execute()
|
||||
17
generate.py
17
generate.py
|
|
@ -3,7 +3,7 @@ import json
|
|||
import logging
|
||||
import random
|
||||
|
||||
from db import Database
|
||||
from db.queries import repeat_get, repeat_increment, repeat_clear, repeat_put
|
||||
from settings import ORDERS_YML
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -45,17 +45,16 @@ def read_config():
|
|||
|
||||
def generate_order():
|
||||
# Do we want to repeat?
|
||||
db = Database()
|
||||
repeat = db.repeat_get()
|
||||
repeat = repeat_get()
|
||||
if repeat is not None:
|
||||
if repeat['probability'] > random.random():
|
||||
db.repeat_increment()
|
||||
if repeat.probability > random.random():
|
||||
repeat_increment()
|
||||
return {
|
||||
"orders": json.loads(repeat['orders']),
|
||||
"count": repeat['count']
|
||||
"orders": json.loads(repeat.orders),
|
||||
"count": repeat.count
|
||||
}
|
||||
else:
|
||||
db.repeat_clear()
|
||||
repeat_clear()
|
||||
|
||||
orders_config = read_config()
|
||||
|
||||
|
|
@ -68,7 +67,7 @@ def generate_order():
|
|||
|
||||
# Log the repeat
|
||||
if repeat_p > 0.0:
|
||||
db.repeat_put(repeat_p, json.dumps(result))
|
||||
repeat_put(repeat_p, json.dumps(result))
|
||||
|
||||
return {
|
||||
"orders": result
|
||||
|
|
|
|||
4
main.py
4
main.py
|
|
@ -6,6 +6,7 @@ import asyncio
|
|||
from scheduling import OrderScheduler
|
||||
from orders import order_issue, order_check
|
||||
from telegram import handle_commands
|
||||
from db.queries import initdb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -23,6 +24,7 @@ if __name__=='__main__':
|
|||
|
||||
parser_immediate = subparsers.add_parser('immediate', help='Immediately generate a command')
|
||||
parser_check = subparsers.add_parser('check', help="Checks if any orders are outstanding")
|
||||
parser_initdb = subparsers.add_parser('initdb', help="Creates the database tables")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -34,6 +36,8 @@ if __name__=='__main__':
|
|||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(order_check())
|
||||
loop.close()
|
||||
elif args.command == 'initdb':
|
||||
initdb()
|
||||
else:
|
||||
loop = asyncio.new_event_loop()
|
||||
s = OrderScheduler(loop)
|
||||
|
|
|
|||
55
orders.py
55
orders.py
|
|
@ -3,10 +3,10 @@ import datetime
|
|||
|
||||
from util import make_session
|
||||
from generate import generate_order, generate_punishment
|
||||
from db import Database
|
||||
from db.queries import order_status_put, punishment_status_put, order_status_outstanding, order_status_confirm
|
||||
from mastodon import Mastodon
|
||||
from telegram import Telegram
|
||||
from settings import MASTODON_USERNAME, ORDER_TIMEOUT
|
||||
from settings import MASTODON_USERNAME, ORDER_TIMEOUT, ENV
|
||||
from util import timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -17,7 +17,10 @@ async def order_mastodon_post(session, orders_str, repeats, due_at):
|
|||
if repeats > 1:
|
||||
post += f"These are the same orders from the last {repeats} days\n\n"
|
||||
post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n"
|
||||
post += "CC - @chicagogear @s10boi"
|
||||
if ENV == 'dev':
|
||||
post += "⚠️ DEV"
|
||||
else:
|
||||
post += "CC - @chicagogear @s10boi"
|
||||
|
||||
m = Mastodon(session)
|
||||
return await m.statusPost(post)
|
||||
|
|
@ -29,15 +32,19 @@ async def order_telegram_post(session, orders_str, repeats, due_at, m_url):
|
|||
post += f"These are the same orders from the last {repeats} days\n\n"
|
||||
post += "Proof of compliance is due by " + due_at.strftime("%I:%M %p") + "\n\n"
|
||||
post += m_url
|
||||
if ENV == 'dev':
|
||||
post += "\n⚠️ DEV"
|
||||
|
||||
t = Telegram(session)
|
||||
await t.message_send(post)
|
||||
|
||||
async def order_telegram_post_none(session):
|
||||
t_post = "No orders for today"
|
||||
post = "No orders for today"
|
||||
if ENV == 'dev':
|
||||
post += "\n⚠️ DEV"
|
||||
|
||||
t = Telegram(session)
|
||||
await t.message_send(t_post)
|
||||
await t.message_send(post)
|
||||
|
||||
async def order_issue():
|
||||
async with make_session() as session:
|
||||
|
|
@ -70,8 +77,7 @@ async def order_issue():
|
|||
m_status['url']
|
||||
)
|
||||
|
||||
db = Database()
|
||||
db.order_status_put(
|
||||
order_status_put(
|
||||
m_status['id'],
|
||||
created_at,
|
||||
due_at,
|
||||
|
|
@ -83,7 +89,10 @@ async def order_issue():
|
|||
async def punishment_mastodon_post(session, punishment_str, reply_id=None):
|
||||
post = "@%s has failed to post proof of compliance. Here is the punishment -\n\n" % MASTODON_USERNAME
|
||||
post += punishment_str + "\n\n"
|
||||
post += "CC - @chicagogear @s10boi"
|
||||
if ENV == 'dev':
|
||||
post += "⚠️ DEV"
|
||||
else:
|
||||
post += "CC - @chicagogear @s10boi"
|
||||
|
||||
m = Mastodon(session)
|
||||
return await m.statusPost(
|
||||
|
|
@ -95,18 +104,20 @@ async def punishment_telegram_post(session, punishment_str, m_url):
|
|||
post = "You failed to show proof of compliance. Here is your punishment -\n\n"
|
||||
post += punishment_str + "\n\n"
|
||||
post += m_url
|
||||
if ENV == 'dev':
|
||||
post += "\n\n⚠️ DEV"
|
||||
|
||||
t = Telegram(session)
|
||||
await t.message_send(post)
|
||||
|
||||
async def punishment_issue(db, session, outstanding_order):
|
||||
async def punishment_issue(session, outstanding_order):
|
||||
punishment = generate_punishment()
|
||||
punishment_str = "\n".join(punishment)
|
||||
|
||||
punishment_status = await punishment_mastodon_post(
|
||||
session,
|
||||
punishment_str,
|
||||
outstanding_order['mastodon_id'],
|
||||
outstanding_order.mastodon_id,
|
||||
)
|
||||
|
||||
await punishment_telegram_post(
|
||||
|
|
@ -115,8 +126,8 @@ async def punishment_issue(db, session, outstanding_order):
|
|||
punishment_status['url']
|
||||
)
|
||||
|
||||
db.punishment_status_put(
|
||||
outstanding_order['id'],
|
||||
punishment_status_put(
|
||||
outstanding_order.id,
|
||||
punishment_status['id'],
|
||||
punishment_status['created_at'],
|
||||
punishment_str
|
||||
|
|
@ -124,30 +135,28 @@ async def punishment_issue(db, session, outstanding_order):
|
|||
|
||||
async def order_check():
|
||||
async with make_session() as session:
|
||||
db = Database()
|
||||
|
||||
outstanding_orders = db.order_status_outstanding()
|
||||
outstanding_orders = order_status_outstanding()
|
||||
for outstanding_order in outstanding_orders:
|
||||
m = Mastodon(session)
|
||||
context = await m.statusContext(outstanding_order['mastodon_id'])
|
||||
context = await m.statusContext(outstanding_order.mastodon_id)
|
||||
|
||||
confirmed_at = None
|
||||
for d in context['descendants']:
|
||||
if (
|
||||
d['in_reply_to_id'] == outstanding_order['mastodon_id'] and
|
||||
d['in_reply_to_id'] == outstanding_order.mastodon_id and
|
||||
d['account']['username'] == MASTODON_USERNAME and
|
||||
len(d['media_attachments']) > 0
|
||||
):
|
||||
confirmed_at = d['created_at']
|
||||
db.order_status_confirm(outstanding_order['id'], confirmed_at)
|
||||
logger.info('Confirmed order %s' % (outstanding_order['id']))
|
||||
order_status_confirm(outstanding_order.id, confirmed_at)
|
||||
logger.info('Confirmed order %s' % (outstanding_order.id))
|
||||
break
|
||||
|
||||
if confirmed_at is None:
|
||||
logger.info('Order %s remains unconfirmed' % (outstanding_order['id']))
|
||||
logger.info('Order %s remains unconfirmed' % (outstanding_order.id))
|
||||
|
||||
due_at = datetime.datetime.fromisoformat(outstanding_order['due_at'])
|
||||
due_at = datetime.datetime.fromisoformat(outstanding_order.due_at)
|
||||
if(due_at < datetime.datetime.now(datetime.UTC)):
|
||||
logger.info('Time to issue a punishment for %s' % outstanding_order['id'])
|
||||
logger.info('Time to issue a punishment for %s' % outstanding_order.id)
|
||||
|
||||
await punishment_issue(db, session, outstanding_order)
|
||||
await punishment_issue(session, outstanding_order)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from scheduler.asyncio import Scheduler
|
|||
|
||||
from settings import TIMEZONE
|
||||
from orders import order_issue, order_check
|
||||
from db import Database
|
||||
from db.queries import order_status_outstanding, skip_day_contains
|
||||
from util import order_time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -27,8 +27,7 @@ class OrderScheduler():
|
|||
self.scheduler.daily(order_time_dt, self.scheduled_order)
|
||||
|
||||
# Schedule any outstanding orders
|
||||
db = Database()
|
||||
outstanding_orders = db.order_status_outstanding()
|
||||
outstanding_orders = order_status_outstanding()
|
||||
for oo in outstanding_orders:
|
||||
self.scheduler.once(
|
||||
datetime.datetime.fromisoformat(oo['due_at']) + GRACE_PERIOD,
|
||||
|
|
@ -45,10 +44,9 @@ class OrderScheduler():
|
|||
return
|
||||
|
||||
# Skip stored dates
|
||||
d = Database()
|
||||
today = datetime.datetime.now(tz=self.tz).strftime("%Y-%m-%d")
|
||||
logger.info('Today %s' % today)
|
||||
if (d.skip_day_contains(today)):
|
||||
if (skip_day_contains(today)):
|
||||
logger.info('Today is a skip day')
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import datetime
|
||||
|
||||
ENV = os.environ.get('ENV', 'dev')
|
||||
|
||||
ORDER_TIME = os.environ.get('ORDER_TIME', '9:00')
|
||||
ORDER_TIMEOUT = datetime.timedelta(
|
||||
hours=os.environ.get('ORDER_TIMEOUT', 3)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import asyncio
|
||||
|
||||
from settings import TELEGRAM_API_TOKEN, TELEGRAM_CHAT_ID
|
||||
from db import Database
|
||||
from db.queries import skip_day_put
|
||||
from util import make_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -75,8 +75,7 @@ class SkipDayAddCommand(TelegramCommand):
|
|||
async def exec(self, text, update, session):
|
||||
date_str = text.split(' ')[1]
|
||||
|
||||
db = Database()
|
||||
db.skip_day_put(date_str)
|
||||
skip_day_put(date_str)
|
||||
|
||||
yield f"Added skip day {date_str}"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue