mirror of
https://github.com/nihilvux/bancho.py.git
synced 2025-10-07 17:00:18 -07:00
114 lines
3.1 KiB
Python
114 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TypedDict
|
|
from typing import cast
|
|
|
|
from sqlalchemy import Column
|
|
from sqlalchemy import Integer
|
|
from sqlalchemy import String
|
|
from sqlalchemy import func
|
|
from sqlalchemy import insert
|
|
from sqlalchemy import select
|
|
from sqlalchemy import update
|
|
from sqlalchemy.dialects.mysql import TINYINT
|
|
|
|
import app.state.services
|
|
from app.repositories import Base
|
|
|
|
|
|
class MailTable(Base):
|
|
__tablename__ = "mail"
|
|
|
|
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
|
from_id = Column("from_id", Integer, nullable=False)
|
|
to_id = Column("to_id", Integer, nullable=False)
|
|
msg = Column("msg", String(2048, collation="utf8"), nullable=False)
|
|
time = Column("time", Integer, nullable=True)
|
|
read = Column("read", TINYINT(1), nullable=False, server_default="0")
|
|
|
|
|
|
READ_PARAMS = (
|
|
MailTable.id,
|
|
MailTable.from_id,
|
|
MailTable.to_id,
|
|
MailTable.msg,
|
|
MailTable.time,
|
|
MailTable.read,
|
|
)
|
|
|
|
|
|
class Mail(TypedDict):
|
|
id: int
|
|
from_id: int
|
|
to_id: int
|
|
msg: str
|
|
time: int
|
|
read: bool
|
|
|
|
|
|
class MailWithUsernames(Mail):
|
|
from_name: str
|
|
to_name: str
|
|
|
|
|
|
async def create(from_id: int, to_id: int, msg: str) -> Mail:
|
|
"""Create a new mail entry in the database."""
|
|
insert_stmt = insert(MailTable).values(
|
|
from_id=from_id,
|
|
to_id=to_id,
|
|
msg=msg,
|
|
time=func.unix_timestamp(),
|
|
)
|
|
rec_id = await app.state.services.database.execute(insert_stmt)
|
|
|
|
select_stmt = select(*READ_PARAMS).where(MailTable.id == rec_id)
|
|
mail = await app.state.services.database.fetch_one(select_stmt)
|
|
assert mail is not None
|
|
return cast(Mail, mail)
|
|
|
|
|
|
from app.repositories.users import UsersTable
|
|
|
|
|
|
async def fetch_all_mail_to_user(
|
|
user_id: int,
|
|
read: bool | None = None,
|
|
) -> list[MailWithUsernames]:
|
|
"""Fetch all of mail to a given target from the database."""
|
|
from_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.from_id)
|
|
to_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.to_id)
|
|
|
|
select_stmt = select(
|
|
*READ_PARAMS,
|
|
from_subquery.label("from_name"),
|
|
to_subquery.label("to_name"),
|
|
).where(MailTable.to_id == user_id)
|
|
|
|
if read is not None:
|
|
select_stmt = select_stmt.where(MailTable.read == read)
|
|
|
|
mail = await app.state.services.database.fetch_all(select_stmt)
|
|
return cast(list[MailWithUsernames], mail)
|
|
|
|
|
|
async def mark_conversation_as_read(to_id: int, from_id: int) -> list[Mail]:
|
|
"""Mark any mail in a user's conversation with another user as read."""
|
|
select_stmt = select(*READ_PARAMS).where(
|
|
MailTable.to_id == to_id,
|
|
MailTable.from_id == from_id,
|
|
MailTable.read == False,
|
|
)
|
|
mail = await app.state.services.database.fetch_all(select_stmt)
|
|
if not mail:
|
|
return []
|
|
|
|
update_stmt = (
|
|
update(MailTable)
|
|
.where(MailTable.to_id == to_id)
|
|
.where(MailTable.from_id == from_id)
|
|
.where(MailTable.read == False)
|
|
.values(read=True)
|
|
)
|
|
await app.state.services.database.execute(update_stmt)
|
|
return cast(list[Mail], mail)
|