|
| 1 | +import datetime |
| 2 | +import logging |
| 3 | +from typing import List, Iterator |
| 4 | + |
| 5 | +from sqlalchemy import (Column, String, DateTime, Date, ForeignKey, Table, MetaData, |
| 6 | + select, func) |
| 7 | +from sqlalchemy.engine.base import Engine |
| 8 | + |
| 9 | +from ovgumensabot.meal import Meal |
| 10 | +from ovgumensabot.menu import Menu |
| 11 | + |
| 12 | + |
| 13 | +class MenuDatabase: |
| 14 | + menus: Table |
| 15 | + meals: Table |
| 16 | + subscriptions: Table |
| 17 | + db: Engine |
| 18 | + |
| 19 | + def __init__(self, db: Engine) -> None: |
| 20 | + self.db = db |
| 21 | + |
| 22 | + meta = MetaData() |
| 23 | + meta.bind = db |
| 24 | + |
| 25 | + self.menus = Table("menus", meta, |
| 26 | + Column("day", Date, primary_key=True), |
| 27 | + Column("last_updated", DateTime, nullable=False)) |
| 28 | + |
| 29 | + self.meals = Table("meals", meta, |
| 30 | + Column("menu_day", Date, ForeignKey("menus.day", ondelete="CASCADE", primary_key=True)), |
| 31 | + Column("price", String(255), nullable=False), |
| 32 | + Column("name", String(255), nullable=False)) |
| 33 | + |
| 34 | + self.subscriptions = Table("subscriptions", meta, |
| 35 | + Column("room_id", String(255), primary_key=True)) |
| 36 | + |
| 37 | + meta.create_all() |
| 38 | + |
| 39 | + def upsert_menu(self, menu: Menu) -> bool: |
| 40 | + logging.getLogger("maubot").info(f"Inserted menu from {menu.day} into database.") |
| 41 | + with self.db.begin() as tx: |
| 42 | + if self.menu_day_exists(menu): |
| 43 | + tx.execute(self.menus.update() |
| 44 | + .where(self.menus.c.day == menu.day).values(day=menu.day, last_updated=menu.last_updated)) |
| 45 | + tx.execute(self.meals.delete().where(self.meals.c.menu_day == menu.day)) |
| 46 | + tx.execute(self.meals.insert(), |
| 47 | + [{"menu_day": menu.day, "price": meal.price, |
| 48 | + "name": meal.name} |
| 49 | + for meal in menu.meals]) |
| 50 | + return False # menu was already existent |
| 51 | + else: |
| 52 | + tx.execute(self.menus.insert() |
| 53 | + .values(day=menu.day, last_updated=menu.last_updated)) |
| 54 | + tx.execute(self.meals.insert(), |
| 55 | + [{"menu_day": menu.day, "price": meal.price, |
| 56 | + "name": meal.name} |
| 57 | + for meal in menu.meals]) |
| 58 | + return True # menu is new |
| 59 | + |
| 60 | + def add_meals_to_menu(self, menu: Menu): |
| 61 | + logging.getLogger("maubot").info(f"Add meals to menu from {menu.day}.") |
| 62 | + if self.menu_day_exists(menu): |
| 63 | + self.db.execute(self.meals.insert(), |
| 64 | + [{"menu_day": menu.day, "price": meal.price, |
| 65 | + "name": meal.name} |
| 66 | + for meal in menu.meals]) |
| 67 | + |
| 68 | + def subscriptions_not_empty(self) -> bool: |
| 69 | + rows = self.db.execute(select([func.count()]).select(self.subscriptions)).scalar() |
| 70 | + return rows and rows > 0 |
| 71 | + |
| 72 | + def subscription_exists(self, room_id: str) -> bool: |
| 73 | + rows = self.db.execute( |
| 74 | + select([func.count()]).select(self.subscriptions).where(self.subscriptions.c.room_id == room_id)).scalar() |
| 75 | + return rows and rows > 0 |
| 76 | + |
| 77 | + def menu_day_exists(self, menu: Menu) -> bool: |
| 78 | + rows = self.db.execute(select([func.count()]).select(self.menus).where(self.menus.c.day == menu.day)).scalar() |
| 79 | + return rows and rows > 0 |
| 80 | + |
| 81 | + def get_menu_on_day(self, day: datetime.date) -> Iterator[Menu]: |
| 82 | + logging.getLogger("maubot").info(f"Search for day {day}") |
| 83 | + menu_rows = self.db.execute(select([self.menus]).where(self.menus.c.day.like(day))) |
| 84 | + return self._rows_to_menus(menu_rows) |
| 85 | + |
| 86 | + def get_latest_menu(self) -> Iterator[Menu]: |
| 87 | + menu_row = self.db.execute(select([self.menus]).order_by(self.menus.c.last_updated.desc()).limit(1)) |
| 88 | + logging.getLogger("maubot").info(f"first menu_row {menu_row}") |
| 89 | + return self._rows_to_menus(menu_row) |
| 90 | + |
| 91 | + def _rows_to_menus(self, menu_rows) -> Iterator[Menu]: |
| 92 | + for menu_row in menu_rows: |
| 93 | + logging.getLogger("maubot").info(f"menu_row {menu_row}") |
| 94 | + meal_rows = self.db.execute(select([self.meals]).where(self.meals.c.menu_day == menu_row[0])) |
| 95 | + meals_of_the_day = [] |
| 96 | + for meal_row in meal_rows: |
| 97 | + meals_of_the_day.append(Meal(menu_day=meal_row[0], price=meal_row[1], name=meal_row[2])) |
| 98 | + yield Menu(day=menu_row[0], last_updated=menu_row[1], meals=meals_of_the_day) |
| 99 | + |
| 100 | + def insert_subscription(self, room_id: str) -> None: |
| 101 | + self.db.execute(self.subscriptions.insert().values(room_id=room_id)) |
| 102 | + |
| 103 | + def get_subscriptions(self) -> List: |
| 104 | + rows = self.db.execute(self.subscriptions.select()) |
| 105 | + for row in rows: |
| 106 | + yield row[0] |
| 107 | + |
| 108 | + def delete_subscription(self, room_id: str) -> None: |
| 109 | + self.db.execute(self.subscriptions.delete().where(self.subscriptions.c.room_id == room_id)) |
| 110 | + |
| 111 | + def get_menu_days(self) -> Iterator[datetime.date]: |
| 112 | + rows = self.db.execute(select([self.menus.c.day])) |
| 113 | + for row in rows: |
| 114 | + yield row[0] |
0 commit comments