mirror of
https://github.com/nihilvux/bancho.py.git
synced 2025-09-17 02:58:39 -07:00
Add files via upload
This commit is contained in:
1
CODEOWNERS
Normal file
1
CODEOWNERS
Normal file
@@ -0,0 +1 @@
|
||||
* @cmyui @kingdom5500 @NiceAesth @tsunyoku @7mochi
|
22
Dockerfile
Normal file
22
Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
WORKDIR /srv/root
|
||||
|
||||
RUN apt update && apt install --no-install-recommends -y \
|
||||
git curl build-essential=12.9 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
RUN pip install -U pip poetry
|
||||
RUN poetry config virtualenvs.create false
|
||||
RUN poetry install --no-root
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y default-mysql-client redis-tools
|
||||
|
||||
# NOTE: done last to avoid re-run of previous steps
|
||||
COPY . .
|
||||
|
||||
ENTRYPOINT [ "scripts/start_server.sh" ]
|
47
Makefile
Normal file
47
Makefile
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env make
|
||||
|
||||
build:
|
||||
if [ -d ".dbdata" ]; then sudo chmod -R 755 .dbdata; fi
|
||||
docker build -t bancho:latest .
|
||||
|
||||
run:
|
||||
docker compose up bancho mysql redis
|
||||
|
||||
run-bg:
|
||||
docker compose up -d bancho mysql redis
|
||||
|
||||
run-caddy:
|
||||
caddy run --envfile .env --config ext/Caddyfile
|
||||
|
||||
last?=1
|
||||
logs:
|
||||
docker compose logs -f bancho mysql redis --tail ${last}
|
||||
|
||||
shell:
|
||||
poetry shell
|
||||
|
||||
test:
|
||||
docker compose -f docker-compose.test.yml up -d bancho-test mysql-test redis-test
|
||||
docker compose -f docker-compose.test.yml exec -T bancho-test /srv/root/scripts/run-tests.sh
|
||||
|
||||
lint:
|
||||
poetry run pre-commit run --all-files
|
||||
|
||||
type-check:
|
||||
poetry run mypy .
|
||||
|
||||
install:
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=1 poetry install --no-root
|
||||
|
||||
install-dev:
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=1 poetry install --no-root --with dev
|
||||
poetry run pre-commit install
|
||||
|
||||
uninstall:
|
||||
poetry env remove python
|
||||
|
||||
# To bump the version number run `make bump version=<major/minor/patch>`
|
||||
# (DO NOT USE IF YOU DON'T KNOW WHAT YOU'RE DOING)
|
||||
# https://python-poetry.org/docs/cli/#version
|
||||
bump:
|
||||
poetry version $(version)
|
23
README_CN.md
Normal file
23
README_CN.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# bancho.py - 中文文档
|
||||
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/ambv/black)
|
||||
[](https://results.pre-commit.ci/latest/github/osuAkatsuki/bancho.py/master)
|
||||
[](https://discord.gg/ShEQgUx)
|
||||
|
||||
The English version: [[English]](https://github.com/osuAkatsuki/bancho.py/blob/master/README.md)
|
||||
|
||||
这是中文翻译哦~由 [hedgehog-qd](https://github.com/hedgehog-qd) 在根据原英语文档部署成功后翻译的。这里
|
||||
我根据我当时遇到的问题补充了一些提示,如有错误请指正,谢谢!
|
||||
|
||||
bancho.py 是一个还在被不断维护的osu!后端项目,不论你的水平如何,都
|
||||
可以去使用他来开一个自己的osu!私服!
|
||||
|
||||
这个项目最初是由 [Akatsuki](https://akatsuki.pw/) 团队开发的,我们的目标是创建一个非常容易
|
||||
维护并且功能很丰富的osu!私服的服务端!
|
||||
|
||||
注意:bancho.py是一个后端!当你跟着下面的步骤部署完成后你可以正常登录
|
||||
并游玩。这个项目自带api,但是没有前端(就是网页),前端的话你也可以去看
|
||||
他们团队开发的前端项目。
|
||||
api文档(英语):<https://github.com/JKBGL/gulag-api-docs>
|
||||
前端(guweb):<https://github.com/Varkaria/guweb>
|
14
README_DE.MD
Normal file
14
README_DE.MD
Normal file
@@ -0,0 +1,14 @@
|
||||
# bancho.py
|
||||
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/ambv/black)
|
||||
[](https://results.pre-commit.ci/latest/github/osuAkatsuki/bancho.py/master)
|
||||
[](https://discord.gg/ShEQgUx)
|
||||
|
||||
bancho.py ist eine in Arbeit befindliche osu!-Server-Implementierung für
|
||||
Entwickler aller Erfahrungsstufen, die daran interessiert sind, ihre eigene(n)
|
||||
private(n) osu-Server-Instanz(en) zu hosten
|
||||
|
||||
Das Projekt wird hauptsächlich vom [Akatsuki](https://akatsuki.pw/)-Team entwickelt,
|
||||
und unser Ziel ist es, die am einfachsten zu wartende, zuverlässigste und
|
||||
funktionsreichste osu!-Server-Implementierung auf dem Markt zu schaffen.
|
13
app/__init__.py
Normal file
13
app/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# type: ignore
|
||||
# isort: dont-add-imports
|
||||
|
||||
from . import api
|
||||
from . import bg_loops
|
||||
from . import commands
|
||||
from . import constants
|
||||
from . import discord
|
||||
from . import logging
|
||||
from . import objects
|
||||
from . import packets
|
||||
from . import state
|
||||
from . import utils
|
27
app/_typing.py
Normal file
27
app/_typing.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from ipaddress import IPv6Address
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
IPAddress = IPv4Address | IPv6Address
|
||||
|
||||
|
||||
class _UnsetSentinel:
|
||||
def __repr__(self) -> str:
|
||||
return "Unset"
|
||||
|
||||
def __copy__(self: T) -> T:
|
||||
return self
|
||||
|
||||
def __reduce__(self) -> str:
|
||||
return "Unset"
|
||||
|
||||
def __deepcopy__(self: T, _: Any) -> T:
|
||||
return self
|
||||
|
||||
|
||||
UNSET = _UnsetSentinel()
|
0
app/adapters/__init__.py
Normal file
0
app/adapters/__init__.py
Normal file
164
app/adapters/database.py
Normal file
164
app/adapters/database.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from databases import Database as _Database
|
||||
from databases.core import Transaction
|
||||
from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb
|
||||
from sqlalchemy.sql.compiler import Compiled
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
|
||||
from app import settings
|
||||
from app.logging import log
|
||||
from app.timer import Timer
|
||||
|
||||
|
||||
class MySQLDialect(MySQLDialect_mysqldb):
|
||||
default_paramstyle = "named"
|
||||
|
||||
|
||||
DIALECT = MySQLDialect()
|
||||
|
||||
MySQLRow = dict[str, Any]
|
||||
MySQLParams = dict[str, Any] | None
|
||||
MySQLQuery = ClauseElement | str
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, url: str) -> None:
|
||||
self._database = _Database(url)
|
||||
|
||||
async def connect(self) -> None:
|
||||
await self._database.connect()
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
await self._database.disconnect()
|
||||
|
||||
def _compile(self, clause_element: ClauseElement) -> tuple[str, MySQLParams]:
|
||||
compiled: Compiled = clause_element.compile(
|
||||
dialect=DIALECT,
|
||||
compile_kwargs={"render_postcompile": True},
|
||||
)
|
||||
return str(compiled), compiled.params
|
||||
|
||||
async def fetch_one(
|
||||
self,
|
||||
query: MySQLQuery,
|
||||
params: MySQLParams = None,
|
||||
) -> MySQLRow | None:
|
||||
if isinstance(query, ClauseElement):
|
||||
query, params = self._compile(query)
|
||||
|
||||
with Timer() as timer:
|
||||
row = await self._database.fetch_one(query, params)
|
||||
|
||||
if settings.DEBUG:
|
||||
time_elapsed = timer.elapsed()
|
||||
log(
|
||||
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
|
||||
extra={
|
||||
"query": query,
|
||||
"params": params,
|
||||
"time_elapsed": time_elapsed,
|
||||
},
|
||||
)
|
||||
|
||||
return dict(row._mapping) if row is not None else None
|
||||
|
||||
async def fetch_all(
|
||||
self,
|
||||
query: MySQLQuery,
|
||||
params: MySQLParams = None,
|
||||
) -> list[MySQLRow]:
|
||||
if isinstance(query, ClauseElement):
|
||||
query, params = self._compile(query)
|
||||
|
||||
with Timer() as timer:
|
||||
rows = await self._database.fetch_all(query, params)
|
||||
|
||||
if settings.DEBUG:
|
||||
time_elapsed = timer.elapsed()
|
||||
log(
|
||||
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
|
||||
extra={
|
||||
"query": query,
|
||||
"params": params,
|
||||
"time_elapsed": time_elapsed,
|
||||
},
|
||||
)
|
||||
|
||||
return [dict(row._mapping) for row in rows]
|
||||
|
||||
async def fetch_val(
|
||||
self,
|
||||
query: MySQLQuery,
|
||||
params: MySQLParams = None,
|
||||
column: Any = 0,
|
||||
) -> Any:
|
||||
if isinstance(query, ClauseElement):
|
||||
query, params = self._compile(query)
|
||||
|
||||
with Timer() as timer:
|
||||
val = await self._database.fetch_val(query, params, column)
|
||||
|
||||
if settings.DEBUG:
|
||||
time_elapsed = timer.elapsed()
|
||||
log(
|
||||
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
|
||||
extra={
|
||||
"query": query,
|
||||
"params": params,
|
||||
"time_elapsed": time_elapsed,
|
||||
},
|
||||
)
|
||||
|
||||
return val
|
||||
|
||||
async def execute(self, query: MySQLQuery, params: MySQLParams = None) -> int:
|
||||
if isinstance(query, ClauseElement):
|
||||
query, params = self._compile(query)
|
||||
|
||||
with Timer() as timer:
|
||||
rec_id = await self._database.execute(query, params)
|
||||
|
||||
if settings.DEBUG:
|
||||
time_elapsed = timer.elapsed()
|
||||
log(
|
||||
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
|
||||
extra={
|
||||
"query": query,
|
||||
"params": params,
|
||||
"time_elapsed": time_elapsed,
|
||||
},
|
||||
)
|
||||
|
||||
return cast(int, rec_id)
|
||||
|
||||
# NOTE: this accepts str since current execute_many uses are not using alchemy.
|
||||
# alchemy does execute_many in a single query so this method will be unneeded once raw SQL is not in use.
|
||||
async def execute_many(self, query: str, params: list[MySQLParams]) -> None:
|
||||
if isinstance(query, ClauseElement):
|
||||
query, _ = self._compile(query)
|
||||
|
||||
with Timer() as timer:
|
||||
await self._database.execute_many(query, params)
|
||||
|
||||
if settings.DEBUG:
|
||||
time_elapsed = timer.elapsed()
|
||||
log(
|
||||
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
|
||||
extra={
|
||||
"query": query,
|
||||
"params": params,
|
||||
"time_elapsed": time_elapsed,
|
||||
},
|
||||
)
|
||||
|
||||
def transaction(
|
||||
self,
|
||||
*,
|
||||
force_rollback: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Transaction:
|
||||
return self._database.transaction(force_rollback=force_rollback, **kwargs)
|
16
app/api/__init__.py
Normal file
16
app/api/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# type: ignore
|
||||
# isort: dont-add-imports
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .v1 import apiv1_router
|
||||
from .v2 import apiv2_router
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(apiv1_router)
|
||||
api_router.include_router(apiv2_router)
|
||||
|
||||
from . import domains
|
||||
from . import init_api
|
||||
from . import middlewares
|
5
app/api/domains/__init__.py
Normal file
5
app/api/domains/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# isort: dont-add-imports
|
||||
|
||||
from . import cho
|
||||
from . import map
|
||||
from . import osu
|
2227
app/api/domains/cho.py
Normal file
2227
app/api/domains/cho.py
Normal file
File diff suppressed because it is too large
Load Diff
22
app/api/domains/map.py
Normal file
22
app/api/domains/map.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""bmap: static beatmap info (thumbnails, previews, etc.)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import status
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
# import app.settings
|
||||
|
||||
router = APIRouter(tags=["Beatmaps"])
|
||||
|
||||
|
||||
# forward any unmatched request to osu!
|
||||
# eventually if we do bmap submission, we'll need this.
|
||||
@router.get("/{file_path:path}")
|
||||
async def everything(request: Request) -> RedirectResponse:
|
||||
return RedirectResponse(
|
||||
url=f"https://b.ppy.sh{request['path']}",
|
||||
status_code=status.HTTP_301_MOVED_PERMANENTLY,
|
||||
)
|
1786
app/api/domains/osu.py
Normal file
1786
app/api/domains/osu.py
Normal file
File diff suppressed because it is too large
Load Diff
196
app/api/init_api.py
Normal file
196
app/api/init_api.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# #!/usr/bin/env python3.11
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import pprint
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
import starlette.routing
|
||||
from fastapi import FastAPI
|
||||
from fastapi import status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from fastapi.responses import Response
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
from starlette.requests import ClientDisconnect
|
||||
|
||||
import app.bg_loops
|
||||
import app.settings
|
||||
import app.state
|
||||
import app.utils
|
||||
from app.api import api_router # type: ignore[attr-defined]
|
||||
from app.api import domains
|
||||
from app.api import middlewares
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
from app.objects import collections
|
||||
|
||||
|
||||
class BanchoAPI(FastAPI):
|
||||
def openapi(self) -> dict[str, Any]:
|
||||
if not self.openapi_schema:
|
||||
routes = self.routes
|
||||
starlette_hosts = [
|
||||
host
|
||||
for host in super().routes
|
||||
if isinstance(host, starlette.routing.Host)
|
||||
]
|
||||
|
||||
# XXX:HACK fastapi will not show documentation for routes
|
||||
# added through use sub applications using the Host class
|
||||
# (e.g. app.host('other.domain', app2))
|
||||
for host in starlette_hosts:
|
||||
for route in host.routes:
|
||||
if route not in routes:
|
||||
routes.append(route)
|
||||
|
||||
self.openapi_schema = get_openapi(
|
||||
title=self.title,
|
||||
version=self.version,
|
||||
openapi_version=self.openapi_version,
|
||||
description=self.description,
|
||||
terms_of_service=self.terms_of_service,
|
||||
contact=self.contact,
|
||||
license_info=self.license_info,
|
||||
routes=routes,
|
||||
tags=self.openapi_tags,
|
||||
servers=self.servers,
|
||||
)
|
||||
|
||||
return self.openapi_schema
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[None]:
|
||||
if isinstance(sys.stdout, io.TextIOWrapper):
|
||||
sys.stdout.reconfigure(encoding="utf-8")
|
||||
|
||||
app.utils.ensure_persistent_volumes_are_available()
|
||||
|
||||
app.state.loop = asyncio.get_running_loop()
|
||||
|
||||
if app.utils.is_running_as_admin():
|
||||
log(
|
||||
"Running the server with root privileges is not recommended.",
|
||||
Ansi.LYELLOW,
|
||||
)
|
||||
|
||||
await app.state.services.database.connect()
|
||||
await app.state.services.redis.initialize()
|
||||
|
||||
if app.state.services.datadog is not None:
|
||||
app.state.services.datadog.start(
|
||||
flush_in_thread=True,
|
||||
flush_interval=15,
|
||||
)
|
||||
app.state.services.datadog.gauge("bancho.online_players", 0)
|
||||
|
||||
app.state.services.ip_resolver = app.state.services.IPResolver()
|
||||
|
||||
await app.state.services.run_sql_migrations()
|
||||
|
||||
await collections.initialize_ram_caches()
|
||||
|
||||
await app.bg_loops.initialize_housekeeping_tasks()
|
||||
|
||||
log("Startup process complete.", Ansi.LGREEN)
|
||||
log(
|
||||
f"Listening @ {app.settings.APP_HOST}:{app.settings.APP_PORT}",
|
||||
Ansi.LMAGENTA,
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# we want to attempt to gracefully finish any ongoing connections
|
||||
# and shut down any of the housekeeping tasks running in the background.
|
||||
await app.state.sessions.cancel_housekeeping_tasks()
|
||||
|
||||
# shutdown services
|
||||
|
||||
await app.state.services.http_client.aclose()
|
||||
await app.state.services.database.disconnect()
|
||||
await app.state.services.redis.aclose()
|
||||
|
||||
if app.state.services.datadog is not None:
|
||||
app.state.services.datadog.stop()
|
||||
app.state.services.datadog.flush()
|
||||
|
||||
|
||||
def init_exception_handlers(asgi_app: BanchoAPI) -> None:
|
||||
@asgi_app.exception_handler(RequestValidationError)
|
||||
async def handle_validation_error(
|
||||
request: Request,
|
||||
exc: RequestValidationError,
|
||||
) -> Response:
|
||||
"""Wrapper around 422 validation errors to print out info for devs."""
|
||||
log(f"Validation error on {request.url}", Ansi.LRED)
|
||||
pprint.pprint(exc.errors())
|
||||
|
||||
return ORJSONResponse(
|
||||
content={"detail": jsonable_encoder(exc.errors())},
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
)
|
||||
|
||||
|
||||
def init_middlewares(asgi_app: BanchoAPI) -> None:
|
||||
"""Initialize our app's middleware stack."""
|
||||
asgi_app.add_middleware(middlewares.MetricsMiddleware)
|
||||
|
||||
@asgi_app.middleware("http")
|
||||
async def http_middleware(
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
# if an osu! client is waiting on leaderboard data
|
||||
# and switches to another leaderboard, it will cancel
|
||||
# the previous request midway, resulting in a large
|
||||
# error in the console. this is to catch that :)
|
||||
|
||||
try:
|
||||
return await call_next(request)
|
||||
except ClientDisconnect:
|
||||
# client disconnected from the server
|
||||
# while we were reading the body.
|
||||
return Response("Client disconnected while reading request.")
|
||||
except RuntimeError as exc:
|
||||
if exc.args[0] == "No response returned.":
|
||||
# client disconnected from the server
|
||||
# while we were sending the response.
|
||||
return Response("Client returned empty response.")
|
||||
|
||||
# unrelated issue, raise normally
|
||||
raise exc
|
||||
|
||||
|
||||
def init_routes(asgi_app: BanchoAPI) -> None:
|
||||
"""Initialize our app's route endpoints."""
|
||||
for domain in ("ppy.sh", app.settings.DOMAIN):
|
||||
for subdomain in ("c", "ce", "c4", "c5", "c6"):
|
||||
asgi_app.host(f"{subdomain}.{domain}", domains.cho.router)
|
||||
|
||||
asgi_app.host(f"osu.{domain}", domains.osu.router)
|
||||
asgi_app.host(f"b.{domain}", domains.map.router)
|
||||
|
||||
# bancho.py's developer-facing api
|
||||
asgi_app.host(f"api.{domain}", api_router)
|
||||
|
||||
|
||||
def init_api() -> BanchoAPI:
|
||||
"""Create & initialize our app."""
|
||||
asgi_app = BanchoAPI(lifespan=lifespan)
|
||||
|
||||
init_middlewares(asgi_app)
|
||||
init_exception_handlers(asgi_app)
|
||||
init_routes(asgi_app)
|
||||
|
||||
return asgi_app
|
||||
|
||||
|
||||
asgi_app = init_api()
|
37
app/api/middlewares.py
Normal file
37
app/api/middlewares.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.base import RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
from app.logging import magnitude_fmt_time
|
||||
|
||||
|
||||
class MetricsMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
start_time = time.perf_counter_ns()
|
||||
response = await call_next(request)
|
||||
end_time = time.perf_counter_ns()
|
||||
|
||||
time_elapsed = end_time - start_time
|
||||
|
||||
col = Ansi.LGREEN if response.status_code < 400 else Ansi.LRED
|
||||
|
||||
url = f"{request.headers['host']}{request['path']}"
|
||||
|
||||
log(
|
||||
f"[{request.method}] {response.status_code} {url}{Ansi.RESET!r} | {Ansi.LBLUE!r}Request took: {magnitude_fmt_time(time_elapsed)}",
|
||||
col,
|
||||
)
|
||||
|
||||
response.headers["process-time"] = str(round(time_elapsed) / 1e6)
|
||||
return response
|
10
app/api/v1/__init__.py
Normal file
10
app/api/v1/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# type: ignore
|
||||
# isort: dont-add-imports
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .api import router
|
||||
|
||||
apiv1_router = APIRouter(tags=["API v1"], prefix="/v1")
|
||||
|
||||
apiv1_router.include_router(router)
|
1080
app/api/v1/api.py
Normal file
1080
app/api/v1/api.py
Normal file
File diff suppressed because it is too large
Load Diff
15
app/api/v2/__init__.py
Normal file
15
app/api/v2/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# isort: dont-add-imports
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import clans
|
||||
from . import maps
|
||||
from . import players
|
||||
from . import scores
|
||||
|
||||
apiv2_router = APIRouter(tags=["API v2"], prefix="/v2")
|
||||
|
||||
apiv2_router.include_router(clans.router)
|
||||
apiv2_router.include_router(maps.router)
|
||||
apiv2_router.include_router(players.router)
|
||||
apiv2_router.include_router(scores.router)
|
50
app/api/v2/clans.py
Normal file
50
app/api/v2/clans.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""bancho.py's v2 apis for interacting with clans"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import status
|
||||
from fastapi.param_functions import Query
|
||||
|
||||
from app.api.v2.common import responses
|
||||
from app.api.v2.common.responses import Failure
|
||||
from app.api.v2.common.responses import Success
|
||||
from app.api.v2.models.clans import Clan
|
||||
from app.repositories import clans as clans_repo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/clans")
|
||||
async def get_clans(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
) -> Success[list[Clan]] | Failure:
|
||||
clans = await clans_repo.fetch_many(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_clans = await clans_repo.fetch_count()
|
||||
|
||||
response = [Clan.from_mapping(rec) for rec in clans]
|
||||
return responses.success(
|
||||
content=response,
|
||||
meta={
|
||||
"total": total_clans,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/clans/{clan_id}")
|
||||
async def get_clan(clan_id: int) -> Success[Clan] | Failure:
|
||||
data = await clans_repo.fetch_one(id=clan_id)
|
||||
if data is None:
|
||||
return responses.failure(
|
||||
message="Clan not found.",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
response = Clan.from_mapping(data)
|
||||
return responses.success(response)
|
29
app/api/v2/common/json.py
Normal file
29
app/api/v2/common/json.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _default_processor(data: Any) -> Any:
|
||||
if isinstance(data, BaseModel):
|
||||
return _default_processor(data.dict())
|
||||
elif isinstance(data, dict):
|
||||
return {k: _default_processor(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [_default_processor(v) for v in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def dumps(data: Any) -> bytes:
|
||||
return orjson.dumps(data, default=_default_processor)
|
||||
|
||||
|
||||
class ORJSONResponse(JSONResponse):
|
||||
media_type = "application/json"
|
||||
|
||||
def render(self, content: Any) -> bytes:
|
||||
return dumps(content)
|
47
app/api/v2/common/responses.py
Normal file
47
app/api/v2/common/responses.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import Literal
|
||||
from typing import TypeVar
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.v2.common import json
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Success(BaseModel, Generic[T]):
|
||||
status: Literal["success"]
|
||||
data: T
|
||||
meta: dict[str, Any]
|
||||
|
||||
|
||||
def success(
|
||||
content: T,
|
||||
status_code: int = 200,
|
||||
headers: dict[str, Any] | None = None,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> Success[T]:
|
||||
if meta is None:
|
||||
meta = {}
|
||||
data = {"status": "success", "data": content, "meta": meta}
|
||||
# XXX:HACK to make typing work
|
||||
return cast(Success[T], json.ORJSONResponse(data, status_code, headers))
|
||||
|
||||
|
||||
class Failure(BaseModel):
|
||||
status: Literal["error"]
|
||||
error: str
|
||||
|
||||
|
||||
def failure(
|
||||
message: str,
|
||||
status_code: int = 400,
|
||||
headers: dict[str, Any] | None = None,
|
||||
) -> Failure:
|
||||
data = {"status": "error", "error": message}
|
||||
# XXX:HACK to make typing work
|
||||
return cast(Failure, json.ORJSONResponse(data, status_code, headers))
|
76
app/api/v2/maps.py
Normal file
76
app/api/v2/maps.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""bancho.py's v2 apis for interacting with maps"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import status
|
||||
from fastapi.param_functions import Query
|
||||
|
||||
from app.api.v2.common import responses
|
||||
from app.api.v2.common.responses import Failure
|
||||
from app.api.v2.common.responses import Success
|
||||
from app.api.v2.models.maps import Map
|
||||
from app.repositories import maps as maps_repo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/maps")
|
||||
async def get_maps(
|
||||
set_id: int | None = None,
|
||||
server: str | None = None,
|
||||
status: int | None = None,
|
||||
artist: str | None = None,
|
||||
creator: str | None = None,
|
||||
filename: str | None = None,
|
||||
mode: int | None = None,
|
||||
frozen: bool | None = None,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
) -> Success[list[Map]] | Failure:
|
||||
maps = await maps_repo.fetch_many(
|
||||
server=server,
|
||||
set_id=set_id,
|
||||
status=status,
|
||||
artist=artist,
|
||||
creator=creator,
|
||||
filename=filename,
|
||||
mode=mode,
|
||||
frozen=frozen,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_maps = await maps_repo.fetch_count(
|
||||
server=server,
|
||||
set_id=set_id,
|
||||
status=status,
|
||||
artist=artist,
|
||||
creator=creator,
|
||||
filename=filename,
|
||||
mode=mode,
|
||||
frozen=frozen,
|
||||
)
|
||||
|
||||
response = [Map.from_mapping(rec) for rec in maps]
|
||||
|
||||
return responses.success(
|
||||
content=response,
|
||||
meta={
|
||||
"total": total_maps,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/maps/{map_id}")
|
||||
async def get_map(map_id: int) -> Success[Map] | Failure:
|
||||
data = await maps_repo.fetch_one(id=map_id)
|
||||
if data is None:
|
||||
return responses.failure(
|
||||
message="Map not found.",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
response = Map.from_mapping(data)
|
||||
return responses.success(response)
|
18
app/api/v2/models/__init__.py
Normal file
18
app/api/v2/models/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# isort: dont-add-imports
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel as _pydantic_BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
T = TypeVar("T", bound="BaseModel")
|
||||
|
||||
|
||||
class BaseModel(_pydantic_BaseModel):
|
||||
model_config = ConfigDict(str_strip_whitespace=True)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls: type[T], mapping: Mapping[str, Any]) -> T:
|
||||
return cls(**{k: mapping[k] for k in cls.model_fields})
|
18
app/api/v2/models/clans.py
Normal file
18
app/api/v2/models/clans.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from . import BaseModel
|
||||
|
||||
# input models
|
||||
|
||||
|
||||
# output models
|
||||
|
||||
|
||||
class Clan(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
tag: str
|
||||
owner: int
|
||||
created_at: datetime
|
36
app/api/v2/models/maps.py
Normal file
36
app/api/v2/models/maps.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from . import BaseModel
|
||||
|
||||
# input models
|
||||
|
||||
|
||||
# output models
|
||||
|
||||
|
||||
class Map(BaseModel):
|
||||
id: int
|
||||
server: str
|
||||
set_id: int
|
||||
status: int
|
||||
md5: str
|
||||
artist: str
|
||||
title: str
|
||||
version: str
|
||||
creator: str
|
||||
filename: str
|
||||
last_update: datetime
|
||||
total_length: int
|
||||
max_combo: int
|
||||
frozen: bool
|
||||
plays: int
|
||||
passes: int
|
||||
mode: int
|
||||
bpm: float
|
||||
cs: float
|
||||
ar: float
|
||||
od: float
|
||||
hp: float
|
||||
diff: float
|
60
app/api/v2/models/players.py
Normal file
60
app/api/v2/models/players.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from . import BaseModel
|
||||
|
||||
# input models
|
||||
|
||||
|
||||
# output models
|
||||
|
||||
|
||||
class Player(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
safe_name: str
|
||||
|
||||
priv: int
|
||||
country: str
|
||||
silence_end: int
|
||||
donor_end: int
|
||||
creation_time: int
|
||||
latest_activity: int
|
||||
|
||||
clan_id: int
|
||||
clan_priv: int
|
||||
|
||||
preferred_mode: int
|
||||
play_style: int
|
||||
|
||||
custom_badge_name: str | None
|
||||
custom_badge_icon: str | None
|
||||
|
||||
userpage_content: str | None
|
||||
|
||||
|
||||
class PlayerStatus(BaseModel):
|
||||
login_time: int
|
||||
action: int
|
||||
info_text: str
|
||||
mode: int
|
||||
mods: int
|
||||
beatmap_id: int
|
||||
|
||||
|
||||
class PlayerStats(BaseModel):
|
||||
id: int
|
||||
mode: int
|
||||
tscore: int
|
||||
rscore: int
|
||||
pp: float
|
||||
plays: int
|
||||
playtime: int
|
||||
acc: float
|
||||
max_combo: int
|
||||
total_hits: int
|
||||
replay_views: int
|
||||
xh_count: int
|
||||
x_count: int
|
||||
sh_count: int
|
||||
s_count: int
|
||||
a_count: int
|
36
app/api/v2/models/scores.py
Normal file
36
app/api/v2/models/scores.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from . import BaseModel
|
||||
|
||||
# input models
|
||||
|
||||
|
||||
# output models
|
||||
|
||||
|
||||
class Score(BaseModel):
|
||||
id: int
|
||||
map_md5: str
|
||||
userid: int
|
||||
|
||||
score: int
|
||||
pp: float
|
||||
acc: float
|
||||
max_combo: int
|
||||
mods: int
|
||||
|
||||
n300: int
|
||||
n100: int
|
||||
n50: int
|
||||
nmiss: int
|
||||
nkatu: int
|
||||
|
||||
grade: str
|
||||
status: int
|
||||
mode: int
|
||||
|
||||
play_time: datetime
|
||||
time_elapsed: int
|
||||
perfect: bool
|
137
app/api/v2/players.py
Normal file
137
app/api/v2/players.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""bancho.py's v2 apis for interacting with players"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import status
|
||||
from fastapi.param_functions import Query
|
||||
|
||||
import app.state.sessions
|
||||
from app.api.v2.common import responses
|
||||
from app.api.v2.common.responses import Failure
|
||||
from app.api.v2.common.responses import Success
|
||||
from app.api.v2.models.players import Player
|
||||
from app.api.v2.models.players import PlayerStats
|
||||
from app.api.v2.models.players import PlayerStatus
|
||||
from app.repositories import stats as stats_repo
|
||||
from app.repositories import users as users_repo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/players")
|
||||
async def get_players(
|
||||
priv: int | None = None,
|
||||
country: str | None = None,
|
||||
clan_id: int | None = None,
|
||||
clan_priv: int | None = None,
|
||||
preferred_mode: int | None = None,
|
||||
play_style: int | None = None,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
) -> Success[list[Player]] | Failure:
|
||||
players = await users_repo.fetch_many(
|
||||
priv=priv,
|
||||
country=country,
|
||||
clan_id=clan_id,
|
||||
clan_priv=clan_priv,
|
||||
preferred_mode=preferred_mode,
|
||||
play_style=play_style,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_players = await users_repo.fetch_count(
|
||||
priv=priv,
|
||||
country=country,
|
||||
clan_id=clan_id,
|
||||
clan_priv=clan_priv,
|
||||
preferred_mode=preferred_mode,
|
||||
play_style=play_style,
|
||||
)
|
||||
|
||||
response = [Player.from_mapping(rec) for rec in players]
|
||||
|
||||
return responses.success(
|
||||
content=response,
|
||||
meta={
|
||||
"total": total_players,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/players/{player_id}")
|
||||
async def get_player(player_id: int) -> Success[Player] | Failure:
|
||||
data = await users_repo.fetch_one(id=player_id)
|
||||
if data is None:
|
||||
return responses.failure(
|
||||
message="Player not found.",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
response = Player.from_mapping(data)
|
||||
return responses.success(response)
|
||||
|
||||
|
||||
@router.get("/players/{player_id}/status")
|
||||
async def get_player_status(player_id: int) -> Success[PlayerStatus] | Failure:
|
||||
player = app.state.sessions.players.get(id=player_id)
|
||||
|
||||
if not player:
|
||||
return responses.failure(
|
||||
message="Player status not found.",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
response = PlayerStatus(
|
||||
login_time=int(player.login_time),
|
||||
action=int(player.status.action),
|
||||
info_text=player.status.info_text,
|
||||
mode=int(player.status.mode),
|
||||
mods=int(player.status.mods),
|
||||
beatmap_id=player.status.map_id,
|
||||
)
|
||||
return responses.success(response)
|
||||
|
||||
|
||||
@router.get("/players/{player_id}/stats/{mode}")
|
||||
async def get_player_mode_stats(
|
||||
player_id: int,
|
||||
mode: int,
|
||||
) -> Success[PlayerStats] | Failure:
|
||||
data = await stats_repo.fetch_one(player_id, mode)
|
||||
if data is None:
|
||||
return responses.failure(
|
||||
message="Player stats not found.",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
response = PlayerStats.from_mapping(data)
|
||||
return responses.success(response)
|
||||
|
||||
|
||||
@router.get("/players/{player_id}/stats")
|
||||
async def get_player_stats(
|
||||
player_id: int,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
) -> Success[list[PlayerStats]] | Failure:
|
||||
data = await stats_repo.fetch_many(
|
||||
player_id=player_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_stats = await stats_repo.fetch_count(
|
||||
player_id=player_id,
|
||||
)
|
||||
|
||||
response = [PlayerStats.from_mapping(rec) for rec in data]
|
||||
return responses.success(
|
||||
response,
|
||||
meta={
|
||||
"total": total_stats,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
)
|
67
app/api/v2/scores.py
Normal file
67
app/api/v2/scores.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""bancho.py's v2 apis for interacting with scores"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import status
|
||||
from fastapi.param_functions import Query
|
||||
|
||||
from app.api.v2.common import responses
|
||||
from app.api.v2.common.responses import Failure
|
||||
from app.api.v2.common.responses import Success
|
||||
from app.api.v2.models.scores import Score
|
||||
from app.repositories import scores as scores_repo
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/scores")
|
||||
async def get_all_scores(
|
||||
map_md5: str | None = None,
|
||||
mods: int | None = None,
|
||||
status: int | None = None,
|
||||
mode: int | None = None,
|
||||
user_id: int | None = None,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
) -> Success[list[Score]] | Failure:
|
||||
scores = await scores_repo.fetch_many(
|
||||
map_md5=map_md5,
|
||||
mods=mods,
|
||||
status=status,
|
||||
mode=mode,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
total_scores = await scores_repo.fetch_count(
|
||||
map_md5=map_md5,
|
||||
mods=mods,
|
||||
status=status,
|
||||
mode=mode,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
response = [Score.from_mapping(rec) for rec in scores]
|
||||
|
||||
return responses.success(
|
||||
content=response,
|
||||
meta={
|
||||
"total": total_scores,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/scores/{score_id}")
|
||||
async def get_score(score_id: int) -> Success[Score] | Failure:
|
||||
data = await scores_repo.fetch_one(id=score_id)
|
||||
if data is None:
|
||||
return responses.failure(
|
||||
message="Score not found.",
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
response = Score.from_mapping(data)
|
||||
return responses.success(response)
|
89
app/bg_loops.py
Normal file
89
app/bg_loops.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import app.packets
|
||||
import app.settings
|
||||
import app.state
|
||||
from app.constants.privileges import Privileges
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
|
||||
OSU_CLIENT_MIN_PING_INTERVAL = 300000 // 1000 # defined by osu!
|
||||
|
||||
|
||||
async def initialize_housekeeping_tasks() -> None:
|
||||
"""Create tasks for each housekeeping tasks."""
|
||||
log("Initializing housekeeping tasks.", Ansi.LCYAN)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
app.state.sessions.housekeeping_tasks.update(
|
||||
{
|
||||
loop.create_task(task)
|
||||
for task in (
|
||||
_remove_expired_donation_privileges(interval=30 * 60),
|
||||
_update_bot_status(interval=5 * 60),
|
||||
_disconnect_ghosts(interval=OSU_CLIENT_MIN_PING_INTERVAL // 3),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _remove_expired_donation_privileges(interval: int) -> None:
|
||||
"""Remove donation privileges from users with expired sessions."""
|
||||
while True:
|
||||
if app.settings.DEBUG:
|
||||
log("Removing expired donation privileges.", Ansi.LMAGENTA)
|
||||
|
||||
expired_donors = await app.state.services.database.fetch_all(
|
||||
"SELECT id FROM users "
|
||||
"WHERE donor_end <= UNIX_TIMESTAMP() "
|
||||
"AND priv & :donor_priv",
|
||||
{"donor_priv": Privileges.DONATOR.value},
|
||||
)
|
||||
|
||||
for expired_donor in expired_donors:
|
||||
player = await app.state.sessions.players.from_cache_or_sql(
|
||||
id=expired_donor["id"],
|
||||
)
|
||||
|
||||
assert player is not None
|
||||
|
||||
# TODO: perhaps make a `revoke_donor` method?
|
||||
await player.remove_privs(Privileges.DONATOR)
|
||||
player.donor_end = 0
|
||||
await app.state.services.database.execute(
|
||||
"UPDATE users SET donor_end = 0 WHERE id = :id",
|
||||
{"id": player.id},
|
||||
)
|
||||
|
||||
if player.is_online:
|
||||
player.enqueue(
|
||||
app.packets.notification("Your supporter status has expired."),
|
||||
)
|
||||
|
||||
log(f"{player}'s supporter status has expired.", Ansi.LMAGENTA)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def _disconnect_ghosts(interval: int) -> None:
|
||||
"""Actively disconnect users above the
|
||||
disconnection time threshold on the osu! server."""
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
current_time = time.time()
|
||||
|
||||
for player in app.state.sessions.players:
|
||||
if current_time - player.last_recv_time > OSU_CLIENT_MIN_PING_INTERVAL:
|
||||
log(f"Auto-dced {player}.", Ansi.LMAGENTA)
|
||||
player.logout()
|
||||
|
||||
|
||||
async def _update_bot_status(interval: int) -> None:
|
||||
"""Re roll the bot status, every `interval`."""
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
app.packets.bot_stats.cache_clear()
|
2533
app/commands.py
Normal file
2533
app/commands.py
Normal file
File diff suppressed because it is too large
Load Diff
8
app/constants/__init__.py
Normal file
8
app/constants/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# type: ignore
|
||||
# isort: dont-add-imports
|
||||
|
||||
from . import clientflags
|
||||
from . import gamemodes
|
||||
from . import mods
|
||||
from . import privileges
|
||||
from . import regexes
|
68
app/constants/clientflags.py
Normal file
68
app/constants/clientflags.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntFlag
|
||||
from enum import unique
|
||||
|
||||
from app.utils import escape_enum
|
||||
from app.utils import pymysql_encode
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class ClientFlags(IntFlag):
|
||||
"""osu! anticheat <= 2016 (unsure of age)"""
|
||||
|
||||
# NOTE: many of these flags are quite outdated and/or
|
||||
# broken and are even known to false positive quite often.
|
||||
# they can be helpful; just take them with a grain of salt.
|
||||
|
||||
CLEAN = 0 # no flags sent
|
||||
|
||||
# flags for timing errors or desync.
|
||||
SPEED_HACK_DETECTED = 1 << 1
|
||||
|
||||
# this is to be ignored by server implementations. osu! team trolling hard
|
||||
INCORRECT_MOD_VALUE = 1 << 2
|
||||
|
||||
MULTIPLE_OSU_CLIENTS = 1 << 3
|
||||
CHECKSUM_FAILURE = 1 << 4
|
||||
FLASHLIGHT_CHECKSUM_INCORRECT = 1 << 5
|
||||
|
||||
# these are only used on the osu!bancho official server.
|
||||
OSU_EXECUTABLE_CHECKSUM = 1 << 6
|
||||
MISSING_PROCESSES_IN_LIST = 1 << 7 # also deprecated as of 2018
|
||||
|
||||
# flags for either:
|
||||
# 1. pixels that should be outside the visible radius
|
||||
# (and thus black) being brighter than they should be.
|
||||
# 2. from an internal alpha value being incorrect.
|
||||
FLASHLIGHT_IMAGE_HACK = 1 << 8
|
||||
|
||||
SPINNER_HACK = 1 << 9
|
||||
TRANSPARENT_WINDOW = 1 << 10
|
||||
|
||||
# (mania) flags for consistently low press intervals.
|
||||
FAST_PRESS = 1 << 11
|
||||
|
||||
# from my experience, pretty decent
|
||||
# for detecting autobotted scores.
|
||||
RAW_MOUSE_DISCREPANCY = 1 << 12
|
||||
RAW_KEYBOARD_DISCREPANCY = 1 << 13
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class LastFMFlags(IntFlag):
|
||||
"""osu! anticheat 2019"""
|
||||
|
||||
# XXX: the aqn flags were fixed within hours of the osu!
|
||||
# update, and vanilla hq is not so widely used anymore.
|
||||
RUN_WITH_LD_FLAG = 1 << 14
|
||||
CONSOLE_OPEN = 1 << 15
|
||||
EXTRA_THREADS = 1 << 16
|
||||
HQ_ASSEMBLY = 1 << 17
|
||||
HQ_FILE = 1 << 18
|
||||
REGISTRY_EDITS = 1 << 19
|
||||
SDL2_LIBRARY = 1 << 20
|
||||
OPENSSL_LIBRARY = 1 << 21
|
||||
AQN_MENU_SAMPLE = 1 << 22
|
75
app/constants/gamemodes.py
Normal file
75
app/constants/gamemodes.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from enum import IntEnum
|
||||
from enum import unique
|
||||
|
||||
from app.constants.mods import Mods
|
||||
from app.utils import escape_enum
|
||||
from app.utils import pymysql_encode
|
||||
|
||||
GAMEMODE_REPR_LIST = (
|
||||
"vn!std",
|
||||
"vn!taiko",
|
||||
"vn!catch",
|
||||
"vn!mania",
|
||||
"rx!std",
|
||||
"rx!taiko",
|
||||
"rx!catch",
|
||||
"rx!mania", # unused
|
||||
"ap!std",
|
||||
"ap!taiko", # unused
|
||||
"ap!catch", # unused
|
||||
"ap!mania", # unused
|
||||
)
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class GameMode(IntEnum):
|
||||
VANILLA_OSU = 0
|
||||
VANILLA_TAIKO = 1
|
||||
VANILLA_CATCH = 2
|
||||
VANILLA_MANIA = 3
|
||||
|
||||
RELAX_OSU = 4
|
||||
RELAX_TAIKO = 5
|
||||
RELAX_CATCH = 6
|
||||
RELAX_MANIA = 7 # unused
|
||||
|
||||
AUTOPILOT_OSU = 8
|
||||
AUTOPILOT_TAIKO = 9 # unused
|
||||
AUTOPILOT_CATCH = 10 # unused
|
||||
AUTOPILOT_MANIA = 11 # unused
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, mode_vn: int, mods: Mods) -> GameMode:
|
||||
mode = mode_vn
|
||||
|
||||
if mods & Mods.AUTOPILOT:
|
||||
mode += 8
|
||||
elif mods & Mods.RELAX:
|
||||
mode += 4
|
||||
|
||||
return cls(mode)
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def valid_gamemodes(cls) -> list[GameMode]:
|
||||
ret = []
|
||||
for mode in cls:
|
||||
if mode not in (
|
||||
cls.RELAX_MANIA,
|
||||
cls.AUTOPILOT_TAIKO,
|
||||
cls.AUTOPILOT_CATCH,
|
||||
cls.AUTOPILOT_MANIA,
|
||||
):
|
||||
ret.append(mode)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def as_vanilla(self) -> int:
|
||||
return self.value % 4
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return GAMEMODE_REPR_LIST[self.value]
|
296
app/constants/mods.py
Normal file
296
app/constants/mods.py
Normal file
@@ -0,0 +1,296 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from enum import IntFlag
|
||||
from enum import unique
|
||||
|
||||
from app.utils import escape_enum
|
||||
from app.utils import pymysql_encode
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class Mods(IntFlag):
|
||||
NOMOD = 0
|
||||
NOFAIL = 1 << 0
|
||||
EASY = 1 << 1
|
||||
TOUCHSCREEN = 1 << 2 # old: 'NOVIDEO'
|
||||
HIDDEN = 1 << 3
|
||||
HARDROCK = 1 << 4
|
||||
SUDDENDEATH = 1 << 5
|
||||
DOUBLETIME = 1 << 6
|
||||
RELAX = 1 << 7
|
||||
HALFTIME = 1 << 8
|
||||
NIGHTCORE = 1 << 9
|
||||
FLASHLIGHT = 1 << 10
|
||||
AUTOPLAY = 1 << 11
|
||||
SPUNOUT = 1 << 12
|
||||
AUTOPILOT = 1 << 13
|
||||
PERFECT = 1 << 14
|
||||
KEY4 = 1 << 15
|
||||
KEY5 = 1 << 16
|
||||
KEY6 = 1 << 17
|
||||
KEY7 = 1 << 18
|
||||
KEY8 = 1 << 19
|
||||
FADEIN = 1 << 20
|
||||
RANDOM = 1 << 21
|
||||
CINEMA = 1 << 22
|
||||
TARGET = 1 << 23
|
||||
KEY9 = 1 << 24
|
||||
KEYCOOP = 1 << 25
|
||||
KEY1 = 1 << 26
|
||||
KEY3 = 1 << 27
|
||||
KEY2 = 1 << 28
|
||||
SCOREV2 = 1 << 29
|
||||
MIRROR = 1 << 30
|
||||
|
||||
@functools.cache
|
||||
def __repr__(self) -> str:
|
||||
if self.value == Mods.NOMOD:
|
||||
return "NM"
|
||||
|
||||
mod_str = []
|
||||
_dict = mod2modstr_dict # global
|
||||
|
||||
for mod in Mods:
|
||||
if self.value & mod:
|
||||
mod_str.append(_dict[mod])
|
||||
|
||||
return "".join(mod_str)
|
||||
|
||||
def filter_invalid_combos(self, mode_vn: int) -> Mods:
|
||||
"""Remove any invalid mod combinations."""
|
||||
|
||||
# 1. mode-inspecific mod conflictions
|
||||
_dtnc = self & (Mods.DOUBLETIME | Mods.NIGHTCORE)
|
||||
if _dtnc == (Mods.DOUBLETIME | Mods.NIGHTCORE):
|
||||
self &= ~Mods.DOUBLETIME # DTNC
|
||||
elif _dtnc and self & Mods.HALFTIME:
|
||||
self &= ~Mods.HALFTIME # (DT|NC)HT
|
||||
|
||||
if self & Mods.EASY and self & Mods.HARDROCK:
|
||||
self &= ~Mods.HARDROCK # EZHR
|
||||
|
||||
if self & (Mods.NOFAIL | Mods.RELAX | Mods.AUTOPILOT):
|
||||
if self & Mods.SUDDENDEATH:
|
||||
self &= ~Mods.SUDDENDEATH # (NF|RX|AP)SD
|
||||
if self & Mods.PERFECT:
|
||||
self &= ~Mods.PERFECT # (NF|RX|AP)PF
|
||||
|
||||
if self & (Mods.RELAX | Mods.AUTOPILOT):
|
||||
if self & Mods.NOFAIL:
|
||||
self &= ~Mods.NOFAIL # (RX|AP)NF
|
||||
|
||||
if self & Mods.PERFECT and self & Mods.SUDDENDEATH:
|
||||
self &= ~Mods.SUDDENDEATH # PFSD
|
||||
|
||||
# 2. remove mode-unique mods from incorrect gamemodes
|
||||
if mode_vn != 0: # osu! specific
|
||||
self &= ~OSU_SPECIFIC_MODS
|
||||
|
||||
# ctb & taiko have no unique mods
|
||||
|
||||
if mode_vn != 3: # mania specific
|
||||
self &= ~MANIA_SPECIFIC_MODS
|
||||
|
||||
# 3. mode-specific mod conflictions
|
||||
if mode_vn == 0:
|
||||
if self & Mods.AUTOPILOT:
|
||||
if self & (Mods.SPUNOUT | Mods.RELAX):
|
||||
self &= ~Mods.AUTOPILOT # (SO|RX)AP
|
||||
|
||||
if mode_vn == 3:
|
||||
self &= ~Mods.RELAX # rx is std/taiko/ctb common
|
||||
if self & Mods.HIDDEN and self & Mods.FADEIN:
|
||||
self &= ~Mods.FADEIN # HDFI
|
||||
|
||||
# 4 remove multiple keymods
|
||||
keymods_used = self & KEY_MODS
|
||||
|
||||
if bin(keymods_used).count("1") > 1:
|
||||
# keep only the first
|
||||
first_keymod = None
|
||||
for mod in KEY_MODS:
|
||||
if keymods_used & mod:
|
||||
first_keymod = mod
|
||||
break
|
||||
|
||||
assert first_keymod is not None
|
||||
|
||||
# remove all but the first keymod.
|
||||
self &= ~(keymods_used & ~first_keymod)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_modstr(cls, s: str) -> Mods:
|
||||
# from fmt: `HDDTRX`
|
||||
mods = cls.NOMOD
|
||||
_dict = modstr2mod_dict # global
|
||||
|
||||
# split into 2 character chunks
|
||||
mod_strs = [s[idx : idx + 2].upper() for idx in range(0, len(s), 2)]
|
||||
|
||||
# find matching mods
|
||||
for mod in mod_strs:
|
||||
if mod not in _dict:
|
||||
continue
|
||||
|
||||
mods |= _dict[mod]
|
||||
|
||||
return mods
|
||||
|
||||
@classmethod
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def from_np(cls, s: str, mode_vn: int) -> Mods:
|
||||
mods = cls.NOMOD
|
||||
_dict = npstr2mod_dict # global
|
||||
|
||||
for mod in s.split(" "):
|
||||
if mod not in _dict:
|
||||
continue
|
||||
|
||||
mods |= _dict[mod]
|
||||
|
||||
# NOTE: for fetching from /np, we automatically
|
||||
# call cls.filter_invalid_combos as we assume
|
||||
# the input string is from user input.
|
||||
return mods.filter_invalid_combos(mode_vn)
|
||||
|
||||
|
||||
modstr2mod_dict = {
|
||||
"NF": Mods.NOFAIL,
|
||||
"EZ": Mods.EASY,
|
||||
"TD": Mods.TOUCHSCREEN,
|
||||
"HD": Mods.HIDDEN,
|
||||
"HR": Mods.HARDROCK,
|
||||
"SD": Mods.SUDDENDEATH,
|
||||
"DT": Mods.DOUBLETIME,
|
||||
"RX": Mods.RELAX,
|
||||
"HT": Mods.HALFTIME,
|
||||
"NC": Mods.NIGHTCORE,
|
||||
"FL": Mods.FLASHLIGHT,
|
||||
"AU": Mods.AUTOPLAY,
|
||||
"SO": Mods.SPUNOUT,
|
||||
"AP": Mods.AUTOPILOT,
|
||||
"PF": Mods.PERFECT,
|
||||
"FI": Mods.FADEIN,
|
||||
"RN": Mods.RANDOM,
|
||||
"CN": Mods.CINEMA,
|
||||
"TP": Mods.TARGET,
|
||||
"V2": Mods.SCOREV2,
|
||||
"MR": Mods.MIRROR,
|
||||
"1K": Mods.KEY1,
|
||||
"2K": Mods.KEY2,
|
||||
"3K": Mods.KEY3,
|
||||
"4K": Mods.KEY4,
|
||||
"5K": Mods.KEY5,
|
||||
"6K": Mods.KEY6,
|
||||
"7K": Mods.KEY7,
|
||||
"8K": Mods.KEY8,
|
||||
"9K": Mods.KEY9,
|
||||
"CO": Mods.KEYCOOP,
|
||||
}
|
||||
|
||||
npstr2mod_dict = {
|
||||
"-NoFail": Mods.NOFAIL,
|
||||
"-Easy": Mods.EASY,
|
||||
"+Hidden": Mods.HIDDEN,
|
||||
"+HardRock": Mods.HARDROCK,
|
||||
"+SuddenDeath": Mods.SUDDENDEATH,
|
||||
"+DoubleTime": Mods.DOUBLETIME,
|
||||
"~Relax~": Mods.RELAX,
|
||||
"-HalfTime": Mods.HALFTIME,
|
||||
"+Nightcore": Mods.NIGHTCORE,
|
||||
"+Flashlight": Mods.FLASHLIGHT,
|
||||
"|Autoplay|": Mods.AUTOPLAY,
|
||||
"-SpunOut": Mods.SPUNOUT,
|
||||
"~Autopilot~": Mods.AUTOPILOT,
|
||||
"+Perfect": Mods.PERFECT,
|
||||
"|Cinema|": Mods.CINEMA,
|
||||
"~Target~": Mods.TARGET,
|
||||
# perhaps could modify regex
|
||||
# to only allow these once,
|
||||
# and only at the end of str?
|
||||
"|1K|": Mods.KEY1,
|
||||
"|2K|": Mods.KEY2,
|
||||
"|3K|": Mods.KEY3,
|
||||
"|4K|": Mods.KEY4,
|
||||
"|5K|": Mods.KEY5,
|
||||
"|6K|": Mods.KEY6,
|
||||
"|7K|": Mods.KEY7,
|
||||
"|8K|": Mods.KEY8,
|
||||
"|9K|": Mods.KEY9,
|
||||
# XXX: kinda mood that there's no way
|
||||
# to tell K1-K4 co-op from /np, but
|
||||
# scores won't submit or anything, so
|
||||
# it's not ultimately a problem.
|
||||
"|10K|": Mods.KEY5 | Mods.KEYCOOP,
|
||||
"|12K|": Mods.KEY6 | Mods.KEYCOOP,
|
||||
"|14K|": Mods.KEY7 | Mods.KEYCOOP,
|
||||
"|16K|": Mods.KEY8 | Mods.KEYCOOP,
|
||||
"|18K|": Mods.KEY9 | Mods.KEYCOOP,
|
||||
}
|
||||
|
||||
mod2modstr_dict = {
|
||||
Mods.NOFAIL: "NF",
|
||||
Mods.EASY: "EZ",
|
||||
Mods.TOUCHSCREEN: "TD",
|
||||
Mods.HIDDEN: "HD",
|
||||
Mods.HARDROCK: "HR",
|
||||
Mods.SUDDENDEATH: "SD",
|
||||
Mods.DOUBLETIME: "DT",
|
||||
Mods.RELAX: "RX",
|
||||
Mods.HALFTIME: "HT",
|
||||
Mods.NIGHTCORE: "NC",
|
||||
Mods.FLASHLIGHT: "FL",
|
||||
Mods.AUTOPLAY: "AU",
|
||||
Mods.SPUNOUT: "SO",
|
||||
Mods.AUTOPILOT: "AP",
|
||||
Mods.PERFECT: "PF",
|
||||
Mods.FADEIN: "FI",
|
||||
Mods.RANDOM: "RN",
|
||||
Mods.CINEMA: "CN",
|
||||
Mods.TARGET: "TP",
|
||||
Mods.SCOREV2: "V2",
|
||||
Mods.MIRROR: "MR",
|
||||
Mods.KEY1: "1K",
|
||||
Mods.KEY2: "2K",
|
||||
Mods.KEY3: "3K",
|
||||
Mods.KEY4: "4K",
|
||||
Mods.KEY5: "5K",
|
||||
Mods.KEY6: "6K",
|
||||
Mods.KEY7: "7K",
|
||||
Mods.KEY8: "8K",
|
||||
Mods.KEY9: "9K",
|
||||
Mods.KEYCOOP: "CO",
|
||||
}
|
||||
|
||||
KEY_MODS = (
|
||||
Mods.KEY1
|
||||
| Mods.KEY2
|
||||
| Mods.KEY3
|
||||
| Mods.KEY4
|
||||
| Mods.KEY5
|
||||
| Mods.KEY6
|
||||
| Mods.KEY7
|
||||
| Mods.KEY8
|
||||
| Mods.KEY9
|
||||
)
|
||||
|
||||
# FREE_MOD_ALLOWED = (
|
||||
# Mods.NOFAIL | Mods.EASY | Mods.HIDDEN | Mods.HARDROCK |
|
||||
# Mods.SUDDENDEATH | Mods.FLASHLIGHT | Mods.FADEIN |
|
||||
# Mods.RELAX | Mods.AUTOPILOT | Mods.SPUNOUT | KEY_MODS
|
||||
# )
|
||||
|
||||
SCORE_INCREASE_MODS = (
|
||||
Mods.HIDDEN | Mods.HARDROCK | Mods.FADEIN | Mods.DOUBLETIME | Mods.FLASHLIGHT
|
||||
)
|
||||
|
||||
SPEED_CHANGING_MODS = Mods.DOUBLETIME | Mods.NIGHTCORE | Mods.HALFTIME
|
||||
|
||||
OSU_SPECIFIC_MODS = Mods.AUTOPILOT | Mods.SPUNOUT | Mods.TARGET
|
||||
# taiko & catch have no specific mods
|
||||
MANIA_SPECIFIC_MODS = Mods.MIRROR | Mods.RANDOM | Mods.FADEIN | KEY_MODS
|
61
app/constants/privileges.py
Normal file
61
app/constants/privileges.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from enum import IntFlag
|
||||
from enum import unique
|
||||
|
||||
from app.utils import escape_enum
|
||||
from app.utils import pymysql_encode
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class Privileges(IntFlag):
|
||||
"""Server side user privileges."""
|
||||
|
||||
# privileges intended for all normal players.
|
||||
UNRESTRICTED = 1 << 0 # is an unbanned player.
|
||||
VERIFIED = 1 << 1 # has logged in to the server in-game.
|
||||
|
||||
# has bypass to low-ceiling anticheat measures (trusted).
|
||||
WHITELISTED = 1 << 2
|
||||
|
||||
# donation tiers, receives some extra benefits.
|
||||
SUPPORTER = 1 << 4
|
||||
PREMIUM = 1 << 5
|
||||
|
||||
# notable users, receives some extra benefits.
|
||||
ALUMNI = 1 << 7
|
||||
|
||||
# staff permissions, able to manage server app.state.
|
||||
TOURNEY_MANAGER = 1 << 10 # able to manage match state without host.
|
||||
NOMINATOR = 1 << 11 # able to manage maps ranked status.
|
||||
MODERATOR = 1 << 12 # able to manage users (level 1).
|
||||
ADMINISTRATOR = 1 << 13 # able to manage users (level 2).
|
||||
DEVELOPER = 1 << 14 # able to manage full server app.state.
|
||||
|
||||
DONATOR = SUPPORTER | PREMIUM
|
||||
STAFF = MODERATOR | ADMINISTRATOR | DEVELOPER
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class ClientPrivileges(IntFlag):
|
||||
"""Client side user privileges."""
|
||||
|
||||
PLAYER = 1 << 0
|
||||
MODERATOR = 1 << 1
|
||||
SUPPORTER = 1 << 2
|
||||
OWNER = 1 << 3
|
||||
DEVELOPER = 1 << 4
|
||||
TOURNAMENT = 1 << 5 # NOTE: not used in communications with osu! client
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class ClanPrivileges(IntEnum):
|
||||
"""A class to represent a clan members privs."""
|
||||
|
||||
Member = 1
|
||||
Officer = 2
|
||||
Owner = 3
|
23
app/constants/regexes.py
Normal file
23
app/constants/regexes.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
OSU_VERSION = re.compile(
|
||||
r"^b(?P<date>\d{8})(?:\.(?P<revision>\d))?"
|
||||
r"(?P<stream>beta|cuttingedge|dev|tourney)?$",
|
||||
)
|
||||
|
||||
USERNAME = re.compile(r"^[\w \[\]-]{2,15}$")
|
||||
EMAIL = re.compile(r"^[^@\s]{1,200}@[^@\s\.]{1,30}(?:\.[^@\.\s]{2,24})+$")
|
||||
|
||||
TOURNEY_MATCHNAME = re.compile(
|
||||
r"^(?P<name>[a-zA-Z0-9_ ]+): "
|
||||
r"\((?P<T1>[a-zA-Z0-9_ ]+)\)"
|
||||
r" vs\.? "
|
||||
r"\((?P<T2>[a-zA-Z0-9_ ]+)\)$",
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
MAPPOOL_PICK = re.compile(r"^([a-zA-Z]+)([0-9]+)$")
|
||||
|
||||
BEST_OF = re.compile(r"^(?:bo)?(\d{1,2})$")
|
173
app/discord.py
Normal file
173
app/discord.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Functionality related to Discord interactivity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from tenacity import retry
|
||||
from tenacity import stop_after_attempt
|
||||
from tenacity import wait_exponential
|
||||
|
||||
from app.state import services
|
||||
|
||||
|
||||
class Footer:
|
||||
def __init__(self, text: str, **kwargs: Any) -> None:
|
||||
self.text = text
|
||||
self.icon_url = kwargs.get("icon_url")
|
||||
self.proxy_icon_url = kwargs.get("proxy_icon_url")
|
||||
|
||||
|
||||
class Image:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.url = kwargs.get("url")
|
||||
self.proxy_url = kwargs.get("proxy_url")
|
||||
self.height = kwargs.get("height")
|
||||
self.width = kwargs.get("width")
|
||||
|
||||
|
||||
class Thumbnail:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.url = kwargs.get("url")
|
||||
self.proxy_url = kwargs.get("proxy_url")
|
||||
self.height = kwargs.get("height")
|
||||
self.width = kwargs.get("width")
|
||||
|
||||
|
||||
class Video:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.url = kwargs.get("url")
|
||||
self.height = kwargs.get("height")
|
||||
self.width = kwargs.get("width")
|
||||
|
||||
|
||||
class Provider:
|
||||
def __init__(self, **kwargs: str) -> None:
|
||||
self.url = kwargs.get("url")
|
||||
self.name = kwargs.get("name")
|
||||
|
||||
|
||||
class Author:
|
||||
def __init__(self, **kwargs: str) -> None:
|
||||
self.name = kwargs.get("name")
|
||||
self.url = kwargs.get("url")
|
||||
self.icon_url = kwargs.get("icon_url")
|
||||
self.proxy_icon_url = kwargs.get("proxy_icon_url")
|
||||
|
||||
|
||||
class Field:
|
||||
def __init__(self, name: str, value: str, inline: bool = False) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
self.inline = inline
|
||||
|
||||
|
||||
class Embed:
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.title = kwargs.get("title")
|
||||
self.type = kwargs.get("type")
|
||||
self.description = kwargs.get("description")
|
||||
self.url = kwargs.get("url")
|
||||
self.timestamp = kwargs.get("timestamp") # datetime
|
||||
self.color = kwargs.get("color", 0x000000)
|
||||
|
||||
self.footer: Footer | None = kwargs.get("footer")
|
||||
self.image: Image | None = kwargs.get("image")
|
||||
self.thumbnail: Thumbnail | None = kwargs.get("thumbnail")
|
||||
self.video: Video | None = kwargs.get("video")
|
||||
self.provider: Provider | None = kwargs.get("provider")
|
||||
self.author: Author | None = kwargs.get("author")
|
||||
|
||||
self.fields: list[Field] = kwargs.get("fields", [])
|
||||
|
||||
def set_footer(self, **kwargs: Any) -> None:
|
||||
self.footer = Footer(**kwargs)
|
||||
|
||||
def set_image(self, **kwargs: Any) -> None:
|
||||
self.image = Image(**kwargs)
|
||||
|
||||
def set_thumbnail(self, **kwargs: Any) -> None:
|
||||
self.thumbnail = Thumbnail(**kwargs)
|
||||
|
||||
def set_video(self, **kwargs: Any) -> None:
|
||||
self.video = Video(**kwargs)
|
||||
|
||||
def set_provider(self, **kwargs: Any) -> None:
|
||||
self.provider = Provider(**kwargs)
|
||||
|
||||
def set_author(self, **kwargs: Any) -> None:
|
||||
self.author = Author(**kwargs)
|
||||
|
||||
def add_field(self, name: str, value: str, inline: bool = False) -> None:
|
||||
self.fields.append(Field(name, value, inline))
|
||||
|
||||
|
||||
class Webhook:
|
||||
"""A class to represent a single-use Discord webhook."""
|
||||
|
||||
def __init__(self, url: str, **kwargs: Any) -> None:
|
||||
self.url = url
|
||||
self.content = kwargs.get("content")
|
||||
self.username = kwargs.get("username")
|
||||
self.avatar_url = kwargs.get("avatar_url")
|
||||
self.tts = kwargs.get("tts")
|
||||
self.file = kwargs.get("file")
|
||||
self.embeds = kwargs.get("embeds", [])
|
||||
|
||||
def add_embed(self, embed: Embed) -> None:
|
||||
self.embeds.append(embed)
|
||||
|
||||
@property
|
||||
def json(self) -> Any:
|
||||
if not any([self.content, self.file, self.embeds]):
|
||||
raise Exception(
|
||||
"Webhook must contain at least one " "of (content, file, embeds).",
|
||||
)
|
||||
|
||||
if self.content and len(self.content) > 2000:
|
||||
raise Exception("Webhook content must be under " "2000 characters.")
|
||||
|
||||
payload: dict[str, Any] = {"embeds": []}
|
||||
|
||||
for key in ("content", "username", "avatar_url", "tts", "file"):
|
||||
val = getattr(self, key)
|
||||
if val is not None:
|
||||
payload[key] = val
|
||||
|
||||
for embed in self.embeds:
|
||||
embed_payload = {}
|
||||
|
||||
# simple params
|
||||
for key in ("title", "type", "description", "url", "timestamp", "color"):
|
||||
val = getattr(embed, key)
|
||||
if val is not None:
|
||||
embed_payload[key] = val
|
||||
|
||||
# class params, must turn into dict
|
||||
for key in ("footer", "image", "thumbnail", "video", "provider", "author"):
|
||||
val = getattr(embed, key)
|
||||
if val is not None:
|
||||
embed_payload[key] = val.__dict__
|
||||
|
||||
if embed.fields:
|
||||
embed_payload["fields"] = [f.__dict__ for f in embed.fields]
|
||||
|
||||
payload["embeds"].append(embed_payload)
|
||||
|
||||
return payload
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
)
|
||||
async def post(self) -> None:
|
||||
"""Post the webhook in JSON format."""
|
||||
# TODO: if `self.file is not None`, then we should
|
||||
# use multipart/form-data instead of json payload.
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = await services.http_client.post(
|
||||
self.url,
|
||||
json=self.json,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
59
app/encryption.py
Normal file
59
app/encryption.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from base64 import b64decode
|
||||
from base64 import b64encode
|
||||
|
||||
from py3rijndael import Pkcs7Padding
|
||||
from py3rijndael import RijndaelCbc
|
||||
|
||||
|
||||
def encrypt_score_aes_data(
|
||||
# to encode
|
||||
score_data: list[str],
|
||||
client_hash: str,
|
||||
# used for encoding
|
||||
iv_b64: bytes,
|
||||
osu_version: str,
|
||||
) -> tuple[bytes, bytes]:
|
||||
"""Encrypt the score data to base64."""
|
||||
# TODO: perhaps this should return TypedDict?
|
||||
|
||||
# attempt to encrypt score data
|
||||
aes = RijndaelCbc(
|
||||
key=f"osu!-scoreburgr---------{osu_version}".encode(),
|
||||
iv=b64decode(iv_b64),
|
||||
padding=Pkcs7Padding(32),
|
||||
block_size=32,
|
||||
)
|
||||
|
||||
score_data_joined = ":".join(score_data)
|
||||
score_data_b64 = b64encode(aes.encrypt(score_data_joined.encode()))
|
||||
client_hash_b64 = b64encode(aes.encrypt(client_hash.encode()))
|
||||
|
||||
return score_data_b64, client_hash_b64
|
||||
|
||||
|
||||
def decrypt_score_aes_data(
|
||||
# to decode
|
||||
score_data_b64: bytes,
|
||||
client_hash_b64: bytes,
|
||||
# used for decoding
|
||||
iv_b64: bytes,
|
||||
osu_version: str,
|
||||
) -> tuple[list[str], str]:
|
||||
"""Decrypt the base64'ed score data."""
|
||||
# TODO: perhaps this should return TypedDict?
|
||||
|
||||
# attempt to decrypt score data
|
||||
aes = RijndaelCbc(
|
||||
key=f"osu!-scoreburgr---------{osu_version}".encode(),
|
||||
iv=b64decode(iv_b64),
|
||||
padding=Pkcs7Padding(32),
|
||||
block_size=32,
|
||||
)
|
||||
|
||||
score_data = aes.decrypt(b64decode(score_data_b64)).decode().split(":")
|
||||
client_hash_decoded = aes.decrypt(b64decode(client_hash_b64)).decode()
|
||||
|
||||
# score data is delimited by colons (:).
|
||||
return score_data, client_hash_decoded
|
102
app/logging.py
Normal file
102
app/logging.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import logging.config
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from enum import IntEnum
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import yaml
|
||||
|
||||
from app import settings
|
||||
|
||||
|
||||
def configure_logging() -> None:
|
||||
with open("logging.yaml") as f:
|
||||
config = yaml.safe_load(f.read())
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
|
||||
class Ansi(IntEnum):
|
||||
# Default colours
|
||||
BLACK = 30
|
||||
RED = 31
|
||||
GREEN = 32
|
||||
YELLOW = 33
|
||||
BLUE = 34
|
||||
MAGENTA = 35
|
||||
CYAN = 36
|
||||
WHITE = 37
|
||||
|
||||
# Light colours
|
||||
GRAY = 90
|
||||
LRED = 91
|
||||
LGREEN = 92
|
||||
LYELLOW = 93
|
||||
LBLUE = 94
|
||||
LMAGENTA = 95
|
||||
LCYAN = 96
|
||||
LWHITE = 97
|
||||
|
||||
RESET = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"\x1b[{self.value}m"
|
||||
|
||||
|
||||
def get_timestamp(full: bool = False, tz: ZoneInfo | None = None) -> str:
|
||||
fmt = "%d/%m/%Y %I:%M:%S%p" if full else "%I:%M:%S%p"
|
||||
return f"{datetime.datetime.now(tz=tz):{fmt}}"
|
||||
|
||||
|
||||
ANSI_ESCAPE_REGEX = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
|
||||
|
||||
|
||||
def escape_ansi(line: str) -> str:
|
||||
return ANSI_ESCAPE_REGEX.sub("", line)
|
||||
|
||||
|
||||
ROOT_LOGGER = logging.getLogger()
|
||||
|
||||
|
||||
def log(
|
||||
msg: str,
|
||||
start_color: Ansi | None = None,
|
||||
extra: Mapping[str, object] | None = None,
|
||||
) -> None:
|
||||
"""\
|
||||
A thin wrapper around the stdlib logging module to handle mostly
|
||||
backwards-compatibility for colours during our migration to the
|
||||
standard library logging module.
|
||||
"""
|
||||
|
||||
# TODO: decouple colors from the base logging function; move it to
|
||||
# be a formatter-specific concern such that we can log without color.
|
||||
if start_color is Ansi.LYELLOW:
|
||||
log_level = logging.WARNING
|
||||
elif start_color is Ansi.LRED:
|
||||
log_level = logging.ERROR
|
||||
else:
|
||||
log_level = logging.INFO
|
||||
|
||||
if settings.LOG_WITH_COLORS:
|
||||
color_prefix = f"{start_color!r}" if start_color is not None else ""
|
||||
color_suffix = f"{Ansi.RESET!r}" if start_color is not None else ""
|
||||
else:
|
||||
msg = escape_ansi(msg)
|
||||
color_prefix = color_suffix = ""
|
||||
|
||||
ROOT_LOGGER.log(log_level, f"{color_prefix}{msg}{color_suffix}", extra=extra)
|
||||
|
||||
|
||||
TIME_ORDER_SUFFIXES = ["nsec", "μsec", "msec", "sec"]
|
||||
|
||||
|
||||
def magnitude_fmt_time(nanosec: int | float) -> str:
|
||||
suffix = None
|
||||
for suffix in TIME_ORDER_SUFFIXES:
|
||||
if nanosec < 1000:
|
||||
break
|
||||
nanosec /= 1000
|
||||
return f"{nanosec:.2f} {suffix}"
|
11
app/objects/__init__.py
Normal file
11
app/objects/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# type: ignore
|
||||
# isort: dont-add-imports
|
||||
|
||||
from . import achievement
|
||||
from . import beatmap
|
||||
from . import channel
|
||||
from . import collections
|
||||
from . import match
|
||||
from . import models
|
||||
from . import player
|
||||
from . import score
|
29
app/objects/achievement.py
Normal file
29
app/objects/achievement.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.objects.score import Score
|
||||
|
||||
|
||||
class Achievement:
|
||||
"""A class to represent a single osu! achievement."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
file: str,
|
||||
name: str,
|
||||
desc: str,
|
||||
cond: Callable[[Score, int], bool], # (score, mode) -> unlocked
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.file = file
|
||||
self.name = name
|
||||
self.desc = desc
|
||||
|
||||
self.cond = cond
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.file}+{self.name}+{self.desc}"
|
996
app/objects/beatmap.py
Normal file
996
app/objects/beatmap.py
Normal file
@@ -0,0 +1,996 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from enum import IntEnum
|
||||
from enum import unique
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
|
||||
import httpx
|
||||
from tenacity import retry
|
||||
from tenacity.stop import stop_after_attempt
|
||||
|
||||
import app.settings
|
||||
import app.state
|
||||
import app.utils
|
||||
from app.constants.gamemodes import GameMode
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
from app.repositories import maps as maps_repo
|
||||
from app.utils import escape_enum
|
||||
from app.utils import pymysql_encode
|
||||
|
||||
# from dataclasses import dataclass
|
||||
|
||||
BEATMAPS_PATH = Path.cwd() / ".data/osu"
|
||||
|
||||
DEFAULT_LAST_UPDATE = datetime(1970, 1, 1)
|
||||
|
||||
IGNORED_BEATMAP_CHARS = dict.fromkeys(map(ord, r':\/*<>?"|'), None)
|
||||
|
||||
|
||||
class BeatmapApiResponse(TypedDict):
|
||||
data: list[dict[str, Any]] | None
|
||||
status_code: int
|
||||
|
||||
|
||||
@retry(reraise=True, stop=stop_after_attempt(3))
|
||||
async def api_get_beatmaps(**params: Any) -> BeatmapApiResponse:
|
||||
"""\
|
||||
Fetch data from the osu!api with a beatmap's md5.
|
||||
|
||||
Optionally use osu.direct's API if the user has not provided an osu! api key.
|
||||
"""
|
||||
if app.settings.DEBUG:
|
||||
log(f"Doing api (getbeatmaps) request {params}", Ansi.LMAGENTA)
|
||||
|
||||
if app.settings.OSU_API_KEY:
|
||||
# https://github.com/ppy/osu-api/wiki#apiget_beatmaps
|
||||
url = "https://old.ppy.sh/api/get_beatmaps"
|
||||
params["k"] = str(app.settings.OSU_API_KEY)
|
||||
else:
|
||||
# https://osu.direct/doc
|
||||
url = "https://osu.direct/api/get_beatmaps"
|
||||
|
||||
response = await app.state.services.http_client.get(url, params=params)
|
||||
response_data = response.json()
|
||||
if response.status_code == 200 and response_data: # (data may be [])
|
||||
return {"data": response_data, "status_code": response.status_code}
|
||||
|
||||
return {"data": None, "status_code": response.status_code}
|
||||
|
||||
|
||||
@retry(reraise=True, stop=stop_after_attempt(3))
|
||||
async def api_get_osu_file(beatmap_id: int) -> bytes:
|
||||
url = f"https://old.ppy.sh/osu/{beatmap_id}"
|
||||
response = await app.state.services.http_client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.read()
|
||||
|
||||
|
||||
def disk_has_expected_osu_file(
|
||||
beatmap_id: int,
|
||||
expected_md5: str | None = None,
|
||||
) -> bool:
|
||||
osu_file_path = BEATMAPS_PATH / f"{beatmap_id}.osu"
|
||||
file_exists = osu_file_path.exists()
|
||||
if file_exists and expected_md5 is not None:
|
||||
osu_file_md5 = hashlib.md5(osu_file_path.read_bytes()).hexdigest()
|
||||
return osu_file_md5 == expected_md5
|
||||
return file_exists
|
||||
|
||||
|
||||
def write_osu_file_to_disk(beatmap_id: int, data: bytes) -> None:
|
||||
osu_file_path = BEATMAPS_PATH / f"{beatmap_id}.osu"
|
||||
osu_file_path.write_bytes(data)
|
||||
|
||||
|
||||
async def ensure_osu_file_is_available(
|
||||
beatmap_id: int,
|
||||
expected_md5: str | None = None,
|
||||
) -> bool:
|
||||
"""\
|
||||
Download the .osu file for a beatmap if it's not already present.
|
||||
|
||||
If `expected_md5` is provided, the file will be downloaded if it
|
||||
does not match the expected md5 hash -- this is typically used for
|
||||
ensuring a file is the latest expected version.
|
||||
|
||||
Returns whether the file is available for use.
|
||||
"""
|
||||
if disk_has_expected_osu_file(beatmap_id, expected_md5):
|
||||
return True
|
||||
|
||||
try:
|
||||
latest_osu_file = await api_get_osu_file(beatmap_id)
|
||||
except httpx.HTTPStatusError:
|
||||
return False
|
||||
except Exception:
|
||||
log(f"Failed to fetch osu file for {beatmap_id}", Ansi.LRED)
|
||||
return False
|
||||
|
||||
write_osu_file_to_disk(beatmap_id, latest_osu_file)
|
||||
return True
|
||||
|
||||
|
||||
# for some ungodly reason, different values are used to
|
||||
# represent different ranked statuses all throughout osu!
|
||||
# This drives me and probably everyone else pretty insane,
|
||||
# but we have nothing to do but deal with it B).
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class RankedStatus(IntEnum):
|
||||
"""Server side osu! beatmap ranked statuses.
|
||||
Same as used in osu!'s /web/getscores.php.
|
||||
"""
|
||||
|
||||
NotSubmitted = -1
|
||||
Pending = 0
|
||||
UpdateAvailable = 1
|
||||
Ranked = 2
|
||||
Approved = 3
|
||||
Qualified = 4
|
||||
Loved = 5
|
||||
|
||||
def __str__(self) -> str:
|
||||
return {
|
||||
self.NotSubmitted: "Unsubmitted",
|
||||
self.Pending: "Unranked",
|
||||
self.UpdateAvailable: "Outdated",
|
||||
self.Ranked: "Ranked",
|
||||
self.Approved: "Approved",
|
||||
self.Qualified: "Qualified",
|
||||
self.Loved: "Loved",
|
||||
}[self]
|
||||
|
||||
@functools.cached_property
|
||||
def osu_api(self) -> int:
|
||||
"""Convert the value to osu!api status."""
|
||||
# XXX: only the ones that exist are mapped.
|
||||
return {
|
||||
self.Pending: 0,
|
||||
self.Ranked: 1,
|
||||
self.Approved: 2,
|
||||
self.Qualified: 3,
|
||||
self.Loved: 4,
|
||||
}[self]
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_osuapi(cls, osuapi_status: int) -> RankedStatus:
|
||||
"""Convert from osu!api status."""
|
||||
mapping: Mapping[int, RankedStatus] = defaultdict(
|
||||
lambda: cls.UpdateAvailable,
|
||||
{
|
||||
-2: cls.Pending, # graveyard
|
||||
-1: cls.Pending, # wip
|
||||
0: cls.Pending,
|
||||
1: cls.Ranked,
|
||||
2: cls.Approved,
|
||||
3: cls.Qualified,
|
||||
4: cls.Loved,
|
||||
},
|
||||
)
|
||||
return mapping[osuapi_status]
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_osudirect(cls, osudirect_status: int) -> RankedStatus:
|
||||
"""Convert from osu!direct status."""
|
||||
mapping: Mapping[int, RankedStatus] = defaultdict(
|
||||
lambda: cls.UpdateAvailable,
|
||||
{
|
||||
0: cls.Ranked,
|
||||
2: cls.Pending,
|
||||
3: cls.Qualified,
|
||||
# 4: all ranked statuses lol
|
||||
5: cls.Pending, # graveyard
|
||||
7: cls.Ranked, # played before
|
||||
8: cls.Loved,
|
||||
},
|
||||
)
|
||||
return mapping[osudirect_status]
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_str(cls, status_str: str) -> RankedStatus:
|
||||
"""Convert from string value.""" # could perhaps have `'unranked': cls.Pending`?
|
||||
mapping: Mapping[str, RankedStatus] = defaultdict(
|
||||
lambda: cls.UpdateAvailable,
|
||||
{
|
||||
"pending": cls.Pending,
|
||||
"ranked": cls.Ranked,
|
||||
"approved": cls.Approved,
|
||||
"qualified": cls.Qualified,
|
||||
"loved": cls.Loved,
|
||||
},
|
||||
)
|
||||
return mapping[status_str]
|
||||
|
||||
|
||||
# @dataclass
|
||||
# class BeatmapInfoRequest:
|
||||
# filenames: Sequence[str]
|
||||
# ids: Sequence[int]
|
||||
|
||||
# @dataclass
|
||||
# class BeatmapInfo:
|
||||
# id: int # i16
|
||||
# map_id: int # i32
|
||||
# set_id: int # i32
|
||||
# thread_id: int # i32
|
||||
# status: int # u8
|
||||
# osu_rank: int # u8
|
||||
# fruits_rank: int # u8
|
||||
# taiko_rank: int # u8
|
||||
# mania_rank: int # u8
|
||||
# map_md5: str
|
||||
|
||||
|
||||
class Beatmap:
|
||||
"""A class representing an osu! beatmap.
|
||||
|
||||
This class provides a high level api which should always be the
|
||||
preferred method of fetching beatmaps due to its housekeeping.
|
||||
It will perform caching & invalidation, handle map updates while
|
||||
minimizing osu!api requests, and always use the most efficient
|
||||
method available to fetch the beatmap's information, while
|
||||
maintaining a low overhead.
|
||||
|
||||
The only methods you should need are:
|
||||
await Beatmap.from_md5(md5: str, set_id: int = -1) -> Beatmap | None
|
||||
await Beatmap.from_bid(bid: int) -> Beatmap | None
|
||||
|
||||
Properties:
|
||||
Beatmap.full -> str # Artist - Title [Version]
|
||||
Beatmap.url -> str # https://osu.cmyui.xyz/b/321
|
||||
Beatmap.embed -> str # [{url} {full}]
|
||||
|
||||
Beatmap.has_leaderboard -> bool
|
||||
Beatmap.awards_ranked_pp -> bool
|
||||
Beatmap.as_dict -> dict[str, object]
|
||||
|
||||
Lower level API:
|
||||
Beatmap._from_md5_cache(md5: str, check_updates: bool = True) -> Beatmap | None
|
||||
Beatmap._from_bid_cache(bid: int, check_updates: bool = True) -> Beatmap | None
|
||||
|
||||
Beatmap._from_md5_sql(md5: str) -> Beatmap | None
|
||||
Beatmap._from_bid_sql(bid: int) -> Beatmap | None
|
||||
|
||||
Beatmap._parse_from_osuapi_resp(osuapi_resp: dict[str, object]) -> None
|
||||
|
||||
Note that the BeatmapSet class also provides a similar API.
|
||||
|
||||
Possibly confusing attributes
|
||||
-----------
|
||||
frozen: `bool`
|
||||
Whether the beatmap's status is to be kept when a newer
|
||||
version is found in the osu!api.
|
||||
# XXX: This is set when a map's status is manually changed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
map_set: BeatmapSet,
|
||||
md5: str = "",
|
||||
id: int = 0,
|
||||
set_id: int = 0,
|
||||
artist: str = "",
|
||||
title: str = "",
|
||||
version: str = "",
|
||||
creator: str = "",
|
||||
last_update: datetime = DEFAULT_LAST_UPDATE,
|
||||
total_length: int = 0,
|
||||
max_combo: int = 0,
|
||||
status: RankedStatus = RankedStatus.Pending,
|
||||
frozen: bool = False,
|
||||
plays: int = 0,
|
||||
passes: int = 0,
|
||||
mode: GameMode = GameMode.VANILLA_OSU,
|
||||
bpm: float = 0.0,
|
||||
cs: float = 0.0,
|
||||
od: float = 0.0,
|
||||
ar: float = 0.0,
|
||||
hp: float = 0.0,
|
||||
diff: float = 0.0,
|
||||
filename: str = "",
|
||||
) -> None:
|
||||
self.set = map_set
|
||||
|
||||
self.md5 = md5
|
||||
self.id = id
|
||||
self.set_id = set_id
|
||||
self.artist = artist
|
||||
self.title = title
|
||||
self.version = version
|
||||
self.creator = creator
|
||||
self.last_update = last_update
|
||||
self.total_length = total_length
|
||||
self.max_combo = max_combo
|
||||
self.status = status
|
||||
self.frozen = frozen
|
||||
self.plays = plays
|
||||
self.passes = passes
|
||||
self.mode = mode
|
||||
self.bpm = bpm
|
||||
self.cs = cs
|
||||
self.od = od
|
||||
self.ar = ar
|
||||
self.hp = hp
|
||||
self.diff = diff
|
||||
self.filename = filename
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.full_name
|
||||
|
||||
@property
|
||||
def full_name(self) -> str:
|
||||
"""The full osu! formatted name `self`."""
|
||||
return f"{self.artist} - {self.title} [{self.version}]"
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""The osu! beatmap url for `self`."""
|
||||
return f"https://osu.{app.settings.DOMAIN}/b/{self.id}"
|
||||
|
||||
@property
|
||||
def embed(self) -> str:
|
||||
"""An osu! chat embed to `self`'s osu! beatmap page."""
|
||||
return f"[{self.url} {self.full_name}]"
|
||||
|
||||
@property
|
||||
def has_leaderboard(self) -> bool:
|
||||
"""Return whether the map has a ranked leaderboard."""
|
||||
return self.status in (
|
||||
RankedStatus.Ranked,
|
||||
RankedStatus.Approved,
|
||||
RankedStatus.Loved,
|
||||
)
|
||||
|
||||
@property
|
||||
def awards_ranked_pp(self) -> bool:
|
||||
"""Return whether the map's status awards ranked pp for scores."""
|
||||
return self.status in (RankedStatus.Ranked, RankedStatus.Approved)
|
||||
|
||||
@property # perhaps worth caching some of?
|
||||
def as_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"md5": self.md5,
|
||||
"id": self.id,
|
||||
"set_id": self.set_id,
|
||||
"artist": self.artist,
|
||||
"title": self.title,
|
||||
"version": self.version,
|
||||
"creator": self.creator,
|
||||
"last_update": self.last_update,
|
||||
"total_length": self.total_length,
|
||||
"max_combo": self.max_combo,
|
||||
"status": self.status,
|
||||
"plays": self.plays,
|
||||
"passes": self.passes,
|
||||
"mode": self.mode,
|
||||
"bpm": self.bpm,
|
||||
"cs": self.cs,
|
||||
"od": self.od,
|
||||
"ar": self.ar,
|
||||
"hp": self.hp,
|
||||
"diff": self.diff,
|
||||
}
|
||||
|
||||
""" High level API """
|
||||
# There are three levels of storage used for beatmaps,
|
||||
# the cache (ram), the db (disk), and the osu!api (web).
|
||||
# Going down this list gets exponentially slower, so
|
||||
# we always prioritze what's fastest when possible.
|
||||
# These methods will keep beatmaps reasonably up to
|
||||
# date and use the fastest storage available, while
|
||||
# populating the higher levels of the cache with new maps.
|
||||
|
||||
@classmethod
|
||||
async def from_md5(cls, md5: str, set_id: int = -1) -> Beatmap | None:
|
||||
"""Fetch a map from the cache, database, or osuapi by md5."""
|
||||
bmap = await cls._from_md5_cache(md5)
|
||||
|
||||
if not bmap:
|
||||
# map not found in cache
|
||||
|
||||
# to be efficient, we want to cache the whole set
|
||||
# at once rather than caching the individual map
|
||||
|
||||
if set_id <= 0:
|
||||
# set id not provided - fetch it from the map md5
|
||||
rec = await maps_repo.fetch_one(md5=md5)
|
||||
|
||||
if rec is not None:
|
||||
# set found in db
|
||||
set_id = rec["set_id"]
|
||||
else:
|
||||
# set not found in db, try api
|
||||
api_data = await api_get_beatmaps(h=md5)
|
||||
|
||||
if api_data["data"] is None:
|
||||
return None
|
||||
|
||||
api_response = api_data["data"]
|
||||
set_id = int(api_response[0]["beatmapset_id"])
|
||||
|
||||
# fetch (and cache) beatmap set
|
||||
beatmap_set = await BeatmapSet.from_bsid(set_id)
|
||||
|
||||
if beatmap_set is not None:
|
||||
# the beatmap set has been cached - fetch beatmap from cache
|
||||
bmap = await cls._from_md5_cache(md5)
|
||||
|
||||
# XXX:HACK in this case, BeatmapSet.from_bsid will have
|
||||
# ensured the map is up to date, so we can just return it
|
||||
return bmap
|
||||
|
||||
if bmap is not None:
|
||||
if bmap.set._cache_expired():
|
||||
await bmap.set._update_if_available()
|
||||
|
||||
return bmap
|
||||
|
||||
@classmethod
|
||||
async def from_bid(cls, bid: int) -> Beatmap | None:
|
||||
"""Fetch a map from the cache, database, or osuapi by id."""
|
||||
bmap = await cls._from_bid_cache(bid)
|
||||
|
||||
if not bmap:
|
||||
# map not found in cache
|
||||
|
||||
# to be efficient, we want to cache the whole set
|
||||
# at once rather than caching the individual map
|
||||
|
||||
rec = await maps_repo.fetch_one(id=bid)
|
||||
|
||||
if rec is not None:
|
||||
# set found in db
|
||||
set_id = rec["set_id"]
|
||||
else:
|
||||
# set not found in db, try getting via api
|
||||
api_data = await api_get_beatmaps(b=bid)
|
||||
|
||||
if api_data["data"] is None:
|
||||
return None
|
||||
|
||||
api_response = api_data["data"]
|
||||
set_id = int(api_response[0]["beatmapset_id"])
|
||||
|
||||
# fetch (and cache) beatmap set
|
||||
beatmap_set = await BeatmapSet.from_bsid(set_id)
|
||||
|
||||
if beatmap_set is not None:
|
||||
# the beatmap set has been cached - fetch beatmap from cache
|
||||
bmap = await cls._from_bid_cache(bid)
|
||||
|
||||
# XXX:HACK in this case, BeatmapSet.from_bsid will have
|
||||
# ensured the map is up to date, so we can just return it
|
||||
return bmap
|
||||
|
||||
if bmap is not None:
|
||||
if bmap.set._cache_expired():
|
||||
await bmap.set._update_if_available()
|
||||
|
||||
return bmap
|
||||
|
||||
""" Lower level API """
|
||||
# These functions are meant for internal use under
|
||||
# all normal circumstances and should only be used
|
||||
# if you're really modifying bancho.py by adding new
|
||||
# features, or perhaps optimizing parts of the code.
|
||||
|
||||
def _parse_from_osuapi_resp(self, osuapi_resp: dict[str, Any]) -> None:
|
||||
"""Change internal data with the data in osu!api format."""
|
||||
# NOTE: `self` is not guaranteed to have any attributes
|
||||
# initialized when this is called.
|
||||
self.md5 = osuapi_resp["file_md5"]
|
||||
# self.id = int(osuapi_resp['beatmap_id'])
|
||||
self.set_id = int(osuapi_resp["beatmapset_id"])
|
||||
|
||||
self.artist, self.title, self.version, self.creator = (
|
||||
osuapi_resp["artist"],
|
||||
osuapi_resp["title"],
|
||||
osuapi_resp["version"],
|
||||
osuapi_resp["creator"],
|
||||
)
|
||||
|
||||
self.filename = (
|
||||
("{artist} - {title} ({creator}) [{version}].osu")
|
||||
.format(**osuapi_resp)
|
||||
.translate(IGNORED_BEATMAP_CHARS)
|
||||
)
|
||||
|
||||
# quite a bit faster than using dt.strptime.
|
||||
_last_update = osuapi_resp["last_update"]
|
||||
self.last_update = datetime(
|
||||
year=int(_last_update[0:4]),
|
||||
month=int(_last_update[5:7]),
|
||||
day=int(_last_update[8:10]),
|
||||
hour=int(_last_update[11:13]),
|
||||
minute=int(_last_update[14:16]),
|
||||
second=int(_last_update[17:19]),
|
||||
)
|
||||
|
||||
self.total_length = int(osuapi_resp["total_length"])
|
||||
|
||||
if osuapi_resp["max_combo"] is not None:
|
||||
self.max_combo = int(osuapi_resp["max_combo"])
|
||||
else:
|
||||
self.max_combo = 0
|
||||
|
||||
# if a map is 'frozen', we keep its status
|
||||
# even after an update from the osu!api.
|
||||
if not getattr(self, "frozen", False):
|
||||
osuapi_status = int(osuapi_resp["approved"])
|
||||
self.status = RankedStatus.from_osuapi(osuapi_status)
|
||||
|
||||
self.mode = GameMode(int(osuapi_resp["mode"]))
|
||||
|
||||
if osuapi_resp["bpm"] is not None:
|
||||
self.bpm = float(osuapi_resp["bpm"])
|
||||
else:
|
||||
self.bpm = 0.0
|
||||
|
||||
self.cs = float(osuapi_resp["diff_size"])
|
||||
self.od = float(osuapi_resp["diff_overall"])
|
||||
self.ar = float(osuapi_resp["diff_approach"])
|
||||
self.hp = float(osuapi_resp["diff_drain"])
|
||||
|
||||
self.diff = float(osuapi_resp["difficultyrating"])
|
||||
|
||||
@staticmethod
|
||||
async def _from_md5_cache(md5: str) -> Beatmap | None:
|
||||
"""Fetch a map from the cache by md5."""
|
||||
return app.state.cache.beatmap.get(md5, None)
|
||||
|
||||
@staticmethod
|
||||
async def _from_bid_cache(bid: int) -> Beatmap | None:
|
||||
"""Fetch a map from the cache by id."""
|
||||
return app.state.cache.beatmap.get(bid, None)
|
||||
|
||||
|
||||
class BeatmapSet:
|
||||
"""A class to represent an osu! beatmap set.
|
||||
|
||||
Like the Beatmap class, this class provides a high level api
|
||||
which should always be the preferred method of fetching beatmaps
|
||||
due to its housekeeping. It will perform caching & invalidation,
|
||||
handle map updates while minimizing osu!api requests, and always
|
||||
use the most efficient method available to fetch the beatmap's
|
||||
information, while maintaining a low overhead.
|
||||
|
||||
The only methods you should need are:
|
||||
await BeatmapSet.from_bsid(bsid: int) -> BeatmapSet | None
|
||||
|
||||
BeatmapSet.all_officially_ranked_or_approved() -> bool
|
||||
BeatmapSet.all_officially_loved() -> bool
|
||||
|
||||
Properties:
|
||||
BeatmapSet.url -> str # https://osu.cmyui.xyz/s/123
|
||||
|
||||
Lower level API:
|
||||
await BeatmapSet._from_bsid_cache(bsid: int) -> BeatmapSet | None
|
||||
await BeatmapSet._from_bsid_sql(bsid: int) -> BeatmapSet | None
|
||||
await BeatmapSet._from_bsid_osuapi(bsid: int) -> BeatmapSet | None
|
||||
|
||||
BeatmapSet._cache_expired() -> bool
|
||||
await BeatmapSet._update_if_available() -> None
|
||||
await BeatmapSet._save_to_sql() -> None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
last_osuapi_check: datetime,
|
||||
maps: list[Beatmap] | None = None,
|
||||
) -> None:
|
||||
self.id = id
|
||||
|
||||
self.maps = maps or []
|
||||
self.last_osuapi_check = last_osuapi_check
|
||||
|
||||
def __repr__(self) -> str:
|
||||
map_names = []
|
||||
for bmap in self.maps:
|
||||
name = f"{bmap.artist} - {bmap.title}"
|
||||
if name not in map_names:
|
||||
map_names.append(name)
|
||||
return ", ".join(map_names)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""The online url for this beatmap set."""
|
||||
return f"https://osu.{app.settings.DOMAIN}/s/{self.id}"
|
||||
|
||||
def any_beatmaps_have_official_leaderboards(self) -> bool:
|
||||
"""Whether all the maps in the set have leaderboards on official servers."""
|
||||
leaderboard_having_statuses = (
|
||||
RankedStatus.Loved,
|
||||
RankedStatus.Ranked,
|
||||
RankedStatus.Approved,
|
||||
)
|
||||
return any(bmap.status in leaderboard_having_statuses for bmap in self.maps)
|
||||
|
||||
def _cache_expired(self) -> bool:
|
||||
"""Whether the cached version of the set is
|
||||
expired and needs an update from the osu!api."""
|
||||
current_datetime = datetime.now()
|
||||
|
||||
if not self.maps:
|
||||
return True
|
||||
|
||||
# the delta between cache invalidations will increase depending
|
||||
# on how long it's been since the map was last updated on osu!
|
||||
last_map_update = max(bmap.last_update for bmap in self.maps)
|
||||
update_delta = current_datetime - last_map_update
|
||||
|
||||
# with a minimum of 2 hours, add 5 hours per year since its update.
|
||||
# the formula for this is subject to adjustment in the future.
|
||||
check_delta = timedelta(hours=2 + ((5 / 365) * update_delta.days))
|
||||
|
||||
# it's much less likely that a beatmapset who has beatmaps with
|
||||
# leaderboards on official servers will be updated.
|
||||
if self.any_beatmaps_have_official_leaderboards():
|
||||
check_delta *= 4
|
||||
|
||||
# we'll cache for an absolute maximum of 1 day.
|
||||
check_delta = min(check_delta, timedelta(days=1))
|
||||
|
||||
return current_datetime > (self.last_osuapi_check + check_delta)
|
||||
|
||||
async def _update_if_available(self) -> None:
|
||||
"""Fetch the newest data from the api, check for differences
|
||||
and propogate any update into our cache & database."""
|
||||
|
||||
try:
|
||||
api_data = await api_get_beatmaps(s=self.id)
|
||||
except (httpx.TransportError, httpx.DecodingError):
|
||||
# NOTE: TransportError is directly caused by the API being unavailable
|
||||
|
||||
# NOTE: DecodingError is caused by the API returning HTML and
|
||||
# normally happens when CF protection is enabled while
|
||||
# osu! recovers from a DDOS attack
|
||||
|
||||
# we do not want to delete the beatmap in this case, so we simply return
|
||||
# but do not set the last check, as we would like to retry these ASAP
|
||||
|
||||
return
|
||||
|
||||
if api_data["data"] is not None:
|
||||
api_response = api_data["data"]
|
||||
|
||||
old_maps = {bmap.id: bmap for bmap in self.maps}
|
||||
new_maps = {int(api_map["beatmap_id"]): api_map for api_map in api_response}
|
||||
|
||||
self.last_osuapi_check = datetime.now()
|
||||
|
||||
# delete maps from old_maps where old.id not in new_maps
|
||||
# update maps from old_maps where old.md5 != new.md5
|
||||
# add maps to old_maps where new.id not in old_maps
|
||||
|
||||
updated_maps: list[Beatmap] = []
|
||||
map_md5s_to_delete: set[str] = set()
|
||||
|
||||
# temp value for building the new beatmap
|
||||
bmap: Beatmap
|
||||
|
||||
# find maps in our current state that've been deleted, or need updates
|
||||
for old_id, old_map in old_maps.items():
|
||||
if old_id not in new_maps:
|
||||
# delete map from old_maps
|
||||
map_md5s_to_delete.add(old_map.md5)
|
||||
else:
|
||||
new_map = new_maps[old_id]
|
||||
new_ranked_status = RankedStatus.from_osuapi(
|
||||
int(new_map["approved"]),
|
||||
)
|
||||
if (
|
||||
old_map.md5 != new_map["file_md5"]
|
||||
or old_map.status != new_ranked_status
|
||||
):
|
||||
# update map from old_maps
|
||||
bmap = old_maps[old_id]
|
||||
bmap._parse_from_osuapi_resp(new_map)
|
||||
updated_maps.append(bmap)
|
||||
else:
|
||||
# map is the same, make no changes
|
||||
updated_maps.append(old_map) # (this line is _maybe_ needed?)
|
||||
|
||||
# find maps that aren't in our current state, and add them
|
||||
for new_id, new_map in new_maps.items():
|
||||
if new_id not in old_maps:
|
||||
# new map we don't have locally, add it
|
||||
bmap = Beatmap.__new__(Beatmap)
|
||||
bmap.id = new_id
|
||||
|
||||
bmap._parse_from_osuapi_resp(new_map)
|
||||
|
||||
# (some implementation-specific stuff not given by api)
|
||||
bmap.frozen = False
|
||||
bmap.passes = 0
|
||||
bmap.plays = 0
|
||||
|
||||
bmap.set = self
|
||||
updated_maps.append(bmap)
|
||||
|
||||
# save changes to cache
|
||||
self.maps = updated_maps
|
||||
|
||||
# save changes to sql
|
||||
|
||||
if map_md5s_to_delete:
|
||||
# delete maps
|
||||
await app.state.services.database.execute(
|
||||
"DELETE FROM maps WHERE md5 IN :map_md5s",
|
||||
{"map_md5s": map_md5s_to_delete},
|
||||
)
|
||||
|
||||
# delete scores on the maps
|
||||
await app.state.services.database.execute(
|
||||
"DELETE FROM scores WHERE map_md5 IN :map_md5s",
|
||||
{"map_md5s": map_md5s_to_delete},
|
||||
)
|
||||
|
||||
# update last_osuapi_check
|
||||
await app.state.services.database.execute(
|
||||
"REPLACE INTO mapsets "
|
||||
"(id, server, last_osuapi_check) "
|
||||
"VALUES (:id, :server, :last_osuapi_check)",
|
||||
{
|
||||
"id": self.id,
|
||||
"server": "osu!",
|
||||
"last_osuapi_check": self.last_osuapi_check,
|
||||
},
|
||||
)
|
||||
|
||||
# update maps in sql
|
||||
await self._save_to_sql()
|
||||
elif api_data["status_code"] in (404, 200):
|
||||
# NOTE: 200 can return an empty array of beatmaps,
|
||||
# so we still delete in this case if the beatmap data is None
|
||||
|
||||
# TODO: a couple of open questions here:
|
||||
# - should we delete the beatmap from the database if it's not in the osu!api?
|
||||
# - are 404 and 200 the only cases where we should delete the beatmap?
|
||||
if self.maps:
|
||||
map_md5s_to_delete = {bmap.md5 for bmap in self.maps}
|
||||
|
||||
# delete maps
|
||||
await app.state.services.database.execute(
|
||||
"DELETE FROM maps WHERE md5 IN :map_md5s",
|
||||
{"map_md5s": map_md5s_to_delete},
|
||||
)
|
||||
|
||||
# delete scores on the maps
|
||||
await app.state.services.database.execute(
|
||||
"DELETE FROM scores WHERE map_md5 IN :map_md5s",
|
||||
{"map_md5s": map_md5s_to_delete},
|
||||
)
|
||||
|
||||
# delete set
|
||||
await app.state.services.database.execute(
|
||||
"DELETE FROM mapsets WHERE id = :set_id",
|
||||
{"set_id": self.id},
|
||||
)
|
||||
|
||||
async def _save_to_sql(self) -> None:
|
||||
"""Save the object's attributes into the database."""
|
||||
await app.state.services.database.execute_many(
|
||||
"REPLACE INTO maps ("
|
||||
"md5, id, server, set_id, "
|
||||
"artist, title, version, creator, "
|
||||
"filename, last_update, total_length, "
|
||||
"max_combo, status, frozen, "
|
||||
"plays, passes, mode, bpm, "
|
||||
"cs, od, ar, hp, diff"
|
||||
") VALUES ("
|
||||
":md5, :id, :server, :set_id, "
|
||||
":artist, :title, :version, :creator, "
|
||||
":filename, :last_update, :total_length, "
|
||||
":max_combo, :status, :frozen, "
|
||||
":plays, :passes, :mode, :bpm, "
|
||||
":cs, :od, :ar, :hp, :diff"
|
||||
")",
|
||||
[
|
||||
{
|
||||
"md5": bmap.md5,
|
||||
"id": bmap.id,
|
||||
"server": "osu!",
|
||||
"set_id": bmap.set_id,
|
||||
"artist": bmap.artist,
|
||||
"title": bmap.title,
|
||||
"version": bmap.version,
|
||||
"creator": bmap.creator,
|
||||
"filename": bmap.filename,
|
||||
"last_update": bmap.last_update,
|
||||
"total_length": bmap.total_length,
|
||||
"max_combo": bmap.max_combo,
|
||||
"status": bmap.status,
|
||||
"frozen": bmap.frozen,
|
||||
"plays": bmap.plays,
|
||||
"passes": bmap.passes,
|
||||
"mode": bmap.mode,
|
||||
"bpm": bmap.bpm,
|
||||
"cs": bmap.cs,
|
||||
"od": bmap.od,
|
||||
"ar": bmap.ar,
|
||||
"hp": bmap.hp,
|
||||
"diff": bmap.diff,
|
||||
}
|
||||
for bmap in self.maps
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _from_bsid_cache(bsid: int) -> BeatmapSet | None:
|
||||
"""Fetch a mapset from the cache by set id."""
|
||||
return app.state.cache.beatmapset.get(bsid, None)
|
||||
|
||||
@classmethod
|
||||
async def _from_bsid_sql(cls, bsid: int) -> BeatmapSet | None:
|
||||
"""Fetch a mapset from the database by set id."""
|
||||
last_osuapi_check = await app.state.services.database.fetch_val(
|
||||
"SELECT last_osuapi_check FROM mapsets WHERE id = :set_id",
|
||||
{"set_id": bsid},
|
||||
column=0, # last_osuapi_check
|
||||
)
|
||||
|
||||
if last_osuapi_check is None:
|
||||
return None
|
||||
|
||||
bmap_set = cls(id=bsid, last_osuapi_check=last_osuapi_check)
|
||||
|
||||
for row in await maps_repo.fetch_many(set_id=bsid):
|
||||
bmap = Beatmap(
|
||||
md5=row["md5"],
|
||||
id=row["id"],
|
||||
set_id=row["set_id"],
|
||||
artist=row["artist"],
|
||||
title=row["title"],
|
||||
version=row["version"],
|
||||
creator=row["creator"],
|
||||
last_update=row["last_update"],
|
||||
total_length=row["total_length"],
|
||||
max_combo=row["max_combo"],
|
||||
status=RankedStatus(row["status"]),
|
||||
frozen=row["frozen"],
|
||||
plays=row["plays"],
|
||||
passes=row["passes"],
|
||||
mode=GameMode(row["mode"]),
|
||||
bpm=row["bpm"],
|
||||
cs=row["cs"],
|
||||
od=row["od"],
|
||||
ar=row["ar"],
|
||||
hp=row["hp"],
|
||||
diff=row["diff"],
|
||||
filename=row["filename"],
|
||||
map_set=bmap_set,
|
||||
)
|
||||
|
||||
# XXX: tempfix for bancho.py <v3.4.1,
|
||||
# where filenames weren't stored.
|
||||
if not bmap.filename:
|
||||
bmap.filename = (
|
||||
("{artist} - {title} ({creator}) [{version}].osu")
|
||||
.format(
|
||||
artist=row["artist"],
|
||||
title=row["title"],
|
||||
creator=row["creator"],
|
||||
version=row["version"],
|
||||
)
|
||||
.translate(IGNORED_BEATMAP_CHARS)
|
||||
)
|
||||
await maps_repo.partial_update(bmap.id, filename=bmap.filename)
|
||||
|
||||
bmap_set.maps.append(bmap)
|
||||
|
||||
return bmap_set
|
||||
|
||||
@classmethod
|
||||
async def _from_bsid_osuapi(cls, bsid: int) -> BeatmapSet | None:
|
||||
"""Fetch a mapset from the osu!api by set id."""
|
||||
api_data = await api_get_beatmaps(s=bsid)
|
||||
if api_data["data"] is not None:
|
||||
api_response = api_data["data"]
|
||||
|
||||
self = cls(id=bsid, last_osuapi_check=datetime.now())
|
||||
|
||||
# XXX: pre-mapset bancho.py support
|
||||
# select all current beatmaps
|
||||
# that're frozen in the db
|
||||
res = await app.state.services.database.fetch_all(
|
||||
"SELECT id, status FROM maps WHERE set_id = :set_id AND frozen = 1",
|
||||
{"set_id": bsid},
|
||||
)
|
||||
|
||||
current_maps = {row["id"]: row["status"] for row in res}
|
||||
|
||||
for api_bmap in api_response:
|
||||
# newer version available for this map
|
||||
bmap: Beatmap = Beatmap.__new__(Beatmap)
|
||||
bmap.id = int(api_bmap["beatmap_id"])
|
||||
|
||||
if bmap.id in current_maps:
|
||||
# map is currently frozen, keep it's status.
|
||||
bmap.status = RankedStatus(current_maps[bmap.id])
|
||||
bmap.frozen = True
|
||||
else:
|
||||
bmap.frozen = False
|
||||
|
||||
bmap._parse_from_osuapi_resp(api_bmap)
|
||||
|
||||
# (some implementation-specific stuff not given by api)
|
||||
bmap.passes = 0
|
||||
bmap.plays = 0
|
||||
|
||||
bmap.set = self
|
||||
self.maps.append(bmap)
|
||||
|
||||
await app.state.services.database.execute(
|
||||
"REPLACE INTO mapsets "
|
||||
"(id, server, last_osuapi_check) "
|
||||
"VALUES (:id, :server, :last_osuapi_check)",
|
||||
{
|
||||
"id": self.id,
|
||||
"server": "osu!",
|
||||
"last_osuapi_check": self.last_osuapi_check,
|
||||
},
|
||||
)
|
||||
|
||||
await self._save_to_sql()
|
||||
return self
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def from_bsid(cls, bsid: int) -> BeatmapSet | None:
|
||||
"""Cache all maps in a set from the osuapi, optionally
|
||||
returning beatmaps by their md5 or id."""
|
||||
bmap_set = await cls._from_bsid_cache(bsid)
|
||||
did_api_request = False
|
||||
|
||||
if not bmap_set:
|
||||
bmap_set = await cls._from_bsid_sql(bsid)
|
||||
|
||||
if not bmap_set:
|
||||
bmap_set = await cls._from_bsid_osuapi(bsid)
|
||||
|
||||
if not bmap_set:
|
||||
return None
|
||||
|
||||
did_api_request = True
|
||||
|
||||
# TODO: this can be done less often for certain types of maps,
|
||||
# such as ones that're ranked on bancho and won't be updated,
|
||||
# and perhaps ones that haven't been updated in a long time.
|
||||
if not did_api_request and bmap_set._cache_expired():
|
||||
await bmap_set._update_if_available()
|
||||
|
||||
# cache the beatmap set, and beatmaps
|
||||
# to be efficient in future requests
|
||||
cache_beatmap_set(bmap_set)
|
||||
|
||||
return bmap_set
|
||||
|
||||
|
||||
def cache_beatmap(beatmap: Beatmap) -> None:
|
||||
"""Add the beatmap to the cache."""
|
||||
app.state.cache.beatmap[beatmap.md5] = beatmap
|
||||
app.state.cache.beatmap[beatmap.id] = beatmap
|
||||
|
||||
|
||||
def cache_beatmap_set(beatmap_set: BeatmapSet) -> None:
|
||||
"""Add the beatmap set, and each beatmap to the cache."""
|
||||
app.state.cache.beatmapset[beatmap_set.id] = beatmap_set
|
||||
|
||||
for beatmap in beatmap_set.maps:
|
||||
cache_beatmap(beatmap)
|
138
app/objects/channel.py
Normal file
138
app/objects/channel.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import app.packets
|
||||
import app.state
|
||||
from app.constants.privileges import Privileges
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.objects.player import Player
|
||||
|
||||
|
||||
class Channel:
|
||||
"""An osu! chat channel.
|
||||
|
||||
Possibly confusing attributes
|
||||
-----------
|
||||
_name: `str`
|
||||
A name string of the channel.
|
||||
The cls.`name` property wraps handling for '#multiplayer' and
|
||||
'#spectator' when communicating with the osu! client; only use
|
||||
this attr when you need the channel's true name; otherwise you
|
||||
should use the `name` property described below.
|
||||
|
||||
instance: `bool`
|
||||
Instanced channels are deleted when all players have left;
|
||||
this is useful for things like multiplayer, spectator, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
topic: str,
|
||||
read_priv: Privileges = Privileges.UNRESTRICTED,
|
||||
write_priv: Privileges = Privileges.UNRESTRICTED,
|
||||
auto_join: bool = True,
|
||||
instance: bool = False,
|
||||
) -> None:
|
||||
# TODO: think of better names than `_name` and `name`
|
||||
self._name = name # 'real' name ('#{multi/spec}_{id}')
|
||||
|
||||
if self._name.startswith("#spec_"):
|
||||
self.name = "#spectator"
|
||||
elif self._name.startswith("#multi_"):
|
||||
self.name = "#multiplayer"
|
||||
else:
|
||||
self.name = self._name
|
||||
|
||||
self.topic = topic
|
||||
self.read_priv = read_priv
|
||||
self.write_priv = write_priv
|
||||
self.auto_join = auto_join
|
||||
self.instance = instance
|
||||
|
||||
self.players: list[Player] = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self._name}>"
|
||||
|
||||
def __contains__(self, player: Player) -> bool:
|
||||
return player in self.players
|
||||
|
||||
# XXX: should this be cached differently?
|
||||
|
||||
def can_read(self, priv: Privileges) -> bool:
|
||||
if not self.read_priv:
|
||||
return True
|
||||
|
||||
return priv & self.read_priv != 0
|
||||
|
||||
def can_write(self, priv: Privileges) -> bool:
|
||||
if not self.write_priv:
|
||||
return True
|
||||
|
||||
return priv & self.write_priv != 0
|
||||
|
||||
def send(self, msg: str, sender: Player, to_self: bool = False) -> None:
|
||||
"""Enqueue `msg` to all appropriate clients from `sender`."""
|
||||
data = app.packets.send_message(
|
||||
sender=sender.name,
|
||||
msg=msg,
|
||||
recipient=self.name,
|
||||
sender_id=sender.id,
|
||||
)
|
||||
|
||||
for player in self.players:
|
||||
if sender.id not in player.blocks and (to_self or player.id != sender.id):
|
||||
player.enqueue(data)
|
||||
|
||||
def send_bot(self, msg: str) -> None:
|
||||
"""Enqueue `msg` to all connected clients from bot."""
|
||||
bot = app.state.sessions.bot
|
||||
|
||||
msg_len = len(msg)
|
||||
|
||||
if msg_len >= 31979: # TODO ??????????
|
||||
msg = f"message would have crashed games ({msg_len} chars)"
|
||||
|
||||
self.enqueue(
|
||||
app.packets.send_message(
|
||||
sender=bot.name,
|
||||
msg=msg,
|
||||
recipient=self.name,
|
||||
sender_id=bot.id,
|
||||
),
|
||||
)
|
||||
|
||||
def send_selective(
|
||||
self,
|
||||
msg: str,
|
||||
sender: Player,
|
||||
recipients: set[Player],
|
||||
) -> None:
|
||||
"""Enqueue `sender`'s `msg` to `recipients`."""
|
||||
for player in recipients:
|
||||
if player in self:
|
||||
player.send(msg, sender=sender, chan=self)
|
||||
|
||||
def append(self, player: Player) -> None:
|
||||
"""Add `player` to the channel's players."""
|
||||
self.players.append(player)
|
||||
|
||||
def remove(self, player: Player) -> None:
|
||||
"""Remove `player` from the channel's players."""
|
||||
self.players.remove(player)
|
||||
|
||||
if not self.players and self.instance:
|
||||
# if it's an instance channel and this
|
||||
# is the last member leaving, just remove
|
||||
# the channel from the global list.
|
||||
app.state.sessions.channels.remove(self)
|
||||
|
||||
def enqueue(self, data: bytes, immune: Sequence[int] = []) -> None:
|
||||
"""Enqueue `data` to all connected clients not in `immune`."""
|
||||
for player in self.players:
|
||||
if player.id not in immune:
|
||||
player.enqueue(data)
|
314
app/objects/collections.py
Normal file
314
app/objects/collections.py
Normal file
@@ -0,0 +1,314 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import databases.core
|
||||
|
||||
import app.settings
|
||||
import app.state
|
||||
import app.utils
|
||||
from app.constants.privileges import ClanPrivileges
|
||||
from app.constants.privileges import Privileges
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
from app.objects.channel import Channel
|
||||
from app.objects.match import Match
|
||||
from app.objects.player import Player
|
||||
from app.repositories import channels as channels_repo
|
||||
from app.repositories import clans as clans_repo
|
||||
from app.repositories import users as users_repo
|
||||
from app.utils import make_safe_name
|
||||
|
||||
|
||||
class Channels(list[Channel]):
|
||||
"""The currently active chat channels on the server."""
|
||||
|
||||
def __iter__(self) -> Iterator[Channel]:
|
||||
return super().__iter__()
|
||||
|
||||
def __contains__(self, o: object) -> bool:
|
||||
"""Check whether internal list contains `o`."""
|
||||
# Allow string to be passed to compare vs. name.
|
||||
if isinstance(o, str):
|
||||
return o in (chan.name for chan in self)
|
||||
else:
|
||||
return super().__contains__(o)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# XXX: we use the "real" name, aka
|
||||
# #multi_1 instead of #multiplayer
|
||||
# #spect_1 instead of #spectator.
|
||||
return f'[{", ".join(c._name for c in self)}]'
|
||||
|
||||
def get_by_name(self, name: str) -> Channel | None:
|
||||
"""Get a channel from the list by `name`."""
|
||||
for channel in self:
|
||||
if channel._name == name:
|
||||
return channel
|
||||
|
||||
return None
|
||||
|
||||
def append(self, channel: Channel) -> None:
|
||||
"""Append `channel` to the list."""
|
||||
super().append(channel)
|
||||
|
||||
if app.settings.DEBUG:
|
||||
log(f"{channel} added to channels list.")
|
||||
|
||||
def extend(self, channels: Iterable[Channel]) -> None:
|
||||
"""Extend the list with `channels`."""
|
||||
super().extend(channels)
|
||||
|
||||
if app.settings.DEBUG:
|
||||
log(f"{channels} added to channels list.")
|
||||
|
||||
def remove(self, channel: Channel) -> None:
|
||||
"""Remove `channel` from the list."""
|
||||
super().remove(channel)
|
||||
|
||||
if app.settings.DEBUG:
|
||||
log(f"{channel} removed from channels list.")
|
||||
|
||||
async def prepare(self) -> None:
|
||||
"""Fetch data from sql & return; preparing to run the server."""
|
||||
log("Fetching channels from sql.", Ansi.LCYAN)
|
||||
for row in await channels_repo.fetch_many():
|
||||
self.append(
|
||||
Channel(
|
||||
name=row["name"],
|
||||
topic=row["topic"],
|
||||
read_priv=Privileges(row["read_priv"]),
|
||||
write_priv=Privileges(row["write_priv"]),
|
||||
auto_join=row["auto_join"] == 1,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Matches(list[Match | None]):
|
||||
"""The currently active multiplayer matches on the server."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
MAX_MATCHES = 64 # TODO: refactor this out of existence
|
||||
super().__init__([None] * MAX_MATCHES)
|
||||
|
||||
def __iter__(self) -> Iterator[Match | None]:
|
||||
return super().__iter__()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'[{", ".join(match.name for match in self if match)}]'
|
||||
|
||||
def get_free(self) -> int | None:
|
||||
"""Return the first free match id from `self`."""
|
||||
for idx, match in enumerate(self):
|
||||
if match is None:
|
||||
return idx
|
||||
|
||||
return None
|
||||
|
||||
def remove(self, match: Match | None) -> None:
|
||||
"""Remove `match` from the list."""
|
||||
for i, _m in enumerate(self):
|
||||
if match is _m:
|
||||
self[i] = None
|
||||
break
|
||||
|
||||
if app.settings.DEBUG:
|
||||
log(f"{match} removed from matches list.")
|
||||
|
||||
|
||||
class Players(list[Player]):
|
||||
"""The currently active players on the server."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __iter__(self) -> Iterator[Player]:
|
||||
return super().__iter__()
|
||||
|
||||
def __contains__(self, player: object) -> bool:
|
||||
# allow us to either pass in the player
|
||||
# obj, or the player name as a string.
|
||||
if isinstance(player, str):
|
||||
return player in (player.name for player in self)
|
||||
else:
|
||||
return super().__contains__(player)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'[{", ".join(map(repr, self))}]'
|
||||
|
||||
@property
|
||||
def ids(self) -> set[int]:
|
||||
"""Return a set of the current ids in the list."""
|
||||
return {p.id for p in self}
|
||||
|
||||
@property
|
||||
def staff(self) -> set[Player]:
|
||||
"""Return a set of the current staff online."""
|
||||
return {p for p in self if p.priv & Privileges.STAFF}
|
||||
|
||||
@property
|
||||
def restricted(self) -> set[Player]:
|
||||
"""Return a set of the current restricted players."""
|
||||
return {p for p in self if not p.priv & Privileges.UNRESTRICTED}
|
||||
|
||||
@property
|
||||
def unrestricted(self) -> set[Player]:
|
||||
"""Return a set of the current unrestricted players."""
|
||||
return {p for p in self if p.priv & Privileges.UNRESTRICTED}
|
||||
|
||||
def enqueue(self, data: bytes, immune: Sequence[Player] = []) -> None:
|
||||
"""Enqueue `data` to all players, except for those in `immune`."""
|
||||
for player in self:
|
||||
if player not in immune:
|
||||
player.enqueue(data)
|
||||
|
||||
def get(
|
||||
self,
|
||||
token: str | None = None,
|
||||
id: int | None = None,
|
||||
name: str | None = None,
|
||||
) -> Player | None:
|
||||
"""Get a player by token, id, or name from cache."""
|
||||
for player in self:
|
||||
if token is not None:
|
||||
if player.token == token:
|
||||
return player
|
||||
elif id is not None:
|
||||
if player.id == id:
|
||||
return player
|
||||
elif name is not None:
|
||||
if player.safe_name == make_safe_name(name):
|
||||
return player
|
||||
|
||||
return None
|
||||
|
||||
async def get_sql(
|
||||
self,
|
||||
id: int | None = None,
|
||||
name: str | None = None,
|
||||
) -> Player | None:
|
||||
"""Get a player by token, id, or name from sql."""
|
||||
# try to get from sql.
|
||||
player = await users_repo.fetch_one(
|
||||
id=id,
|
||||
name=name,
|
||||
fetch_all_fields=True,
|
||||
)
|
||||
if player is None:
|
||||
return None
|
||||
|
||||
clan_id: int | None = None
|
||||
clan_priv: ClanPrivileges | None = None
|
||||
if player["clan_id"] != 0:
|
||||
clan_id = player["clan_id"]
|
||||
clan_priv = ClanPrivileges(player["clan_priv"])
|
||||
|
||||
return Player(
|
||||
id=player["id"],
|
||||
name=player["name"],
|
||||
priv=Privileges(player["priv"]),
|
||||
pw_bcrypt=player["pw_bcrypt"].encode(),
|
||||
token=Player.generate_token(),
|
||||
clan_id=clan_id,
|
||||
clan_priv=clan_priv,
|
||||
geoloc={
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"country": {
|
||||
"acronym": player["country"],
|
||||
"numeric": app.state.services.country_codes[player["country"]],
|
||||
},
|
||||
},
|
||||
silence_end=player["silence_end"],
|
||||
donor_end=player["donor_end"],
|
||||
api_key=player["api_key"],
|
||||
)
|
||||
|
||||
async def from_cache_or_sql(
|
||||
self,
|
||||
id: int | None = None,
|
||||
name: str | None = None,
|
||||
) -> Player | None:
|
||||
"""Try to get player from cache, or sql as fallback."""
|
||||
player = self.get(id=id, name=name)
|
||||
if player is not None:
|
||||
return player
|
||||
player = await self.get_sql(id=id, name=name)
|
||||
if player is not None:
|
||||
return player
|
||||
|
||||
return None
|
||||
|
||||
async def from_login(
|
||||
self,
|
||||
name: str,
|
||||
pw_md5: str,
|
||||
sql: bool = False,
|
||||
) -> Player | None:
|
||||
"""Return a player with a given name & pw_md5, from cache or sql."""
|
||||
player = self.get(name=name)
|
||||
if not player:
|
||||
if not sql:
|
||||
return None
|
||||
|
||||
player = await self.get_sql(name=name)
|
||||
if not player:
|
||||
return None
|
||||
|
||||
assert player.pw_bcrypt is not None
|
||||
|
||||
if app.state.cache.bcrypt[player.pw_bcrypt] == pw_md5.encode():
|
||||
return player
|
||||
|
||||
return None
|
||||
|
||||
def append(self, player: Player) -> None:
|
||||
"""Append `p` to the list."""
|
||||
if player in self:
|
||||
if app.settings.DEBUG:
|
||||
log(f"{player} double-added to global player list?")
|
||||
return
|
||||
|
||||
super().append(player)
|
||||
|
||||
def remove(self, player: Player) -> None:
|
||||
"""Remove `p` from the list."""
|
||||
if player not in self:
|
||||
if app.settings.DEBUG:
|
||||
log(f"{player} removed from player list when not online?")
|
||||
return
|
||||
|
||||
super().remove(player)
|
||||
|
||||
|
||||
async def initialize_ram_caches() -> None:
|
||||
"""Setup & cache the global collections before listening for connections."""
|
||||
# fetch channels, clans and pools from db
|
||||
await app.state.sessions.channels.prepare()
|
||||
|
||||
bot = await users_repo.fetch_one(id=1)
|
||||
if bot is None:
|
||||
raise RuntimeError("Bot account not found in database.")
|
||||
|
||||
# create bot & add it to online players
|
||||
app.state.sessions.bot = Player(
|
||||
id=1,
|
||||
name=bot["name"],
|
||||
priv=Privileges.UNRESTRICTED,
|
||||
pw_bcrypt=None,
|
||||
token=Player.generate_token(),
|
||||
login_time=float(0x7FFFFFFF), # (never auto-dc)
|
||||
is_bot_client=True,
|
||||
)
|
||||
app.state.sessions.players.append(app.state.sessions.bot)
|
||||
|
||||
# static api keys
|
||||
app.state.sessions.api_keys = {
|
||||
row["api_key"]: row["id"]
|
||||
for row in await app.state.services.database.fetch_all(
|
||||
"SELECT id, api_key FROM users WHERE api_key IS NOT NULL",
|
||||
)
|
||||
}
|
552
app/objects/match.py
Normal file
552
app/objects/match.py
Normal file
@@ -0,0 +1,552 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime as datetime
|
||||
from datetime import timedelta as timedelta
|
||||
from enum import IntEnum
|
||||
from enum import unique
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypedDict
|
||||
|
||||
import app.packets
|
||||
import app.settings
|
||||
import app.state
|
||||
from app.constants import regexes
|
||||
from app.constants.gamemodes import GameMode
|
||||
from app.constants.mods import Mods
|
||||
from app.objects.beatmap import Beatmap
|
||||
from app.repositories.tourney_pools import TourneyPool
|
||||
from app.utils import escape_enum
|
||||
from app.utils import pymysql_encode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from asyncio import TimerHandle
|
||||
|
||||
from app.objects.channel import Channel
|
||||
from app.objects.player import Player
|
||||
|
||||
|
||||
MAX_MATCH_NAME_LENGTH = 50
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class SlotStatus(IntEnum):
|
||||
open = 1
|
||||
locked = 2
|
||||
not_ready = 4
|
||||
ready = 8
|
||||
no_map = 16
|
||||
playing = 32
|
||||
complete = 64
|
||||
quit = 128
|
||||
|
||||
# has_player = not_ready | ready | no_map | playing | complete
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class MatchTeams(IntEnum):
|
||||
neutral = 0
|
||||
blue = 1
|
||||
red = 2
|
||||
|
||||
|
||||
"""
|
||||
# implemented by osu! and send between client/server,
|
||||
# quite frequently even, but seems useless??
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class MatchTypes(IntEnum):
|
||||
standard = 0
|
||||
powerplay = 1 # literally no idea what this is for
|
||||
"""
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class MatchWinConditions(IntEnum):
|
||||
score = 0
|
||||
accuracy = 1
|
||||
combo = 2
|
||||
scorev2 = 3
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class MatchTeamTypes(IntEnum):
|
||||
head_to_head = 0
|
||||
tag_coop = 1
|
||||
team_vs = 2
|
||||
tag_team_vs = 3
|
||||
|
||||
|
||||
class Slot:
|
||||
"""An individual player slot in an osu! multiplayer match."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.player: Player | None = None
|
||||
self.status = SlotStatus.open
|
||||
self.team = MatchTeams.neutral
|
||||
self.mods = Mods.NOMOD
|
||||
self.loaded = False
|
||||
self.skipped = False
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self.player is None
|
||||
|
||||
def copy_from(self, other: Slot) -> None:
|
||||
self.player = other.player
|
||||
self.status = other.status
|
||||
self.team = other.team
|
||||
self.mods = other.mods
|
||||
|
||||
def reset(self, new_status: SlotStatus = SlotStatus.open) -> None:
|
||||
self.player = None
|
||||
self.status = new_status
|
||||
self.team = MatchTeams.neutral
|
||||
self.mods = Mods.NOMOD
|
||||
self.loaded = False
|
||||
self.skipped = False
|
||||
|
||||
|
||||
class StartingTimers(TypedDict):
|
||||
start: TimerHandle
|
||||
alerts: list[TimerHandle]
|
||||
time: float
|
||||
|
||||
|
||||
class Match:
|
||||
"""\
|
||||
An osu! multiplayer match.
|
||||
|
||||
Possibly confusing attributes
|
||||
-----------
|
||||
_refs: set[`Player`]
|
||||
A set of players who have access to mp commands in the match.
|
||||
These can be used with the !mp <addref/rmref/listref> commands.
|
||||
|
||||
slots: list[`Slot`]
|
||||
A list of 16 `Slot` objects representing the match's slots.
|
||||
|
||||
starting: dict[str, `TimerHandle`] | None
|
||||
Used when the match is started with !mp start <seconds>.
|
||||
It stores both the starting timer, and the chat alert timers.
|
||||
|
||||
seed: `int`
|
||||
The seed used for osu!mania's random mod.
|
||||
|
||||
use_pp_scoring: `bool`
|
||||
Whether pp should be used as a win condition override during scrims.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: int,
|
||||
name: str,
|
||||
password: str,
|
||||
has_public_history: bool,
|
||||
map_name: str,
|
||||
map_id: int,
|
||||
map_md5: str,
|
||||
host_id: int,
|
||||
mode: GameMode,
|
||||
mods: Mods,
|
||||
win_condition: MatchWinConditions,
|
||||
team_type: MatchTeamTypes,
|
||||
freemods: bool,
|
||||
seed: int,
|
||||
chat_channel: Channel,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.passwd = password
|
||||
self.has_public_history = has_public_history
|
||||
|
||||
self.host_id = host_id
|
||||
self._refs: set[Player] = set()
|
||||
|
||||
self.map_id = map_id
|
||||
self.map_md5 = map_md5
|
||||
self.map_name = map_name
|
||||
self.prev_map_id = 0 # previously chosen map
|
||||
|
||||
self.mods = mods
|
||||
self.mode = mode
|
||||
self.freemods = freemods
|
||||
|
||||
self.chat = chat_channel
|
||||
self.slots = [Slot() for _ in range(16)]
|
||||
|
||||
# self.type = MatchTypes.standard
|
||||
self.team_type = team_type
|
||||
self.win_condition = win_condition
|
||||
|
||||
self.in_progress = False
|
||||
self.starting: StartingTimers | None = None
|
||||
self.seed = seed # used for mania random mod
|
||||
|
||||
self.tourney_pool: TourneyPool | None = None
|
||||
|
||||
# scrimmage stuff
|
||||
self.is_scrimming = False
|
||||
self.match_points: dict[MatchTeams | Player, int] = defaultdict(int)
|
||||
self.bans: set[tuple[Mods, int]] = set()
|
||||
self.winners: list[Player | MatchTeams | None] = [] # none for tie
|
||||
self.winning_pts = 0
|
||||
self.use_pp_scoring = False # only for scrims
|
||||
|
||||
self.tourney_clients: set[int] = set() # player ids
|
||||
|
||||
@property
|
||||
def host(self) -> Player:
|
||||
player = app.state.sessions.players.get(id=self.host_id)
|
||||
if player is None:
|
||||
raise ValueError(
|
||||
f"Host with id {self.host_id} not found for match {self!r}",
|
||||
)
|
||||
return player
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
"""The match's invitation url."""
|
||||
return f"osump://{self.id}/{self.passwd}"
|
||||
|
||||
@property
|
||||
def map_url(self) -> str:
|
||||
"""The osu! beatmap url for `self`'s map."""
|
||||
return f"https://osu.{app.settings.DOMAIN}/b/{self.map_id}"
|
||||
|
||||
@property
|
||||
def embed(self) -> str:
|
||||
"""An osu! chat embed for `self`."""
|
||||
return f"[{self.url} {self.name}]"
|
||||
|
||||
@property
|
||||
def map_embed(self) -> str:
|
||||
"""An osu! chat embed for `self`'s map."""
|
||||
return f"[{self.map_url} {self.map_name}]"
|
||||
|
||||
@property
|
||||
def refs(self) -> set[Player]:
|
||||
"""Return all players with referee permissions."""
|
||||
refs = self._refs
|
||||
|
||||
if self.host is not None:
|
||||
refs.add(self.host)
|
||||
|
||||
return refs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.name} ({self.id})>"
|
||||
|
||||
def get_slot(self, player: Player) -> Slot | None:
|
||||
"""Return the slot containing a given player."""
|
||||
for s in self.slots:
|
||||
if player is s.player:
|
||||
return s
|
||||
|
||||
return None
|
||||
|
||||
def get_slot_id(self, player: Player) -> int | None:
|
||||
"""Return the slot index containing a given player."""
|
||||
for idx, s in enumerate(self.slots):
|
||||
if player is s.player:
|
||||
return idx
|
||||
|
||||
return None
|
||||
|
||||
def get_free(self) -> int | None:
|
||||
"""Return the first unoccupied slot in multi, if any."""
|
||||
for idx, s in enumerate(self.slots):
|
||||
if s.status == SlotStatus.open:
|
||||
return idx
|
||||
|
||||
return None
|
||||
|
||||
def get_host_slot(self) -> Slot | None:
|
||||
"""Return the slot containing the host."""
|
||||
for s in self.slots:
|
||||
if s.player is not None and s.player is self.host:
|
||||
return s
|
||||
|
||||
return None
|
||||
|
||||
def copy(self, m: Match) -> None:
|
||||
"""Fully copy the data of another match obj."""
|
||||
self.map_id = m.map_id
|
||||
self.map_md5 = m.map_md5
|
||||
self.map_name = m.map_name
|
||||
self.freemods = m.freemods
|
||||
self.mode = m.mode
|
||||
self.team_type = m.team_type
|
||||
self.win_condition = m.win_condition
|
||||
self.mods = m.mods
|
||||
self.name = m.name
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
data: bytes,
|
||||
lobby: bool = True,
|
||||
immune: Sequence[int] = [],
|
||||
) -> None:
|
||||
"""Add data to be sent to all clients in the match."""
|
||||
self.chat.enqueue(data, immune)
|
||||
|
||||
lchan = app.state.sessions.channels.get_by_name("#lobby")
|
||||
if lobby and lchan and lchan.players:
|
||||
lchan.enqueue(data)
|
||||
|
||||
def enqueue_state(self, lobby: bool = True) -> None:
|
||||
"""Enqueue `self`'s state to players in the match & lobby."""
|
||||
# TODO: hmm this is pretty bad, writes twice
|
||||
|
||||
# send password only to users currently in the match.
|
||||
self.chat.enqueue(app.packets.update_match(self, send_pw=True))
|
||||
|
||||
lchan = app.state.sessions.channels.get_by_name("#lobby")
|
||||
if lobby and lchan and lchan.players:
|
||||
lchan.enqueue(app.packets.update_match(self, send_pw=False))
|
||||
|
||||
def unready_players(self, expected: SlotStatus = SlotStatus.ready) -> None:
|
||||
"""Unready any players in the `expected` state."""
|
||||
for s in self.slots:
|
||||
if s.status == expected:
|
||||
s.status = SlotStatus.not_ready
|
||||
|
||||
def reset_players_loaded_status(self) -> None:
|
||||
"""Reset all players' loaded status."""
|
||||
for s in self.slots:
|
||||
s.loaded = False
|
||||
s.skipped = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the match for all ready players with the map."""
|
||||
no_map: list[int] = []
|
||||
|
||||
for s in self.slots:
|
||||
# start each player who has the map.
|
||||
if s.player is not None:
|
||||
if s.status != SlotStatus.no_map:
|
||||
s.status = SlotStatus.playing
|
||||
else:
|
||||
no_map.append(s.player.id)
|
||||
|
||||
self.in_progress = True
|
||||
self.enqueue(app.packets.match_start(self), immune=no_map, lobby=False)
|
||||
self.enqueue_state()
|
||||
|
||||
def reset_scrim(self) -> None:
|
||||
"""Reset the current scrim's winning points & bans."""
|
||||
self.match_points.clear()
|
||||
self.winners.clear()
|
||||
self.bans.clear()
|
||||
|
||||
async def await_submissions(
|
||||
self,
|
||||
was_playing: Sequence[Slot],
|
||||
) -> tuple[dict[MatchTeams | Player, int], Sequence[Player]]:
|
||||
"""Await score submissions from all players in completed state."""
|
||||
scores: dict[MatchTeams | Player, int] = defaultdict(int)
|
||||
didnt_submit: list[Player] = []
|
||||
time_waited = 0.0 # allow up to 10s (total, not per player)
|
||||
|
||||
ffa = self.team_type in (MatchTeamTypes.head_to_head, MatchTeamTypes.tag_coop)
|
||||
|
||||
if self.use_pp_scoring:
|
||||
win_cond = "pp"
|
||||
else:
|
||||
win_cond = ("score", "acc", "max_combo", "score")[self.win_condition]
|
||||
|
||||
bmap = await Beatmap.from_md5(self.map_md5)
|
||||
|
||||
if not bmap:
|
||||
# map isn't submitted
|
||||
return {}, ()
|
||||
|
||||
for s in was_playing:
|
||||
# continue trying to fetch each player's
|
||||
# scores until they've all been submitted.
|
||||
while True:
|
||||
assert s.player is not None
|
||||
rc_score = s.player.recent_score
|
||||
|
||||
max_age = datetime.now() - timedelta(
|
||||
seconds=bmap.total_length + time_waited + 0.5,
|
||||
)
|
||||
|
||||
if (
|
||||
rc_score
|
||||
and rc_score.bmap
|
||||
and rc_score.bmap.md5 == self.map_md5
|
||||
and rc_score.server_time > max_age
|
||||
):
|
||||
# score found, add to our scores dict if != 0.
|
||||
score: int = getattr(rc_score, win_cond)
|
||||
if score:
|
||||
key: MatchTeams | Player = s.player if ffa else s.team
|
||||
scores[key] += score
|
||||
|
||||
break
|
||||
|
||||
# wait 0.5s and try again
|
||||
await asyncio.sleep(0.5)
|
||||
time_waited += 0.5
|
||||
|
||||
if time_waited > 10:
|
||||
# inform the match this user didn't
|
||||
# submit a score in time, and skip them.
|
||||
didnt_submit.append(s.player)
|
||||
break
|
||||
|
||||
# all scores retrieved, update the match.
|
||||
return scores, didnt_submit
|
||||
|
||||
async def update_matchpoints(self, was_playing: Sequence[Slot]) -> None:
|
||||
"""\
|
||||
Determine the winner from `scores`, increment & inform players.
|
||||
|
||||
This automatically works with the match settings (such as
|
||||
win condition, teams, & co-op) to determine the appropriate
|
||||
winner, and will use any team names included in the match name,
|
||||
along with the match name (fmt: OWC2020: (Team1) vs. (Team2)).
|
||||
|
||||
For the examples, we'll use accuracy as a win condition.
|
||||
|
||||
Teams, match title: `OWC2015: (United States) vs. (China)`.
|
||||
United States takes the point! (293.32% vs 292.12%)
|
||||
Total Score: United States | 7 - 2 | China
|
||||
United States takes the match, finishing OWC2015 with a score of 7 - 2!
|
||||
|
||||
FFA, the top <=3 players will be listed for the total score.
|
||||
Justice takes the point! (94.32% [Match avg. 91.22%])
|
||||
Total Score: Justice - 3 | cmyui - 2 | FrostiDrinks - 2
|
||||
Justice takes the match, finishing with a score of 4 - 2!
|
||||
"""
|
||||
|
||||
scores, didnt_submit = await self.await_submissions(was_playing)
|
||||
|
||||
for player in didnt_submit:
|
||||
self.chat.send_bot(f"{player} didn't submit a score (timeout: 10s).")
|
||||
|
||||
if not scores:
|
||||
self.chat.send_bot("Scores could not be calculated.")
|
||||
return None
|
||||
|
||||
ffa = self.team_type in (
|
||||
MatchTeamTypes.head_to_head,
|
||||
MatchTeamTypes.tag_coop,
|
||||
)
|
||||
|
||||
# all scores are equal, it was a tie.
|
||||
if len(scores) != 1 and len(set(scores.values())) == 1:
|
||||
self.winners.append(None)
|
||||
self.chat.send_bot("The point has ended in a tie!")
|
||||
return None
|
||||
|
||||
# Find the winner & increment their matchpoints.
|
||||
winner: Player | MatchTeams = max(scores, key=lambda k: scores[k])
|
||||
self.winners.append(winner)
|
||||
self.match_points[winner] += 1
|
||||
|
||||
msg: list[str] = []
|
||||
|
||||
def add_suffix(score: int | float) -> str | int | float:
|
||||
if self.use_pp_scoring:
|
||||
return f"{score:.2f}pp"
|
||||
elif self.win_condition == MatchWinConditions.accuracy:
|
||||
return f"{score:.2f}%"
|
||||
elif self.win_condition == MatchWinConditions.combo:
|
||||
return f"{score}x"
|
||||
else:
|
||||
return str(score)
|
||||
|
||||
if ffa:
|
||||
from app.objects.player import Player
|
||||
|
||||
assert isinstance(winner, Player)
|
||||
|
||||
msg.append(
|
||||
f"{winner.name} takes the point! ({add_suffix(scores[winner])} "
|
||||
f"[Match avg. {add_suffix(sum(scores.values()) / len(scores))}])",
|
||||
)
|
||||
|
||||
wmp = self.match_points[winner]
|
||||
|
||||
# check if match point #1 has enough points to win.
|
||||
if self.winning_pts and wmp == self.winning_pts:
|
||||
# we have a champion, announce & reset our match.
|
||||
self.is_scrimming = False
|
||||
self.reset_scrim()
|
||||
self.bans.clear()
|
||||
|
||||
m = f"{winner.name} takes the match! Congratulations!"
|
||||
else:
|
||||
# no winner, just announce the match points so far.
|
||||
# for ffa, we'll only announce the top <=3 players.
|
||||
m_points = sorted(self.match_points.items(), key=lambda x: x[1])
|
||||
m = f"Total Score: {' | '.join([f'{k.name} - {v}' for k, v in m_points])}"
|
||||
|
||||
msg.append(m)
|
||||
del m
|
||||
|
||||
else: # teams
|
||||
assert isinstance(winner, MatchTeams)
|
||||
|
||||
r_match = regexes.TOURNEY_MATCHNAME.match(self.name)
|
||||
if r_match:
|
||||
match_name = r_match["name"]
|
||||
team_names = {
|
||||
MatchTeams.blue: r_match["T1"],
|
||||
MatchTeams.red: r_match["T2"],
|
||||
}
|
||||
else:
|
||||
match_name = self.name
|
||||
team_names = {MatchTeams.blue: "Blue", MatchTeams.red: "Red"}
|
||||
|
||||
# teams are binary, so we have a loser.
|
||||
if winner is MatchTeams.blue:
|
||||
loser = MatchTeams.red
|
||||
else:
|
||||
loser = MatchTeams.blue
|
||||
|
||||
# from match name if available, else blue/red.
|
||||
wname = team_names[winner]
|
||||
lname = team_names[loser]
|
||||
|
||||
# scores from the recent play
|
||||
# (according to win condition)
|
||||
ws = add_suffix(scores[winner])
|
||||
ls = add_suffix(scores[loser])
|
||||
|
||||
# total win/loss score in the match.
|
||||
wmp = self.match_points[winner]
|
||||
lmp = self.match_points[loser]
|
||||
|
||||
# announce the score for the most recent play.
|
||||
msg.append(f"{wname} takes the point! ({ws} vs. {ls})")
|
||||
|
||||
# check if the winner has enough match points to win the match.
|
||||
if self.winning_pts and wmp == self.winning_pts:
|
||||
# we have a champion, announce & reset our match.
|
||||
self.is_scrimming = False
|
||||
self.reset_scrim()
|
||||
|
||||
msg.append(
|
||||
f"{wname} takes the match, finishing {match_name} "
|
||||
f"with a score of {wmp} - {lmp}! Congratulations!",
|
||||
)
|
||||
else:
|
||||
# no winner, just announce the match points so far.
|
||||
msg.append(f"Total Score: {wname} | {wmp} - {lmp} | {lname}")
|
||||
|
||||
if didnt_submit:
|
||||
self.chat.send_bot(
|
||||
"If you'd like to perform a rematch, "
|
||||
"please use the `!mp rematch` command.",
|
||||
)
|
||||
|
||||
for line in msg:
|
||||
self.chat.send_bot(line)
|
8
app/objects/models.py
Normal file
8
app/objects/models.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OsuBeatmapRequestForm(BaseModel):
|
||||
Filenames: list[str]
|
||||
Ids: list[int]
|
1017
app/objects/player.py
Normal file
1017
app/objects/player.py
Normal file
File diff suppressed because it is too large
Load Diff
453
app/objects/score.py
Normal file
453
app/objects/score.py
Normal file
@@ -0,0 +1,453 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from enum import IntEnum
|
||||
from enum import unique
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import app.state
|
||||
import app.usecases.performance
|
||||
import app.utils
|
||||
from app.constants.clientflags import ClientFlags
|
||||
from app.constants.gamemodes import GameMode
|
||||
from app.constants.mods import Mods
|
||||
from app.objects.beatmap import Beatmap
|
||||
from app.repositories import scores as scores_repo
|
||||
from app.usecases.performance import ScoreParams
|
||||
from app.utils import escape_enum
|
||||
from app.utils import pymysql_encode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.objects.player import Player
|
||||
|
||||
BEATMAPS_PATH = Path.cwd() / ".data/osu"
|
||||
|
||||
|
||||
@unique
|
||||
class Grade(IntEnum):
|
||||
# NOTE: these are implemented in the opposite order
|
||||
# as osu! to make more sense with <> operators.
|
||||
N = 0
|
||||
F = 1
|
||||
D = 2
|
||||
C = 3
|
||||
B = 4
|
||||
A = 5
|
||||
S = 6 # S
|
||||
SH = 7 # HD S
|
||||
X = 8 # SS
|
||||
XH = 9 # HD SS
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_str(cls, s: str) -> Grade:
|
||||
return {
|
||||
"xh": Grade.XH,
|
||||
"x": Grade.X,
|
||||
"sh": Grade.SH,
|
||||
"s": Grade.S,
|
||||
"a": Grade.A,
|
||||
"b": Grade.B,
|
||||
"c": Grade.C,
|
||||
"d": Grade.D,
|
||||
"f": Grade.F,
|
||||
"n": Grade.N,
|
||||
}[s.lower()]
|
||||
|
||||
def __format__(self, format_spec: str) -> str:
|
||||
if format_spec == "stats_column":
|
||||
return f"{self.name.lower()}_count"
|
||||
else:
|
||||
raise ValueError(f"Invalid format specifier {format_spec}")
|
||||
|
||||
|
||||
@unique
|
||||
@pymysql_encode(escape_enum)
|
||||
class SubmissionStatus(IntEnum):
|
||||
# TODO: make a system more like bancho's?
|
||||
FAILED = 0
|
||||
SUBMITTED = 1
|
||||
BEST = 2
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return {
|
||||
self.FAILED: "Failed",
|
||||
self.SUBMITTED: "Submitted",
|
||||
self.BEST: "Best",
|
||||
}[self]
|
||||
|
||||
|
||||
class Score:
|
||||
"""\
|
||||
Server side representation of an osu! score; any gamemode.
|
||||
|
||||
Possibly confusing attributes
|
||||
-----------
|
||||
bmap: `Beatmap | None`
|
||||
A beatmap obj representing the osu map.
|
||||
|
||||
player: `Player | None`
|
||||
A player obj of the player who submitted the score.
|
||||
|
||||
grade: `Grade`
|
||||
The letter grade in the score.
|
||||
|
||||
rank: `int`
|
||||
The leaderboard placement of the score.
|
||||
|
||||
perfect: `bool`
|
||||
Whether the score is a full-combo.
|
||||
|
||||
time_elapsed: `int`
|
||||
The total elapsed time of the play (in milliseconds).
|
||||
|
||||
client_flags: `int`
|
||||
osu!'s old anticheat flags.
|
||||
|
||||
prev_best: `Score | None`
|
||||
The previous best score before this play was submitted.
|
||||
NOTE: just because a score has a `prev_best` attribute does
|
||||
mean the score is our best score on the map! the `status`
|
||||
value will always be accurate for any score.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# TODO: check whether the reamining Optional's should be
|
||||
self.id: int | None = None
|
||||
self.bmap: Beatmap | None = None
|
||||
self.player: Player | None = None
|
||||
|
||||
self.mode: GameMode
|
||||
self.mods: Mods
|
||||
|
||||
self.pp: float
|
||||
self.sr: float
|
||||
self.score: int
|
||||
self.max_combo: int
|
||||
self.acc: float
|
||||
|
||||
# TODO: perhaps abstract these differently
|
||||
# since they're mode dependant? feels weird..
|
||||
self.n300: int
|
||||
self.n100: int # n150 for taiko
|
||||
self.n50: int
|
||||
self.nmiss: int
|
||||
self.ngeki: int
|
||||
self.nkatu: int
|
||||
|
||||
self.grade: Grade
|
||||
|
||||
self.passed: bool
|
||||
self.perfect: bool
|
||||
self.status: SubmissionStatus
|
||||
|
||||
self.client_time: datetime
|
||||
self.server_time: datetime
|
||||
self.time_elapsed: int
|
||||
|
||||
self.client_flags: ClientFlags
|
||||
self.client_checksum: str
|
||||
|
||||
self.rank: int | None = None
|
||||
self.prev_best: Score | None = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# TODO: i really need to clean up my reprs
|
||||
try:
|
||||
assert self.bmap is not None
|
||||
return (
|
||||
f"<{self.acc:.2f}% {self.max_combo}x {self.nmiss}M "
|
||||
f"#{self.rank} on {self.bmap.full_name} for {self.pp:,.2f}pp>"
|
||||
)
|
||||
except:
|
||||
return super().__repr__()
|
||||
|
||||
"""Classmethods to fetch a score object from various data types."""
|
||||
|
||||
@classmethod
|
||||
async def from_sql(cls, score_id: int) -> Score | None:
|
||||
"""Create a score object from sql using its scoreid."""
|
||||
rec = await scores_repo.fetch_one(score_id)
|
||||
|
||||
if rec is None:
|
||||
return None
|
||||
|
||||
s = cls()
|
||||
|
||||
s.id = rec["id"]
|
||||
s.bmap = await Beatmap.from_md5(rec["map_md5"])
|
||||
s.player = await app.state.sessions.players.from_cache_or_sql(id=rec["userid"])
|
||||
|
||||
s.sr = 0.0 # TODO
|
||||
|
||||
s.pp = rec["pp"]
|
||||
s.score = rec["score"]
|
||||
s.max_combo = rec["max_combo"]
|
||||
s.mods = Mods(rec["mods"])
|
||||
s.acc = rec["acc"]
|
||||
s.n300 = rec["n300"]
|
||||
s.n100 = rec["n100"]
|
||||
s.n50 = rec["n50"]
|
||||
s.nmiss = rec["nmiss"]
|
||||
s.ngeki = rec["ngeki"]
|
||||
s.nkatu = rec["nkatu"]
|
||||
s.grade = Grade.from_str(rec["grade"])
|
||||
s.perfect = rec["perfect"] == 1
|
||||
s.status = SubmissionStatus(rec["status"])
|
||||
s.passed = s.status != SubmissionStatus.FAILED
|
||||
s.mode = GameMode(rec["mode"])
|
||||
s.server_time = rec["play_time"]
|
||||
s.time_elapsed = rec["time_elapsed"]
|
||||
s.client_flags = ClientFlags(rec["client_flags"])
|
||||
s.client_checksum = rec["online_checksum"]
|
||||
|
||||
if s.bmap:
|
||||
s.rank = await s.calculate_placement()
|
||||
|
||||
return s
|
||||
|
||||
@classmethod
|
||||
def from_submission(cls, data: list[str]) -> Score:
|
||||
"""Create a score object from an osu! submission string."""
|
||||
s = cls()
|
||||
|
||||
""" parse the following format
|
||||
# 0 online_checksum
|
||||
# 1 n300
|
||||
# 2 n100
|
||||
# 3 n50
|
||||
# 4 ngeki
|
||||
# 5 nkatu
|
||||
# 6 nmiss
|
||||
# 7 score
|
||||
# 8 max_combo
|
||||
# 9 perfect
|
||||
# 10 grade
|
||||
# 11 mods
|
||||
# 12 passed
|
||||
# 13 gamemode
|
||||
# 14 play_time # yyMMddHHmmss
|
||||
# 15 osu_version + (" " * client_flags)
|
||||
"""
|
||||
|
||||
s.client_checksum = data[0]
|
||||
s.n300 = int(data[1])
|
||||
s.n100 = int(data[2])
|
||||
s.n50 = int(data[3])
|
||||
s.ngeki = int(data[4])
|
||||
s.nkatu = int(data[5])
|
||||
s.nmiss = int(data[6])
|
||||
s.score = int(data[7])
|
||||
s.max_combo = int(data[8])
|
||||
s.perfect = data[9] == "True"
|
||||
s.grade = Grade.from_str(data[10])
|
||||
s.mods = Mods(int(data[11]))
|
||||
s.passed = data[12] == "True"
|
||||
s.mode = GameMode.from_params(int(data[13]), s.mods)
|
||||
s.client_time = datetime.strptime(data[14], "%y%m%d%H%M%S")
|
||||
s.client_flags = ClientFlags(data[15].count(" ") & ~4)
|
||||
|
||||
s.server_time = datetime.now()
|
||||
|
||||
return s
|
||||
|
||||
def compute_online_checksum(
|
||||
self,
|
||||
osu_version: str,
|
||||
osu_client_hash: str,
|
||||
storyboard_checksum: str,
|
||||
) -> str:
|
||||
"""Validate the online checksum of the score."""
|
||||
assert self.player is not None
|
||||
assert self.bmap is not None
|
||||
|
||||
return hashlib.md5(
|
||||
"chickenmcnuggets{0}o15{1}{2}smustard{3}{4}uu{5}{6}{7}{8}{9}{10}{11}Q{12}{13}{15}{14:%y%m%d%H%M%S}{16}{17}".format(
|
||||
self.n100 + self.n300,
|
||||
self.n50,
|
||||
self.ngeki,
|
||||
self.nkatu,
|
||||
self.nmiss,
|
||||
self.bmap.md5,
|
||||
self.max_combo,
|
||||
self.perfect,
|
||||
self.player.name,
|
||||
self.score,
|
||||
self.grade.name,
|
||||
int(self.mods),
|
||||
self.passed,
|
||||
self.mode.as_vanilla,
|
||||
self.client_time,
|
||||
osu_version, # 20210520
|
||||
osu_client_hash,
|
||||
storyboard_checksum,
|
||||
# yyMMddHHmmss
|
||||
).encode(),
|
||||
).hexdigest()
|
||||
|
||||
"""Methods to calculate internal data for a score."""
|
||||
|
||||
async def calculate_placement(self) -> int:
|
||||
assert self.bmap is not None
|
||||
|
||||
if self.mode >= GameMode.RELAX_OSU:
|
||||
scoring_metric = "pp"
|
||||
score = self.pp
|
||||
else:
|
||||
scoring_metric = "score"
|
||||
score = self.score
|
||||
|
||||
num_better_scores: int | None = await app.state.services.database.fetch_val(
|
||||
"SELECT COUNT(*) AS c FROM scores s "
|
||||
"INNER JOIN users u ON u.id = s.userid "
|
||||
"WHERE s.map_md5 = :map_md5 AND s.mode = :mode "
|
||||
"AND s.status = 2 AND u.priv & 1 "
|
||||
f"AND s.{scoring_metric} > :score",
|
||||
{
|
||||
"map_md5": self.bmap.md5,
|
||||
"mode": self.mode,
|
||||
"score": score,
|
||||
},
|
||||
column=0, # COUNT(*)
|
||||
)
|
||||
assert num_better_scores is not None
|
||||
return num_better_scores + 1
|
||||
|
||||
def calculate_performance(self, beatmap_id: int) -> tuple[float, float]:
|
||||
"""Calculate PP and star rating for our score."""
|
||||
mode_vn = self.mode.as_vanilla
|
||||
|
||||
score_args = ScoreParams(
|
||||
mode=mode_vn,
|
||||
mods=int(self.mods),
|
||||
combo=self.max_combo,
|
||||
ngeki=self.ngeki,
|
||||
n300=self.n300,
|
||||
nkatu=self.nkatu,
|
||||
n100=self.n100,
|
||||
n50=self.n50,
|
||||
nmiss=self.nmiss,
|
||||
)
|
||||
|
||||
result = app.usecases.performance.calculate_performances(
|
||||
osu_file_path=str(BEATMAPS_PATH / f"{beatmap_id}.osu"),
|
||||
scores=[score_args],
|
||||
)
|
||||
|
||||
return result[0]["performance"]["pp"], result[0]["difficulty"]["stars"]
|
||||
|
||||
async def calculate_status(self) -> None:
|
||||
"""Calculate the submission status of a submitted score."""
|
||||
assert self.player is not None
|
||||
assert self.bmap is not None
|
||||
|
||||
recs = await scores_repo.fetch_many(
|
||||
user_id=self.player.id,
|
||||
map_md5=self.bmap.md5,
|
||||
mode=self.mode,
|
||||
status=SubmissionStatus.BEST,
|
||||
)
|
||||
|
||||
if recs:
|
||||
rec = recs[0]
|
||||
|
||||
# we have a score on the map.
|
||||
# save it as our previous best score.
|
||||
self.prev_best = await Score.from_sql(rec["id"])
|
||||
assert self.prev_best is not None
|
||||
|
||||
# if our new score is better, update
|
||||
# both of our score's submission statuses.
|
||||
# NOTE: this will be updated in sql later on in submission
|
||||
if self.pp > rec["pp"]:
|
||||
self.status = SubmissionStatus.BEST
|
||||
self.prev_best.status = SubmissionStatus.SUBMITTED
|
||||
else:
|
||||
self.status = SubmissionStatus.SUBMITTED
|
||||
else:
|
||||
# this is our first score on the map.
|
||||
self.status = SubmissionStatus.BEST
|
||||
|
||||
def calculate_accuracy(self) -> float:
|
||||
"""Calculate the accuracy of our score."""
|
||||
mode_vn = self.mode.as_vanilla
|
||||
|
||||
if mode_vn == 0: # osu!
|
||||
total = self.n300 + self.n100 + self.n50 + self.nmiss
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
return (
|
||||
100.0
|
||||
* ((self.n300 * 300.0) + (self.n100 * 100.0) + (self.n50 * 50.0))
|
||||
/ (total * 300.0)
|
||||
)
|
||||
|
||||
elif mode_vn == 1: # osu!taiko
|
||||
total = self.n300 + self.n100 + self.nmiss
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
return 100.0 * ((self.n100 * 0.5) + self.n300) / total
|
||||
|
||||
elif mode_vn == 2: # osu!catch
|
||||
total = self.n300 + self.n100 + self.n50 + self.nkatu + self.nmiss
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
return 100.0 * (self.n300 + self.n100 + self.n50) / total
|
||||
|
||||
elif mode_vn == 3: # osu!mania
|
||||
total = (
|
||||
self.n300 + self.n100 + self.n50 + self.ngeki + self.nkatu + self.nmiss
|
||||
)
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
|
||||
if self.mods & Mods.SCOREV2:
|
||||
return (
|
||||
100.0
|
||||
* (
|
||||
(self.n50 * 50.0)
|
||||
+ (self.n100 * 100.0)
|
||||
+ (self.nkatu * 200.0)
|
||||
+ (self.n300 * 300.0)
|
||||
+ (self.ngeki * 305.0)
|
||||
)
|
||||
/ (total * 305.0)
|
||||
)
|
||||
|
||||
return (
|
||||
100.0
|
||||
* (
|
||||
(self.n50 * 50.0)
|
||||
+ (self.n100 * 100.0)
|
||||
+ (self.nkatu * 200.0)
|
||||
+ ((self.n300 + self.ngeki) * 300.0)
|
||||
)
|
||||
/ (total * 300.0)
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Invalid vanilla mode {mode_vn}")
|
||||
|
||||
""" Methods for updating a score. """
|
||||
|
||||
async def increment_replay_views(self) -> None:
|
||||
# TODO: move replay views to be per-score rather than per-user
|
||||
assert self.player is not None
|
||||
|
||||
# TODO: apparently cached stats don't store replay views?
|
||||
# need to refactor that to be able to use stats_repo here
|
||||
await app.state.services.database.execute(
|
||||
f"UPDATE stats "
|
||||
"SET replay_views = replay_views + 1 "
|
||||
"WHERE id = :user_id AND mode = :mode",
|
||||
{"user_id": self.player.id, "mode": self.mode},
|
||||
)
|
1289
app/packets.py
Normal file
1289
app/packets.py
Normal file
File diff suppressed because it is too large
Load Diff
15
app/repositories/__init__.py
Normal file
15
app/repositories/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
from sqlalchemy.orm import registry
|
||||
|
||||
mapper_registry = registry()
|
||||
|
||||
|
||||
class Base(metaclass=DeclarativeMeta):
|
||||
__abstract__ = True
|
||||
|
||||
registry = mapper_registry
|
||||
metadata = mapper_registry.metadata
|
||||
|
||||
__init__ = mapper_registry.constructor
|
173
app/repositories/achievements.py
Normal file
173
app/repositories/achievements.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.objects.score import Score
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
|
||||
|
||||
class AchievementsTable(Base):
|
||||
__tablename__ = "achievements"
|
||||
|
||||
id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True)
|
||||
file = Column("file", String(128), nullable=False)
|
||||
name = Column("name", String(128, collation="utf8"), nullable=False)
|
||||
desc = Column("desc", String(256, collation="utf8"), nullable=False)
|
||||
cond = Column("cond", String(64), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("achievements_desc_uindex", desc, unique=True),
|
||||
Index("achievements_file_uindex", file, unique=True),
|
||||
Index("achievements_name_uindex", name, unique=True),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
AchievementsTable.id,
|
||||
AchievementsTable.file,
|
||||
AchievementsTable.name,
|
||||
AchievementsTable.desc,
|
||||
AchievementsTable.cond,
|
||||
)
|
||||
|
||||
|
||||
class Achievement(TypedDict):
|
||||
id: int
|
||||
file: str
|
||||
name: str
|
||||
desc: str
|
||||
cond: Callable[[Score, int], bool]
|
||||
|
||||
|
||||
async def create(
|
||||
file: str,
|
||||
name: str,
|
||||
desc: str,
|
||||
cond: str,
|
||||
) -> Achievement:
|
||||
"""Create a new achievement."""
|
||||
insert_stmt = insert(AchievementsTable).values(
|
||||
file=file,
|
||||
name=name,
|
||||
desc=desc,
|
||||
cond=cond,
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == rec_id)
|
||||
achievement = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert achievement is not None
|
||||
|
||||
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
|
||||
return cast(Achievement, achievement)
|
||||
|
||||
|
||||
async def fetch_one(
|
||||
id: int | None = None,
|
||||
name: str | None = None,
|
||||
) -> Achievement | None:
|
||||
"""Fetch a single achievement."""
|
||||
if id is None and name is None:
|
||||
raise ValueError("Must provide at least one parameter.")
|
||||
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
|
||||
if id is not None:
|
||||
select_stmt = select_stmt.where(AchievementsTable.id == id)
|
||||
if name is not None:
|
||||
select_stmt = select_stmt.where(AchievementsTable.name == name)
|
||||
|
||||
achievement = await app.state.services.database.fetch_one(select_stmt)
|
||||
if achievement is None:
|
||||
return None
|
||||
|
||||
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
|
||||
return cast(Achievement, achievement)
|
||||
|
||||
|
||||
async def fetch_count() -> int:
|
||||
"""Fetch the number of achievements."""
|
||||
select_stmt = select(func.count().label("count")).select_from(AchievementsTable)
|
||||
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[Achievement]:
|
||||
"""Fetch a list of achievements."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
achievements = await app.state.services.database.fetch_all(select_stmt)
|
||||
for achievement in achievements:
|
||||
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
|
||||
|
||||
return cast(list[Achievement], achievements)
|
||||
|
||||
|
||||
async def partial_update(
|
||||
id: int,
|
||||
file: str | _UnsetSentinel = UNSET,
|
||||
name: str | _UnsetSentinel = UNSET,
|
||||
desc: str | _UnsetSentinel = UNSET,
|
||||
cond: str | _UnsetSentinel = UNSET,
|
||||
) -> Achievement | None:
|
||||
"""Update an existing achievement."""
|
||||
update_stmt = update(AchievementsTable).where(AchievementsTable.id == id)
|
||||
if not isinstance(file, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(file=file)
|
||||
if not isinstance(name, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(name=name)
|
||||
if not isinstance(desc, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(desc=desc)
|
||||
if not isinstance(cond, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(cond=cond)
|
||||
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == id)
|
||||
achievement = await app.state.services.database.fetch_one(select_stmt)
|
||||
if achievement is None:
|
||||
return None
|
||||
|
||||
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
|
||||
return cast(Achievement, achievement)
|
||||
|
||||
|
||||
async def delete_one(
|
||||
id: int,
|
||||
) -> Achievement | None:
|
||||
"""Delete an existing achievement."""
|
||||
select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == id)
|
||||
achievement = await app.state.services.database.fetch_one(select_stmt)
|
||||
if achievement is None:
|
||||
return None
|
||||
|
||||
delete_stmt = delete(AchievementsTable).where(AchievementsTable.id == id)
|
||||
await app.state.services.database.execute(delete_stmt)
|
||||
|
||||
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
|
||||
return cast(Achievement, achievement)
|
184
app/repositories/channels.py
Normal file
184
app/repositories/channels.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class ChannelsTable(Base):
|
||||
__tablename__ = "channels"
|
||||
|
||||
id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True)
|
||||
name = Column("name", String(32), nullable=False)
|
||||
topic = Column("topic", String(256), nullable=False)
|
||||
read_priv = Column("read_priv", Integer, nullable=False, server_default="1")
|
||||
write_priv = Column("write_priv", Integer, nullable=False, server_default="2")
|
||||
auto_join = Column("auto_join", TINYINT(1), nullable=False, server_default="0")
|
||||
|
||||
__table_args__ = (
|
||||
Index("channels_name_uindex", name, unique=True),
|
||||
Index("channels_auto_join_index", auto_join),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
ChannelsTable.id,
|
||||
ChannelsTable.name,
|
||||
ChannelsTable.topic,
|
||||
ChannelsTable.read_priv,
|
||||
ChannelsTable.write_priv,
|
||||
ChannelsTable.auto_join,
|
||||
)
|
||||
|
||||
|
||||
class Channel(TypedDict):
|
||||
id: int
|
||||
name: str
|
||||
topic: str
|
||||
read_priv: int
|
||||
write_priv: int
|
||||
auto_join: bool
|
||||
|
||||
|
||||
async def create(
|
||||
name: str,
|
||||
topic: str,
|
||||
read_priv: int,
|
||||
write_priv: int,
|
||||
auto_join: bool,
|
||||
) -> Channel:
|
||||
"""Create a new channel."""
|
||||
insert_stmt = insert(ChannelsTable).values(
|
||||
name=name,
|
||||
topic=topic,
|
||||
read_priv=read_priv,
|
||||
write_priv=write_priv,
|
||||
auto_join=auto_join,
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(ChannelsTable.id == rec_id)
|
||||
channel = await app.state.services.database.fetch_one(select_stmt)
|
||||
|
||||
assert channel is not None
|
||||
return cast(Channel, channel)
|
||||
|
||||
|
||||
async def fetch_one(
|
||||
id: int | None = None,
|
||||
name: str | None = None,
|
||||
) -> Channel | None:
|
||||
"""Fetch a single channel."""
|
||||
if id is None and name is None:
|
||||
raise ValueError("Must provide at least one parameter.")
|
||||
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
|
||||
if id is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.id == id)
|
||||
if name is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.name == name)
|
||||
|
||||
channel = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Channel | None, channel)
|
||||
|
||||
|
||||
async def fetch_count(
|
||||
read_priv: int | None = None,
|
||||
write_priv: int | None = None,
|
||||
auto_join: bool | None = None,
|
||||
) -> int:
|
||||
if read_priv is None and write_priv is None and auto_join is None:
|
||||
raise ValueError("Must provide at least one parameter.")
|
||||
|
||||
select_stmt = select(func.count().label("count")).select_from(ChannelsTable)
|
||||
|
||||
if read_priv is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.read_priv == read_priv)
|
||||
if write_priv is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.write_priv == write_priv)
|
||||
if auto_join is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.auto_join == auto_join)
|
||||
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
read_priv: int | None = None,
|
||||
write_priv: int | None = None,
|
||||
auto_join: bool | None = None,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[Channel]:
|
||||
"""Fetch multiple channels from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
|
||||
if read_priv is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.read_priv == read_priv)
|
||||
if write_priv is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.write_priv == write_priv)
|
||||
if auto_join is not None:
|
||||
select_stmt = select_stmt.where(ChannelsTable.auto_join == auto_join)
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
channels = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Channel], channels)
|
||||
|
||||
|
||||
async def partial_update(
|
||||
name: str,
|
||||
topic: str | _UnsetSentinel = UNSET,
|
||||
read_priv: int | _UnsetSentinel = UNSET,
|
||||
write_priv: int | _UnsetSentinel = UNSET,
|
||||
auto_join: bool | _UnsetSentinel = UNSET,
|
||||
) -> Channel | None:
|
||||
"""Update a channel in the database."""
|
||||
update_stmt = update(ChannelsTable).where(ChannelsTable.name == name)
|
||||
|
||||
if not isinstance(topic, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(topic=topic)
|
||||
if not isinstance(read_priv, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(read_priv=read_priv)
|
||||
if not isinstance(write_priv, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(write_priv=write_priv)
|
||||
if not isinstance(auto_join, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(auto_join=auto_join)
|
||||
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(ChannelsTable.name == name)
|
||||
channel = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Channel | None, channel)
|
||||
|
||||
|
||||
async def delete_one(
|
||||
name: str,
|
||||
) -> Channel | None:
|
||||
"""Delete a channel from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(ChannelsTable.name == name)
|
||||
channel = await app.state.services.database.fetch_one(select_stmt)
|
||||
if channel is None:
|
||||
return None
|
||||
|
||||
delete_stmt = delete(ChannelsTable).where(ChannelsTable.name == name)
|
||||
await app.state.services.database.execute(delete_stmt)
|
||||
return cast(Channel | None, channel)
|
156
app/repositories/clans.py
Normal file
156
app/repositories/clans.py
Normal file
@@ -0,0 +1,156 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class ClansTable(Base):
|
||||
__tablename__ = "clans"
|
||||
|
||||
id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True)
|
||||
name = Column("name", String(16, collation="utf8"), nullable=False)
|
||||
tag = Column("tag", String(6, collation="utf8"), nullable=False)
|
||||
owner = Column("owner", Integer, nullable=False)
|
||||
created_at = Column("created_at", DateTime, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("clans_name_uindex", name, unique=False),
|
||||
Index("clans_owner_uindex", owner, unique=True),
|
||||
Index("clans_tag_uindex", tag, unique=True),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
ClansTable.id,
|
||||
ClansTable.name,
|
||||
ClansTable.tag,
|
||||
ClansTable.owner,
|
||||
ClansTable.created_at,
|
||||
)
|
||||
|
||||
|
||||
class Clan(TypedDict):
|
||||
id: int
|
||||
name: str
|
||||
tag: str
|
||||
owner: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
async def create(
|
||||
name: str,
|
||||
tag: str,
|
||||
owner: int,
|
||||
) -> Clan:
|
||||
"""Create a new clan in the database."""
|
||||
insert_stmt = insert(ClansTable).values(
|
||||
name=name,
|
||||
tag=tag,
|
||||
owner=owner,
|
||||
created_at=func.now(),
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(ClansTable.id == rec_id)
|
||||
clan = await app.state.services.database.fetch_one(select_stmt)
|
||||
|
||||
assert clan is not None
|
||||
return cast(Clan, clan)
|
||||
|
||||
|
||||
async def fetch_one(
|
||||
id: int | None = None,
|
||||
name: str | None = None,
|
||||
tag: str | None = None,
|
||||
owner: int | None = None,
|
||||
) -> Clan | None:
|
||||
"""Fetch a single clan from the database."""
|
||||
if id is None and name is None and tag is None and owner is None:
|
||||
raise ValueError("Must provide at least one parameter.")
|
||||
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
|
||||
if id is not None:
|
||||
select_stmt = select_stmt.where(ClansTable.id == id)
|
||||
if name is not None:
|
||||
select_stmt = select_stmt.where(ClansTable.name == name)
|
||||
if tag is not None:
|
||||
select_stmt = select_stmt.where(ClansTable.tag == tag)
|
||||
if owner is not None:
|
||||
select_stmt = select_stmt.where(ClansTable.owner == owner)
|
||||
|
||||
clan = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Clan | None, clan)
|
||||
|
||||
|
||||
async def fetch_count() -> int:
|
||||
"""Fetch the number of clans in the database."""
|
||||
select_stmt = select(func.count().label("count")).select_from(ClansTable)
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[Clan]:
|
||||
"""Fetch many clans from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
clans = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Clan], clans)
|
||||
|
||||
|
||||
async def partial_update(
|
||||
id: int,
|
||||
name: str | _UnsetSentinel = UNSET,
|
||||
tag: str | _UnsetSentinel = UNSET,
|
||||
owner: int | _UnsetSentinel = UNSET,
|
||||
) -> Clan | None:
|
||||
"""Update a clan in the database."""
|
||||
update_stmt = update(ClansTable).where(ClansTable.id == id)
|
||||
if not isinstance(name, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(name=name)
|
||||
if not isinstance(tag, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(tag=tag)
|
||||
if not isinstance(owner, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(owner=owner)
|
||||
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(ClansTable.id == id)
|
||||
clan = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Clan | None, clan)
|
||||
|
||||
|
||||
async def delete_one(id: int) -> Clan | None:
|
||||
"""Delete a clan from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(ClansTable.id == id)
|
||||
clan = await app.state.services.database.fetch_one(select_stmt)
|
||||
if clan is None:
|
||||
return None
|
||||
|
||||
delete_stmt = delete(ClansTable).where(ClansTable.id == id)
|
||||
await app.state.services.database.execute(delete_stmt)
|
||||
return cast(Clan, clan)
|
133
app/repositories/client_hashes.py
Normal file
133
app/repositories/client_hashes.py
Normal file
@@ -0,0 +1,133 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import CHAR
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.mysql import Insert as MysqlInsert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.sql import ColumnElement
|
||||
from sqlalchemy.types import Boolean
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
from app.repositories.users import UsersTable
|
||||
|
||||
|
||||
class ClientHashesTable(Base):
|
||||
__tablename__ = "client_hashes"
|
||||
|
||||
userid = Column("userid", Integer, nullable=False, primary_key=True)
|
||||
osupath = Column("osupath", CHAR(32), nullable=False, primary_key=True)
|
||||
adapters = Column("adapters", CHAR(32), nullable=False, primary_key=True)
|
||||
uninstall_id = Column("uninstall_id", CHAR(32), nullable=False, primary_key=True)
|
||||
disk_serial = Column("disk_serial", CHAR(32), nullable=False, primary_key=True)
|
||||
latest_time = Column("latest_time", DateTime, nullable=False)
|
||||
occurrences = Column("occurrences", Integer, nullable=False, server_default="0")
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
ClientHashesTable.userid,
|
||||
ClientHashesTable.osupath,
|
||||
ClientHashesTable.adapters,
|
||||
ClientHashesTable.uninstall_id,
|
||||
ClientHashesTable.disk_serial,
|
||||
ClientHashesTable.latest_time,
|
||||
ClientHashesTable.occurrences,
|
||||
)
|
||||
|
||||
|
||||
class ClientHash(TypedDict):
|
||||
userid: int
|
||||
osupath: str
|
||||
adapters: str
|
||||
uninstall_id: str
|
||||
disk_serial: str
|
||||
latest_time: datetime
|
||||
occurrences: int
|
||||
|
||||
|
||||
class ClientHashWithPlayer(ClientHash):
|
||||
name: str
|
||||
priv: int
|
||||
|
||||
|
||||
async def create(
|
||||
userid: int,
|
||||
osupath: str,
|
||||
adapters: str,
|
||||
uninstall_id: str,
|
||||
disk_serial: str,
|
||||
) -> ClientHash:
|
||||
"""Create a new client hash entry in the database."""
|
||||
insert_stmt: MysqlInsert = (
|
||||
mysql_insert(ClientHashesTable)
|
||||
.values(
|
||||
userid=userid,
|
||||
osupath=osupath,
|
||||
adapters=adapters,
|
||||
uninstall_id=uninstall_id,
|
||||
disk_serial=disk_serial,
|
||||
latest_time=func.now(),
|
||||
occurrences=1,
|
||||
)
|
||||
.on_duplicate_key_update(
|
||||
latest_time=func.now(),
|
||||
occurrences=ClientHashesTable.occurrences + 1,
|
||||
)
|
||||
)
|
||||
|
||||
await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(ClientHashesTable.userid == userid)
|
||||
.where(ClientHashesTable.osupath == osupath)
|
||||
.where(ClientHashesTable.adapters == adapters)
|
||||
.where(ClientHashesTable.uninstall_id == uninstall_id)
|
||||
.where(ClientHashesTable.disk_serial == disk_serial)
|
||||
)
|
||||
client_hash = await app.state.services.database.fetch_one(select_stmt)
|
||||
|
||||
assert client_hash is not None
|
||||
return cast(ClientHash, client_hash)
|
||||
|
||||
|
||||
async def fetch_any_hardware_matches_for_user(
|
||||
userid: int,
|
||||
running_under_wine: bool,
|
||||
adapters: str,
|
||||
uninstall_id: str,
|
||||
disk_serial: str | None = None,
|
||||
) -> list[ClientHashWithPlayer]:
|
||||
"""\
|
||||
Fetch a list of matching hardware addresses where any of
|
||||
`adapters`, `uninstall_id` or `disk_serial` match other users
|
||||
from the database.
|
||||
"""
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS, UsersTable.name, UsersTable.priv)
|
||||
.join(UsersTable, ClientHashesTable.userid == UsersTable.id)
|
||||
.where(ClientHashesTable.userid != userid)
|
||||
)
|
||||
|
||||
if running_under_wine:
|
||||
select_stmt = select_stmt.where(ClientHashesTable.uninstall_id == uninstall_id)
|
||||
else:
|
||||
# make disk serial optional in the OR
|
||||
oneof_filters: list[ColumnElement[Boolean]] = []
|
||||
oneof_filters.append(ClientHashesTable.adapters == adapters)
|
||||
oneof_filters.append(ClientHashesTable.uninstall_id == uninstall_id)
|
||||
if disk_serial is not None:
|
||||
oneof_filters.append(ClientHashesTable.disk_serial == disk_serial)
|
||||
select_stmt = select_stmt.where(or_(*oneof_filters))
|
||||
|
||||
client_hashes = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[ClientHashWithPlayer], client_hashes)
|
125
app/repositories/comments.py
Normal file
125
app/repositories/comments.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import CHAR
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.mysql import FLOAT
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
from app.repositories.users import UsersTable
|
||||
|
||||
|
||||
class TargetType(StrEnum):
|
||||
REPLAY = "replay"
|
||||
BEATMAP = "map"
|
||||
SONG = "song"
|
||||
|
||||
|
||||
class CommentsTable(Base):
|
||||
__tablename__ = "comments"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
target_id = Column("target_id", nullable=False)
|
||||
target_type = Column(Enum(TargetType, name="target_type"), nullable=False)
|
||||
userid = Column("userid", Integer, nullable=False)
|
||||
time = Column("time", FLOAT(precision=6, scale=3), nullable=False)
|
||||
comment = Column("comment", String(80, collation="utf8"), nullable=False)
|
||||
colour = Column("colour", CHAR(6), nullable=True)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
CommentsTable.id,
|
||||
CommentsTable.target_id,
|
||||
CommentsTable.target_type,
|
||||
CommentsTable.userid,
|
||||
CommentsTable.time,
|
||||
CommentsTable.comment,
|
||||
CommentsTable.colour,
|
||||
)
|
||||
|
||||
|
||||
class Comment(TypedDict):
|
||||
id: int
|
||||
target_id: int
|
||||
target_type: TargetType
|
||||
userid: int
|
||||
time: float
|
||||
comment: str
|
||||
colour: str | None
|
||||
|
||||
|
||||
async def create(
|
||||
target_id: int,
|
||||
target_type: TargetType,
|
||||
userid: int,
|
||||
time: float,
|
||||
comment: str,
|
||||
colour: str | None,
|
||||
) -> Comment:
|
||||
"""Create a new comment entry in the database."""
|
||||
insert_stmt = insert(CommentsTable).values(
|
||||
target_id=target_id,
|
||||
target_type=target_type,
|
||||
userid=userid,
|
||||
time=time,
|
||||
comment=comment,
|
||||
colour=colour,
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(CommentsTable.id == rec_id)
|
||||
_comment = await app.state.services.database.fetch_one(select_stmt)
|
||||
|
||||
assert _comment is not None
|
||||
return cast(Comment, _comment)
|
||||
|
||||
|
||||
class CommentWithUserPrivileges(Comment):
|
||||
priv: int
|
||||
|
||||
|
||||
async def fetch_all_relevant_to_replay(
|
||||
score_id: int | None = None,
|
||||
map_set_id: int | None = None,
|
||||
map_id: int | None = None,
|
||||
) -> list[CommentWithUserPrivileges]:
|
||||
"""\
|
||||
Fetch all comments from the database where any of the following match:
|
||||
- `score_id`
|
||||
- `map_set_id`
|
||||
- `map_id`
|
||||
"""
|
||||
select_stmt = (
|
||||
select(READ_PARAMS, UsersTable.priv)
|
||||
.join(UsersTable, CommentsTable.userid == UsersTable.id)
|
||||
.where(
|
||||
or_(
|
||||
and_(
|
||||
CommentsTable.target_type == TargetType.REPLAY,
|
||||
CommentsTable.target_id == score_id,
|
||||
),
|
||||
and_(
|
||||
CommentsTable.target_type == TargetType.SONG,
|
||||
CommentsTable.target_id == map_set_id,
|
||||
),
|
||||
and_(
|
||||
CommentsTable.target_type == TargetType.BEATMAP,
|
||||
CommentsTable.target_id == map_id,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
comments = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[CommentWithUserPrivileges], comments)
|
75
app/repositories/favourites.py
Normal file
75
app/repositories/favourites.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class FavouritesTable(Base):
|
||||
__tablename__ = "favourites"
|
||||
|
||||
userid = Column("userid", Integer, nullable=False, primary_key=True)
|
||||
setid = Column("setid", Integer, nullable=False, primary_key=True)
|
||||
created_at = Column("created_at", Integer, nullable=False, server_default="0")
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
FavouritesTable.userid,
|
||||
FavouritesTable.setid,
|
||||
FavouritesTable.created_at,
|
||||
)
|
||||
|
||||
|
||||
class Favourite(TypedDict):
|
||||
userid: int
|
||||
setid: int
|
||||
created_at: int
|
||||
|
||||
|
||||
async def create(
|
||||
userid: int,
|
||||
setid: int,
|
||||
) -> Favourite:
|
||||
"""Create a new favourite mapset entry in the database."""
|
||||
insert_stmt = insert(FavouritesTable).values(
|
||||
userid=userid,
|
||||
setid=setid,
|
||||
created_at=func.unix_timestamp(),
|
||||
)
|
||||
await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(FavouritesTable.userid == userid)
|
||||
.where(FavouritesTable.setid == setid)
|
||||
)
|
||||
favourite = await app.state.services.database.fetch_one(select_stmt)
|
||||
|
||||
assert favourite is not None
|
||||
return cast(Favourite, favourite)
|
||||
|
||||
|
||||
async def fetch_all(userid: int) -> list[Favourite]:
|
||||
"""Fetch all favourites from a player."""
|
||||
select_stmt = select(*READ_PARAMS).where(FavouritesTable.userid == userid)
|
||||
favourites = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Favourite], favourites)
|
||||
|
||||
|
||||
async def fetch_one(userid: int, setid: int) -> Favourite | None:
|
||||
"""Check if a mapset is already a favourite."""
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(FavouritesTable.userid == userid)
|
||||
.where(FavouritesTable.setid == setid)
|
||||
)
|
||||
favourite = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Favourite | None, favourite)
|
128
app/repositories/ingame_logins.py
Normal file
128
app/repositories/ingame_logins.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Date
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class IngameLoginsTable(Base):
|
||||
__tablename__ = "ingame_logins"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
userid = Column("userid", Integer, nullable=False)
|
||||
ip = Column("ip", String(45), nullable=False)
|
||||
osu_ver = Column("osu_ver", Date, nullable=False)
|
||||
osu_stream = Column("osu_stream", String(11), nullable=False)
|
||||
datetime = Column("datetime", DateTime, nullable=False)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
IngameLoginsTable.id,
|
||||
IngameLoginsTable.userid,
|
||||
IngameLoginsTable.ip,
|
||||
IngameLoginsTable.osu_ver,
|
||||
IngameLoginsTable.osu_stream,
|
||||
IngameLoginsTable.datetime,
|
||||
)
|
||||
|
||||
|
||||
class IngameLogin(TypedDict):
|
||||
id: int
|
||||
userid: str
|
||||
ip: str
|
||||
osu_ver: date
|
||||
osu_stream: str
|
||||
datetime: datetime
|
||||
|
||||
|
||||
class InGameLoginUpdateFields(TypedDict, total=False):
|
||||
userid: str
|
||||
ip: str
|
||||
osu_ver: date
|
||||
osu_stream: str
|
||||
|
||||
|
||||
async def create(
|
||||
user_id: int,
|
||||
ip: str,
|
||||
osu_ver: date,
|
||||
osu_stream: str,
|
||||
) -> IngameLogin:
|
||||
"""Create a new login entry in the database."""
|
||||
insert_stmt = insert(IngameLoginsTable).values(
|
||||
userid=user_id,
|
||||
ip=ip,
|
||||
osu_ver=osu_ver,
|
||||
osu_stream=osu_stream,
|
||||
datetime=func.now(),
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(IngameLoginsTable.id == rec_id)
|
||||
ingame_login = await app.state.services.database.fetch_one(select_stmt)
|
||||
|
||||
assert ingame_login is not None
|
||||
return cast(IngameLogin, ingame_login)
|
||||
|
||||
|
||||
async def fetch_one(id: int) -> IngameLogin | None:
|
||||
"""Fetch a login entry from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(IngameLoginsTable.id == id)
|
||||
ingame_login = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(IngameLogin | None, ingame_login)
|
||||
|
||||
|
||||
async def fetch_count(
|
||||
user_id: int | None = None,
|
||||
ip: str | None = None,
|
||||
) -> int:
|
||||
"""Fetch the number of logins in the database."""
|
||||
select_stmt = select(func.count().label("count")).select_from(IngameLoginsTable)
|
||||
if user_id is not None:
|
||||
select_stmt = select_stmt.where(IngameLoginsTable.userid == user_id)
|
||||
if ip is not None:
|
||||
select_stmt = select_stmt.where(IngameLoginsTable.ip == ip)
|
||||
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
user_id: int | None = None,
|
||||
ip: str | None = None,
|
||||
osu_ver: date | None = None,
|
||||
osu_stream: str | None = None,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[IngameLogin]:
|
||||
"""Fetch a list of logins from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
|
||||
if user_id is not None:
|
||||
select_stmt = select_stmt.where(IngameLoginsTable.userid == user_id)
|
||||
if ip is not None:
|
||||
select_stmt = select_stmt.where(IngameLoginsTable.ip == ip)
|
||||
if osu_ver is not None:
|
||||
select_stmt = select_stmt.where(IngameLoginsTable.osu_ver == osu_ver)
|
||||
if osu_stream is not None:
|
||||
select_stmt = select_stmt.where(IngameLoginsTable.osu_stream == osu_stream)
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
ingame_logins = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[IngameLogin], ingame_logins)
|
70
app/repositories/logs.py
Normal file
70
app/repositories/logs.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class LogTable(Base):
|
||||
__tablename__ = "logs"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
_from = Column("from", Integer, nullable=False)
|
||||
to = Column("to", Integer, nullable=False)
|
||||
action = Column("action", String(32), nullable=False)
|
||||
msg = Column("msg", String(2048, collation="utf8"), nullable=True)
|
||||
time = Column("time", DateTime, nullable=False, onupdate=func.now())
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
LogTable.id,
|
||||
LogTable._from.label("from"),
|
||||
LogTable.to,
|
||||
LogTable.action,
|
||||
LogTable.msg,
|
||||
LogTable.time,
|
||||
)
|
||||
|
||||
|
||||
class Log(TypedDict):
|
||||
id: int
|
||||
_from: int
|
||||
to: int
|
||||
action: str
|
||||
msg: str | None
|
||||
time: datetime
|
||||
|
||||
|
||||
async def create(
|
||||
_from: int,
|
||||
to: int,
|
||||
action: str,
|
||||
msg: str,
|
||||
) -> Log:
|
||||
"""Create a new log entry in the database."""
|
||||
insert_stmt = insert(LogTable).values(
|
||||
{
|
||||
"from": _from,
|
||||
"to": to,
|
||||
"action": action,
|
||||
"msg": msg,
|
||||
"time": func.now(),
|
||||
},
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(LogTable.id == rec_id)
|
||||
log = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert log is not None
|
||||
return cast(Log, log)
|
113
app/repositories/mail.py
Normal file
113
app/repositories/mail.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class MailTable(Base):
|
||||
__tablename__ = "mail"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
from_id = Column("from_id", Integer, nullable=False)
|
||||
to_id = Column("to_id", Integer, nullable=False)
|
||||
msg = Column("msg", String(2048, collation="utf8"), nullable=False)
|
||||
time = Column("time", Integer, nullable=True)
|
||||
read = Column("read", TINYINT(1), nullable=False, server_default="0")
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
MailTable.id,
|
||||
MailTable.from_id,
|
||||
MailTable.to_id,
|
||||
MailTable.msg,
|
||||
MailTable.time,
|
||||
MailTable.read,
|
||||
)
|
||||
|
||||
|
||||
class Mail(TypedDict):
|
||||
id: int
|
||||
from_id: int
|
||||
to_id: int
|
||||
msg: str
|
||||
time: int
|
||||
read: bool
|
||||
|
||||
|
||||
class MailWithUsernames(Mail):
|
||||
from_name: str
|
||||
to_name: str
|
||||
|
||||
|
||||
async def create(from_id: int, to_id: int, msg: str) -> Mail:
|
||||
"""Create a new mail entry in the database."""
|
||||
insert_stmt = insert(MailTable).values(
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
msg=msg,
|
||||
time=func.unix_timestamp(),
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(MailTable.id == rec_id)
|
||||
mail = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert mail is not None
|
||||
return cast(Mail, mail)
|
||||
|
||||
|
||||
from app.repositories.users import UsersTable
|
||||
|
||||
|
||||
async def fetch_all_mail_to_user(
|
||||
user_id: int,
|
||||
read: bool | None = None,
|
||||
) -> list[MailWithUsernames]:
|
||||
"""Fetch all of mail to a given target from the database."""
|
||||
from_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.from_id)
|
||||
to_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.to_id)
|
||||
|
||||
select_stmt = select(
|
||||
*READ_PARAMS,
|
||||
from_subquery.label("from_name"),
|
||||
to_subquery.label("to_name"),
|
||||
).where(MailTable.to_id == user_id)
|
||||
|
||||
if read is not None:
|
||||
select_stmt = select_stmt.where(MailTable.read == read)
|
||||
|
||||
mail = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[MailWithUsernames], mail)
|
||||
|
||||
|
||||
async def mark_conversation_as_read(to_id: int, from_id: int) -> list[Mail]:
|
||||
"""Mark any mail in a user's conversation with another user as read."""
|
||||
select_stmt = select(*READ_PARAMS).where(
|
||||
MailTable.to_id == to_id,
|
||||
MailTable.from_id == from_id,
|
||||
MailTable.read == False,
|
||||
)
|
||||
mail = await app.state.services.database.fetch_all(select_stmt)
|
||||
if not mail:
|
||||
return []
|
||||
|
||||
update_stmt = (
|
||||
update(MailTable)
|
||||
.where(MailTable.to_id == to_id)
|
||||
.where(MailTable.from_id == from_id)
|
||||
.where(MailTable.read == False)
|
||||
.values(read=True)
|
||||
)
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
return cast(list[Mail], mail)
|
97
app/repositories/map_requests.py
Normal file
97
app/repositories/map_requests.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class MapRequestsTable(Base):
|
||||
__tablename__ = "map_requests"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
map_id = Column("map_id", Integer, nullable=False)
|
||||
player_id = Column("player_id", Integer, nullable=False)
|
||||
datetime = Column("datetime", DateTime, nullable=False)
|
||||
active = Column("active", TINYINT(1), nullable=False)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
MapRequestsTable.id,
|
||||
MapRequestsTable.map_id,
|
||||
MapRequestsTable.player_id,
|
||||
MapRequestsTable.datetime,
|
||||
)
|
||||
|
||||
|
||||
class MapRequest(TypedDict):
|
||||
id: int
|
||||
map_id: int
|
||||
player_id: int
|
||||
datetime: datetime
|
||||
active: bool
|
||||
|
||||
|
||||
async def create(
|
||||
map_id: int,
|
||||
player_id: int,
|
||||
active: bool,
|
||||
) -> MapRequest:
|
||||
"""Create a new map request entry in the database."""
|
||||
insert_stmt = insert(MapRequestsTable).values(
|
||||
map_id=map_id,
|
||||
player_id=player_id,
|
||||
datetime=func.now(),
|
||||
active=active,
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(MapRequestsTable.id == rec_id)
|
||||
map_request = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert map_request is not None
|
||||
|
||||
return cast(MapRequest, map_request)
|
||||
|
||||
|
||||
async def fetch_all(
|
||||
map_id: int | None = None,
|
||||
player_id: int | None = None,
|
||||
active: bool | None = None,
|
||||
) -> list[MapRequest]:
|
||||
"""Fetch a list of map requests from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if map_id is not None:
|
||||
select_stmt = select_stmt.where(MapRequestsTable.map_id == map_id)
|
||||
if player_id is not None:
|
||||
select_stmt = select_stmt.where(MapRequestsTable.player_id == player_id)
|
||||
if active is not None:
|
||||
select_stmt = select_stmt.where(MapRequestsTable.active == active)
|
||||
|
||||
map_requests = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[MapRequest], map_requests)
|
||||
|
||||
|
||||
async def mark_batch_as_inactive(map_ids: list[Any]) -> list[MapRequest]:
|
||||
"""Mark a map request as inactive."""
|
||||
update_stmt = (
|
||||
update(MapRequestsTable)
|
||||
.where(MapRequestsTable.map_id.in_(map_ids))
|
||||
.values(active=False)
|
||||
)
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(MapRequestsTable.map_id.in_(map_ids))
|
||||
map_requests = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[MapRequest], map_requests)
|
370
app/repositories/maps.py
Normal file
370
app/repositories/maps.py
Normal file
@@ -0,0 +1,370 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.mysql import FLOAT
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class MapServer(StrEnum):
|
||||
OSU = "osu!"
|
||||
PRIVATE = "private"
|
||||
|
||||
|
||||
class MapsTable(Base):
|
||||
__tablename__ = "maps"
|
||||
|
||||
server = Column(
|
||||
Enum(MapServer, name="server"),
|
||||
nullable=False,
|
||||
server_default="osu!",
|
||||
primary_key=True,
|
||||
)
|
||||
id = Column(Integer, nullable=False, primary_key=True)
|
||||
set_id = Column(Integer, nullable=False)
|
||||
status = Column(Integer, nullable=False)
|
||||
md5 = Column(String(32), nullable=False)
|
||||
artist = Column(String(128, collation="utf8"), nullable=False)
|
||||
title = Column(String(128, collation="utf8"), nullable=False)
|
||||
version = Column(String(128, collation="utf8"), nullable=False)
|
||||
creator = Column(String(19, collation="utf8"), nullable=False)
|
||||
filename = Column(String(256, collation="utf8"), nullable=False)
|
||||
last_update = Column(DateTime, nullable=False)
|
||||
total_length = Column(Integer, nullable=False)
|
||||
max_combo = Column(Integer, nullable=False)
|
||||
frozen = Column(TINYINT(1), nullable=False, server_default="0")
|
||||
plays = Column(Integer, nullable=False, server_default="0")
|
||||
passes = Column(Integer, nullable=False, server_default="0")
|
||||
mode = Column(TINYINT(1), nullable=False, server_default="0")
|
||||
bpm = Column(FLOAT(12, 2), nullable=False, server_default="0.00")
|
||||
cs = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
|
||||
ar = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
|
||||
od = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
|
||||
hp = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
|
||||
diff = Column(FLOAT(6, 3), nullable=False, server_default="0.000")
|
||||
|
||||
__table_args__ = (
|
||||
Index("maps_set_id_index", "set_id"),
|
||||
Index("maps_status_index", "status"),
|
||||
Index("maps_filename_index", "filename"),
|
||||
Index("maps_plays_index", "plays"),
|
||||
Index("maps_mode_index", "mode"),
|
||||
Index("maps_frozen_index", "frozen"),
|
||||
Index("maps_md5_uindex", "md5", unique=True),
|
||||
Index("maps_id_uindex", "id", unique=True),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
MapsTable.id,
|
||||
MapsTable.server,
|
||||
MapsTable.set_id,
|
||||
MapsTable.status,
|
||||
MapsTable.md5,
|
||||
MapsTable.artist,
|
||||
MapsTable.title,
|
||||
MapsTable.version,
|
||||
MapsTable.creator,
|
||||
MapsTable.filename,
|
||||
MapsTable.last_update,
|
||||
MapsTable.total_length,
|
||||
MapsTable.max_combo,
|
||||
MapsTable.frozen,
|
||||
MapsTable.plays,
|
||||
MapsTable.passes,
|
||||
MapsTable.mode,
|
||||
MapsTable.bpm,
|
||||
MapsTable.cs,
|
||||
MapsTable.ar,
|
||||
MapsTable.od,
|
||||
MapsTable.hp,
|
||||
MapsTable.diff,
|
||||
)
|
||||
|
||||
|
||||
class Map(TypedDict):
|
||||
id: int
|
||||
server: str
|
||||
set_id: int
|
||||
status: int
|
||||
md5: str
|
||||
artist: str
|
||||
title: str
|
||||
version: str
|
||||
creator: str
|
||||
filename: str
|
||||
last_update: datetime
|
||||
total_length: int
|
||||
max_combo: int
|
||||
frozen: bool
|
||||
plays: int
|
||||
passes: int
|
||||
mode: int
|
||||
bpm: float
|
||||
cs: float
|
||||
ar: float
|
||||
od: float
|
||||
hp: float
|
||||
diff: float
|
||||
|
||||
|
||||
async def create(
|
||||
id: int,
|
||||
server: str,
|
||||
set_id: int,
|
||||
status: int,
|
||||
md5: str,
|
||||
artist: str,
|
||||
title: str,
|
||||
version: str,
|
||||
creator: str,
|
||||
filename: str,
|
||||
last_update: datetime,
|
||||
total_length: int,
|
||||
max_combo: int,
|
||||
frozen: bool,
|
||||
plays: int,
|
||||
passes: int,
|
||||
mode: int,
|
||||
bpm: float,
|
||||
cs: float,
|
||||
ar: float,
|
||||
od: float,
|
||||
hp: float,
|
||||
diff: float,
|
||||
) -> Map:
|
||||
"""Create a new beatmap entry in the database."""
|
||||
insert_stmt = insert(MapsTable).values(
|
||||
id=id,
|
||||
server=server,
|
||||
set_id=set_id,
|
||||
status=status,
|
||||
md5=md5,
|
||||
artist=artist,
|
||||
title=title,
|
||||
version=version,
|
||||
creator=creator,
|
||||
filename=filename,
|
||||
last_update=last_update,
|
||||
total_length=total_length,
|
||||
max_combo=max_combo,
|
||||
frozen=frozen,
|
||||
plays=plays,
|
||||
passes=passes,
|
||||
mode=mode,
|
||||
bpm=bpm,
|
||||
cs=cs,
|
||||
ar=ar,
|
||||
od=od,
|
||||
hp=hp,
|
||||
diff=diff,
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(MapsTable.id == rec_id)
|
||||
map = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert map is not None
|
||||
return cast(Map, map)
|
||||
|
||||
|
||||
async def fetch_one(
|
||||
id: int | None = None,
|
||||
md5: str | None = None,
|
||||
filename: str | None = None,
|
||||
) -> Map | None:
|
||||
"""Fetch a beatmap entry from the database."""
|
||||
if id is None and md5 is None and filename is None:
|
||||
raise ValueError("Must provide at least one parameter.")
|
||||
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if id is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.id == id)
|
||||
if md5 is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.md5 == md5)
|
||||
if filename is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.filename == filename)
|
||||
|
||||
map = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Map | None, map)
|
||||
|
||||
|
||||
async def fetch_count(
|
||||
server: str | None = None,
|
||||
set_id: int | None = None,
|
||||
status: int | None = None,
|
||||
artist: str | None = None,
|
||||
creator: str | None = None,
|
||||
filename: str | None = None,
|
||||
mode: int | None = None,
|
||||
frozen: bool | None = None,
|
||||
) -> int:
|
||||
"""Fetch the number of maps in the database."""
|
||||
select_stmt = select(func.count().label("count")).select_from(MapsTable)
|
||||
if server is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.server == server)
|
||||
if set_id is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.set_id == set_id)
|
||||
if status is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.status == status)
|
||||
if artist is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.artist == artist)
|
||||
if creator is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.creator == creator)
|
||||
if filename is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.filename == filename)
|
||||
if mode is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.mode == mode)
|
||||
if frozen is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.frozen == frozen)
|
||||
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
server: str | None = None,
|
||||
set_id: int | None = None,
|
||||
status: int | None = None,
|
||||
artist: str | None = None,
|
||||
creator: str | None = None,
|
||||
filename: str | None = None,
|
||||
mode: int | None = None,
|
||||
frozen: bool | None = None,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[Map]:
|
||||
"""Fetch a list of maps from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if server is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.server == server)
|
||||
if set_id is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.set_id == set_id)
|
||||
if status is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.status == status)
|
||||
if artist is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.artist == artist)
|
||||
if creator is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.creator == creator)
|
||||
if filename is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.filename == filename)
|
||||
if mode is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.mode == mode)
|
||||
if frozen is not None:
|
||||
select_stmt = select_stmt.where(MapsTable.frozen == frozen)
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
maps = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Map], maps)
|
||||
|
||||
|
||||
async def partial_update(
|
||||
id: int,
|
||||
server: str | _UnsetSentinel = UNSET,
|
||||
set_id: int | _UnsetSentinel = UNSET,
|
||||
status: int | _UnsetSentinel = UNSET,
|
||||
md5: str | _UnsetSentinel = UNSET,
|
||||
artist: str | _UnsetSentinel = UNSET,
|
||||
title: str | _UnsetSentinel = UNSET,
|
||||
version: str | _UnsetSentinel = UNSET,
|
||||
creator: str | _UnsetSentinel = UNSET,
|
||||
filename: str | _UnsetSentinel = UNSET,
|
||||
last_update: datetime | _UnsetSentinel = UNSET,
|
||||
total_length: int | _UnsetSentinel = UNSET,
|
||||
max_combo: int | _UnsetSentinel = UNSET,
|
||||
frozen: bool | _UnsetSentinel = UNSET,
|
||||
plays: int | _UnsetSentinel = UNSET,
|
||||
passes: int | _UnsetSentinel = UNSET,
|
||||
mode: int | _UnsetSentinel = UNSET,
|
||||
bpm: float | _UnsetSentinel = UNSET,
|
||||
cs: float | _UnsetSentinel = UNSET,
|
||||
ar: float | _UnsetSentinel = UNSET,
|
||||
od: float | _UnsetSentinel = UNSET,
|
||||
hp: float | _UnsetSentinel = UNSET,
|
||||
diff: float | _UnsetSentinel = UNSET,
|
||||
) -> Map | None:
|
||||
"""Update a beatmap entry in the database."""
|
||||
update_stmt = update(MapsTable).where(MapsTable.id == id)
|
||||
if not isinstance(server, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(server=server)
|
||||
if not isinstance(set_id, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(set_id=set_id)
|
||||
if not isinstance(status, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(status=status)
|
||||
if not isinstance(md5, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(md5=md5)
|
||||
if not isinstance(artist, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(artist=artist)
|
||||
if not isinstance(title, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(title=title)
|
||||
if not isinstance(version, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(version=version)
|
||||
if not isinstance(creator, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(creator=creator)
|
||||
if not isinstance(filename, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(filename=filename)
|
||||
if not isinstance(last_update, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(last_update=last_update)
|
||||
if not isinstance(total_length, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(total_length=total_length)
|
||||
if not isinstance(max_combo, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(max_combo=max_combo)
|
||||
if not isinstance(frozen, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(frozen=frozen)
|
||||
if not isinstance(plays, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(plays=plays)
|
||||
if not isinstance(passes, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(passes=passes)
|
||||
if not isinstance(mode, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(mode=mode)
|
||||
if not isinstance(bpm, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(bpm=bpm)
|
||||
if not isinstance(cs, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(cs=cs)
|
||||
if not isinstance(ar, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(ar=ar)
|
||||
if not isinstance(od, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(od=od)
|
||||
if not isinstance(hp, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(hp=hp)
|
||||
if not isinstance(diff, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(diff=diff)
|
||||
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(MapsTable.id == id)
|
||||
map = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Map | None, map)
|
||||
|
||||
|
||||
async def delete_one(id: int) -> Map | None:
|
||||
"""Delete a beatmap entry from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(MapsTable.id == id)
|
||||
map = await app.state.services.database.fetch_one(select_stmt)
|
||||
if map is None:
|
||||
return None
|
||||
|
||||
delete_stmt = delete(MapsTable).where(MapsTable.id == id)
|
||||
await app.state.services.database.execute(delete_stmt)
|
||||
return cast(Map, map)
|
85
app/repositories/ratings.py
Normal file
85
app/repositories/ratings.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class RatingsTable(Base):
|
||||
__tablename__ = "ratings"
|
||||
|
||||
userid = Column("userid", Integer, nullable=False, primary_key=True)
|
||||
map_md5 = Column("map_md5", String(32), nullable=False, primary_key=True)
|
||||
rating = Column("rating", TINYINT(2), nullable=False)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
RatingsTable.userid,
|
||||
RatingsTable.map_md5,
|
||||
RatingsTable.rating,
|
||||
)
|
||||
|
||||
|
||||
class Rating(TypedDict):
|
||||
userid: int
|
||||
map_md5: str
|
||||
rating: int
|
||||
|
||||
|
||||
async def create(userid: int, map_md5: str, rating: int) -> Rating:
|
||||
"""Create a new rating."""
|
||||
insert_stmt = insert(RatingsTable).values(
|
||||
userid=userid,
|
||||
map_md5=map_md5,
|
||||
rating=rating,
|
||||
)
|
||||
await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(RatingsTable.userid == userid)
|
||||
.where(RatingsTable.map_md5 == map_md5)
|
||||
)
|
||||
_rating = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert _rating is not None
|
||||
return cast(Rating, _rating)
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
userid: int | None = None,
|
||||
map_md5: str | None = None,
|
||||
page: int | None = 1,
|
||||
page_size: int | None = 50,
|
||||
) -> list[Rating]:
|
||||
"""Fetch multiple ratings, optionally with filter params and pagination."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if userid is not None:
|
||||
select_stmt = select_stmt.where(RatingsTable.userid == userid)
|
||||
if map_md5 is not None:
|
||||
select_stmt = select_stmt.where(RatingsTable.map_md5 == map_md5)
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
ratings = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Rating], ratings)
|
||||
|
||||
|
||||
async def fetch_one(userid: int, map_md5: str) -> Rating | None:
|
||||
"""Fetch a single rating for a given user and map."""
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(RatingsTable.userid == userid)
|
||||
.where(RatingsTable.map_md5 == map_md5)
|
||||
)
|
||||
rating = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Rating | None, rating)
|
246
app/repositories/scores.py
Normal file
246
app/repositories/scores.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.mysql import FLOAT
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class ScoresTable(Base):
|
||||
__tablename__ = "scores"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
map_md5 = Column("map_md5", String(32), nullable=False)
|
||||
score = Column("score", Integer, nullable=False)
|
||||
pp = Column("pp", FLOAT(precision=6, scale=3), nullable=False)
|
||||
acc = Column("acc", FLOAT(precision=6, scale=3), nullable=False)
|
||||
max_combo = Column("max_combo", Integer, nullable=False)
|
||||
mods = Column("mods", Integer, nullable=False)
|
||||
n300 = Column("n300", Integer, nullable=False)
|
||||
n100 = Column("n100", Integer, nullable=False)
|
||||
n50 = Column("n50", Integer, nullable=False)
|
||||
nmiss = Column("nmiss", Integer, nullable=False)
|
||||
ngeki = Column("ngeki", Integer, nullable=False)
|
||||
nkatu = Column("nkatu", Integer, nullable=False)
|
||||
grade = Column("grade", String(2), nullable=False, server_default="N")
|
||||
status = Column("status", Integer, nullable=False)
|
||||
mode = Column("mode", Integer, nullable=False)
|
||||
play_time = Column("play_time", DateTime, nullable=False)
|
||||
time_elapsed = Column("time_elapsed", Integer, nullable=False)
|
||||
client_flags = Column("client_flags", Integer, nullable=False)
|
||||
userid = Column("userid", Integer, nullable=False)
|
||||
perfect = Column("perfect", TINYINT(1), nullable=False)
|
||||
online_checksum = Column("online_checksum", String(32), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("scores_map_md5_index", map_md5),
|
||||
Index("scores_score_index", score),
|
||||
Index("scores_pp_index", pp),
|
||||
Index("scores_mods_index", mods),
|
||||
Index("scores_status_index", status),
|
||||
Index("scores_mode_index", mode),
|
||||
Index("scores_play_time_index", play_time),
|
||||
Index("scores_userid_index", userid),
|
||||
Index("scores_online_checksum_index", online_checksum),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
ScoresTable.id,
|
||||
ScoresTable.map_md5,
|
||||
ScoresTable.score,
|
||||
ScoresTable.pp,
|
||||
ScoresTable.acc,
|
||||
ScoresTable.max_combo,
|
||||
ScoresTable.mods,
|
||||
ScoresTable.n300,
|
||||
ScoresTable.n100,
|
||||
ScoresTable.n50,
|
||||
ScoresTable.nmiss,
|
||||
ScoresTable.ngeki,
|
||||
ScoresTable.nkatu,
|
||||
ScoresTable.grade,
|
||||
ScoresTable.status,
|
||||
ScoresTable.mode,
|
||||
ScoresTable.play_time,
|
||||
ScoresTable.time_elapsed,
|
||||
ScoresTable.client_flags,
|
||||
ScoresTable.userid,
|
||||
ScoresTable.perfect,
|
||||
ScoresTable.online_checksum,
|
||||
)
|
||||
|
||||
|
||||
class Score(TypedDict):
|
||||
id: int
|
||||
map_md5: str
|
||||
score: int
|
||||
pp: float
|
||||
acc: float
|
||||
max_combo: int
|
||||
mods: int
|
||||
n300: int
|
||||
n100: int
|
||||
n50: int
|
||||
nmiss: int
|
||||
ngeki: int
|
||||
nkatu: int
|
||||
grade: str
|
||||
status: int
|
||||
mode: int
|
||||
play_time: datetime
|
||||
time_elapsed: int
|
||||
client_flags: int
|
||||
userid: int
|
||||
perfect: int
|
||||
online_checksum: str
|
||||
|
||||
|
||||
async def create(
|
||||
map_md5: str,
|
||||
score: int,
|
||||
pp: float,
|
||||
acc: float,
|
||||
max_combo: int,
|
||||
mods: int,
|
||||
n300: int,
|
||||
n100: int,
|
||||
n50: int,
|
||||
nmiss: int,
|
||||
ngeki: int,
|
||||
nkatu: int,
|
||||
grade: str,
|
||||
status: int,
|
||||
mode: int,
|
||||
play_time: datetime,
|
||||
time_elapsed: int,
|
||||
client_flags: int,
|
||||
user_id: int,
|
||||
perfect: int,
|
||||
online_checksum: str,
|
||||
) -> Score:
|
||||
insert_stmt = insert(ScoresTable).values(
|
||||
map_md5=map_md5,
|
||||
score=score,
|
||||
pp=pp,
|
||||
acc=acc,
|
||||
max_combo=max_combo,
|
||||
mods=mods,
|
||||
n300=n300,
|
||||
n100=n100,
|
||||
n50=n50,
|
||||
nmiss=nmiss,
|
||||
ngeki=ngeki,
|
||||
nkatu=nkatu,
|
||||
grade=grade,
|
||||
status=status,
|
||||
mode=mode,
|
||||
play_time=play_time,
|
||||
time_elapsed=time_elapsed,
|
||||
client_flags=client_flags,
|
||||
userid=user_id,
|
||||
perfect=perfect,
|
||||
online_checksum=online_checksum,
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(ScoresTable.id == rec_id)
|
||||
_score = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert _score is not None
|
||||
return cast(Score, _score)
|
||||
|
||||
|
||||
async def fetch_one(id: int) -> Score | None:
|
||||
select_stmt = select(*READ_PARAMS).where(ScoresTable.id == id)
|
||||
_score = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Score | None, _score)
|
||||
|
||||
|
||||
async def fetch_count(
|
||||
map_md5: str | None = None,
|
||||
mods: int | None = None,
|
||||
status: int | None = None,
|
||||
mode: int | None = None,
|
||||
user_id: int | None = None,
|
||||
) -> int:
|
||||
select_stmt = select(func.count().label("count")).select_from(ScoresTable)
|
||||
if map_md5 is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.map_md5 == map_md5)
|
||||
if mods is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.mods == mods)
|
||||
if status is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.status == status)
|
||||
if mode is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.mode == mode)
|
||||
if user_id is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.userid == user_id)
|
||||
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
map_md5: str | None = None,
|
||||
mods: int | None = None,
|
||||
status: int | None = None,
|
||||
mode: int | None = None,
|
||||
user_id: int | None = None,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[Score]:
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if map_md5 is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.map_md5 == map_md5)
|
||||
if mods is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.mods == mods)
|
||||
if status is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.status == status)
|
||||
if mode is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.mode == mode)
|
||||
if user_id is not None:
|
||||
select_stmt = select_stmt.where(ScoresTable.userid == user_id)
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
scores = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Score], scores)
|
||||
|
||||
|
||||
async def partial_update(
|
||||
id: int,
|
||||
pp: float | _UnsetSentinel = UNSET,
|
||||
status: int | _UnsetSentinel = UNSET,
|
||||
) -> Score | None:
|
||||
"""Update an existing score."""
|
||||
update_stmt = update(ScoresTable).where(ScoresTable.id == id)
|
||||
if not isinstance(pp, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(pp=pp)
|
||||
if not isinstance(status, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(status=status)
|
||||
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(ScoresTable.id == id)
|
||||
_score = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Score | None, _score)
|
||||
|
||||
|
||||
# TODO: delete
|
237
app/repositories/stats.py
Normal file
237
app/repositories/stats.py
Normal file
@@ -0,0 +1,237 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.mysql import FLOAT
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class StatsTable(Base):
|
||||
__tablename__ = "stats"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
mode = Column("mode", TINYINT(1), primary_key=True)
|
||||
tscore = Column("tscore", Integer, nullable=False, server_default="0")
|
||||
rscore = Column("rscore", Integer, nullable=False, server_default="0")
|
||||
pp = Column("pp", Integer, nullable=False, server_default="0")
|
||||
plays = Column("plays", Integer, nullable=False, server_default="0")
|
||||
playtime = Column("playtime", Integer, nullable=False, server_default="0")
|
||||
acc = Column(
|
||||
"acc",
|
||||
FLOAT(precision=6, scale=3),
|
||||
nullable=False,
|
||||
server_default="0.000",
|
||||
)
|
||||
max_combo = Column("max_combo", Integer, nullable=False, server_default="0")
|
||||
total_hits = Column("total_hits", Integer, nullable=False, server_default="0")
|
||||
replay_views = Column("replay_views", Integer, nullable=False, server_default="0")
|
||||
xh_count = Column("xh_count", Integer, nullable=False, server_default="0")
|
||||
x_count = Column("x_count", Integer, nullable=False, server_default="0")
|
||||
sh_count = Column("sh_count", Integer, nullable=False, server_default="0")
|
||||
s_count = Column("s_count", Integer, nullable=False, server_default="0")
|
||||
a_count = Column("a_count", Integer, nullable=False, server_default="0")
|
||||
|
||||
__table_args__ = (
|
||||
Index("stats_mode_index", mode),
|
||||
Index("stats_pp_index", pp),
|
||||
Index("stats_tscore_index", tscore),
|
||||
Index("stats_rscore_index", rscore),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
StatsTable.id,
|
||||
StatsTable.mode,
|
||||
StatsTable.tscore,
|
||||
StatsTable.rscore,
|
||||
StatsTable.pp,
|
||||
StatsTable.plays,
|
||||
StatsTable.playtime,
|
||||
StatsTable.acc,
|
||||
StatsTable.max_combo,
|
||||
StatsTable.total_hits,
|
||||
StatsTable.replay_views,
|
||||
StatsTable.xh_count,
|
||||
StatsTable.x_count,
|
||||
StatsTable.sh_count,
|
||||
StatsTable.s_count,
|
||||
StatsTable.a_count,
|
||||
)
|
||||
|
||||
|
||||
class Stat(TypedDict):
|
||||
id: int
|
||||
mode: int
|
||||
tscore: int
|
||||
rscore: int
|
||||
pp: int
|
||||
plays: int
|
||||
playtime: int
|
||||
acc: float
|
||||
max_combo: int
|
||||
total_hits: int
|
||||
replay_views: int
|
||||
xh_count: int
|
||||
x_count: int
|
||||
sh_count: int
|
||||
s_count: int
|
||||
a_count: int
|
||||
|
||||
|
||||
async def create(player_id: int, mode: int) -> Stat:
|
||||
"""Create a new player stats entry in the database."""
|
||||
insert_stmt = insert(StatsTable).values(id=player_id, mode=mode)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(StatsTable.id == rec_id)
|
||||
stat = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert stat is not None
|
||||
return cast(Stat, stat)
|
||||
|
||||
|
||||
async def create_all_modes(player_id: int) -> list[Stat]:
|
||||
"""Create new player stats entries for each game mode in the database."""
|
||||
insert_stmt = insert(StatsTable).values(
|
||||
[
|
||||
{"id": player_id, "mode": mode}
|
||||
for mode in (
|
||||
0, # vn!std
|
||||
1, # vn!taiko
|
||||
2, # vn!catch
|
||||
3, # vn!mania
|
||||
4, # rx!std
|
||||
5, # rx!taiko
|
||||
6, # rx!catch
|
||||
8, # ap!std
|
||||
)
|
||||
],
|
||||
)
|
||||
await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(StatsTable.id == player_id)
|
||||
stats = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Stat], stats)
|
||||
|
||||
|
||||
async def fetch_one(player_id: int, mode: int) -> Stat | None:
|
||||
"""Fetch a player stats entry from the database."""
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(StatsTable.id == player_id)
|
||||
.where(StatsTable.mode == mode)
|
||||
)
|
||||
stat = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Stat | None, stat)
|
||||
|
||||
|
||||
async def fetch_count(
|
||||
player_id: int | None = None,
|
||||
mode: int | None = None,
|
||||
) -> int:
|
||||
select_stmt = select(func.count().label("count")).select_from(StatsTable)
|
||||
if player_id is not None:
|
||||
select_stmt = select_stmt.where(StatsTable.id == player_id)
|
||||
if mode is not None:
|
||||
select_stmt = select_stmt.where(StatsTable.mode == mode)
|
||||
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
player_id: int | None = None,
|
||||
mode: int | None = None,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[Stat]:
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if player_id is not None:
|
||||
select_stmt = select_stmt.where(StatsTable.id == player_id)
|
||||
if mode is not None:
|
||||
select_stmt = select_stmt.where(StatsTable.mode == mode)
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
stats = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[Stat], stats)
|
||||
|
||||
|
||||
async def partial_update(
|
||||
player_id: int,
|
||||
mode: int,
|
||||
tscore: int | _UnsetSentinel = UNSET,
|
||||
rscore: int | _UnsetSentinel = UNSET,
|
||||
pp: int | _UnsetSentinel = UNSET,
|
||||
plays: int | _UnsetSentinel = UNSET,
|
||||
playtime: int | _UnsetSentinel = UNSET,
|
||||
acc: float | _UnsetSentinel = UNSET,
|
||||
max_combo: int | _UnsetSentinel = UNSET,
|
||||
total_hits: int | _UnsetSentinel = UNSET,
|
||||
replay_views: int | _UnsetSentinel = UNSET,
|
||||
xh_count: int | _UnsetSentinel = UNSET,
|
||||
x_count: int | _UnsetSentinel = UNSET,
|
||||
sh_count: int | _UnsetSentinel = UNSET,
|
||||
s_count: int | _UnsetSentinel = UNSET,
|
||||
a_count: int | _UnsetSentinel = UNSET,
|
||||
) -> Stat | None:
|
||||
"""Update a player stats entry in the database."""
|
||||
update_stmt = (
|
||||
update(StatsTable)
|
||||
.where(StatsTable.id == player_id)
|
||||
.where(StatsTable.mode == mode)
|
||||
)
|
||||
if not isinstance(tscore, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(tscore=tscore)
|
||||
if not isinstance(rscore, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(rscore=rscore)
|
||||
if not isinstance(pp, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(pp=pp)
|
||||
if not isinstance(plays, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(plays=plays)
|
||||
if not isinstance(playtime, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(playtime=playtime)
|
||||
if not isinstance(acc, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(acc=acc)
|
||||
if not isinstance(max_combo, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(max_combo=max_combo)
|
||||
if not isinstance(total_hits, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(total_hits=total_hits)
|
||||
if not isinstance(replay_views, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(replay_views=replay_views)
|
||||
if not isinstance(xh_count, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(xh_count=xh_count)
|
||||
if not isinstance(x_count, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(x_count=x_count)
|
||||
if not isinstance(sh_count, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(sh_count=sh_count)
|
||||
if not isinstance(s_count, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(s_count=s_count)
|
||||
if not isinstance(a_count, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(a_count=a_count)
|
||||
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(StatsTable.id == player_id)
|
||||
.where(StatsTable.mode == mode)
|
||||
)
|
||||
stat = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(Stat | None, stat)
|
||||
|
||||
|
||||
# TODO: delete?
|
137
app/repositories/tourney_pool_maps.py
Normal file
137
app/repositories/tourney_pool_maps.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class TourneyPoolMapsTable(Base):
|
||||
__tablename__ = "tourney_pool_maps"
|
||||
|
||||
map_id = Column("map_id", Integer, nullable=False, primary_key=True)
|
||||
pool_id = Column("pool_id", Integer, nullable=False, primary_key=True)
|
||||
mods = Column("mods", Integer, nullable=False)
|
||||
slot = Column("slot", Integer, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index("tourney_pool_maps_mods_slot_index", mods, slot),
|
||||
Index("tourney_pool_maps_tourney_pools_id_fk", pool_id),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
TourneyPoolMapsTable.map_id,
|
||||
TourneyPoolMapsTable.pool_id,
|
||||
TourneyPoolMapsTable.mods,
|
||||
TourneyPoolMapsTable.slot,
|
||||
)
|
||||
|
||||
|
||||
class TourneyPoolMap(TypedDict):
|
||||
map_id: int
|
||||
pool_id: int
|
||||
mods: int
|
||||
slot: int
|
||||
|
||||
|
||||
async def create(map_id: int, pool_id: int, mods: int, slot: int) -> TourneyPoolMap:
|
||||
"""Create a new map pool entry in the database."""
|
||||
insert_stmt = insert(TourneyPoolMapsTable).values(
|
||||
map_id=map_id,
|
||||
pool_id=pool_id,
|
||||
mods=mods,
|
||||
slot=slot,
|
||||
)
|
||||
await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(TourneyPoolMapsTable.map_id == map_id)
|
||||
.where(TourneyPoolMapsTable.pool_id == pool_id)
|
||||
)
|
||||
tourney_pool_map = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert tourney_pool_map is not None
|
||||
return cast(TourneyPoolMap, tourney_pool_map)
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
pool_id: int | None = None,
|
||||
mods: int | None = None,
|
||||
slot: int | None = None,
|
||||
page: int | None = 1,
|
||||
page_size: int | None = 50,
|
||||
) -> list[TourneyPoolMap]:
|
||||
"""Fetch a list of map pool entries from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if pool_id is not None:
|
||||
select_stmt = select_stmt.where(TourneyPoolMapsTable.pool_id == pool_id)
|
||||
if mods is not None:
|
||||
select_stmt = select_stmt.where(TourneyPoolMapsTable.mods == mods)
|
||||
if slot is not None:
|
||||
select_stmt = select_stmt.where(TourneyPoolMapsTable.slot == slot)
|
||||
if page and page_size:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
tourney_pool_maps = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[TourneyPoolMap], tourney_pool_maps)
|
||||
|
||||
|
||||
async def fetch_by_pool_and_pick(
|
||||
pool_id: int,
|
||||
mods: int,
|
||||
slot: int,
|
||||
) -> TourneyPoolMap | None:
|
||||
"""Fetch a map pool entry by pool and pick from the database."""
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(TourneyPoolMapsTable.pool_id == pool_id)
|
||||
.where(TourneyPoolMapsTable.mods == mods)
|
||||
.where(TourneyPoolMapsTable.slot == slot)
|
||||
)
|
||||
tourney_pool_map = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(TourneyPoolMap | None, tourney_pool_map)
|
||||
|
||||
|
||||
async def delete_map_from_pool(pool_id: int, map_id: int) -> TourneyPoolMap | None:
|
||||
"""Delete a map pool entry from a given tourney pool from the database."""
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(TourneyPoolMapsTable.pool_id == pool_id)
|
||||
.where(TourneyPoolMapsTable.map_id == map_id)
|
||||
)
|
||||
|
||||
tourney_pool_map = await app.state.services.database.fetch_one(select_stmt)
|
||||
if tourney_pool_map is None:
|
||||
return None
|
||||
|
||||
delete_stmt = (
|
||||
delete(TourneyPoolMapsTable)
|
||||
.where(TourneyPoolMapsTable.pool_id == pool_id)
|
||||
.where(TourneyPoolMapsTable.map_id == map_id)
|
||||
)
|
||||
|
||||
await app.state.services.database.execute(delete_stmt)
|
||||
return cast(TourneyPoolMap, tourney_pool_map)
|
||||
|
||||
|
||||
async def delete_all_in_pool(pool_id: int) -> list[TourneyPoolMap]:
|
||||
"""Delete all map pool entries from a given tourney pool from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(TourneyPoolMapsTable.pool_id == pool_id)
|
||||
tourney_pool_maps = await app.state.services.database.fetch_all(select_stmt)
|
||||
if not tourney_pool_maps:
|
||||
return []
|
||||
|
||||
delete_stmt = delete(TourneyPoolMapsTable).where(
|
||||
TourneyPoolMapsTable.pool_id == pool_id,
|
||||
)
|
||||
await app.state.services.database.execute(delete_stmt)
|
||||
return cast(list[TourneyPoolMap], tourney_pool_maps)
|
104
app/repositories/tourney_pools.py
Normal file
104
app/repositories/tourney_pools.py
Normal file
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
|
||||
import app.state.services
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class TourneyPoolsTable(Base):
|
||||
__tablename__ = "tourney_pools"
|
||||
|
||||
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
|
||||
name = Column("name", String(16), nullable=False)
|
||||
created_at = Column("created_at", DateTime, nullable=False)
|
||||
created_by = Column("created_by", Integer, nullable=False)
|
||||
|
||||
__table_args__ = (Index("tourney_pools_users_id_fk", created_by),)
|
||||
|
||||
|
||||
class TourneyPool(TypedDict):
|
||||
id: int
|
||||
name: str
|
||||
created_at: datetime
|
||||
created_by: int
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
TourneyPoolsTable.id,
|
||||
TourneyPoolsTable.name,
|
||||
TourneyPoolsTable.created_at,
|
||||
TourneyPoolsTable.created_by,
|
||||
)
|
||||
|
||||
|
||||
async def create(name: str, created_by: int) -> TourneyPool:
|
||||
"""Create a new tourney pool entry in the database."""
|
||||
insert_stmt = insert(TourneyPoolsTable).values(
|
||||
name=name,
|
||||
created_at=func.now(),
|
||||
created_by=created_by,
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == rec_id)
|
||||
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert tourney_pool is not None
|
||||
return cast(TourneyPool, tourney_pool)
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
id: int | None = None,
|
||||
created_by: int | None = None,
|
||||
page: int | None = 1,
|
||||
page_size: int | None = 50,
|
||||
) -> list[TourneyPool]:
|
||||
"""Fetch many tourney pools from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if id is not None:
|
||||
select_stmt = select_stmt.where(TourneyPoolsTable.id == id)
|
||||
if created_by is not None:
|
||||
select_stmt = select_stmt.where(TourneyPoolsTable.created_by == created_by)
|
||||
if page and page_size:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
tourney_pools = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[TourneyPool], tourney_pools)
|
||||
|
||||
|
||||
async def fetch_by_name(name: str) -> TourneyPool | None:
|
||||
"""Fetch a tourney pool by name from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.name == name)
|
||||
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(TourneyPool | None, tourney_pool)
|
||||
|
||||
|
||||
async def fetch_by_id(id: int) -> TourneyPool | None:
|
||||
"""Fetch a tourney pool by id from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == id)
|
||||
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(TourneyPool | None, tourney_pool)
|
||||
|
||||
|
||||
async def delete_by_id(id: int) -> TourneyPool | None:
|
||||
"""Delete a tourney pool by id from the database."""
|
||||
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == id)
|
||||
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
|
||||
if tourney_pool is None:
|
||||
return None
|
||||
|
||||
delete_stmt = delete(TourneyPoolsTable).where(TourneyPoolsTable.id == id)
|
||||
await app.state.services.database.execute(delete_stmt)
|
||||
return cast(TourneyPool, tourney_pool)
|
79
app/repositories/user_achievements.py
Normal file
79
app/repositories/user_achievements.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
|
||||
|
||||
class UserAchievementsTable(Base):
|
||||
__tablename__ = "user_achievements"
|
||||
|
||||
userid = Column("userid", Integer, nullable=False, primary_key=True)
|
||||
achid = Column("achid", Integer, nullable=False, primary_key=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("user_achievements_achid_index", achid),
|
||||
Index("user_achievements_userid_index", userid),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
UserAchievementsTable.userid,
|
||||
UserAchievementsTable.achid,
|
||||
)
|
||||
|
||||
|
||||
class UserAchievement(TypedDict):
|
||||
userid: int
|
||||
achid: int
|
||||
|
||||
|
||||
async def create(user_id: int, achievement_id: int) -> UserAchievement:
|
||||
"""Creates a new user achievement entry."""
|
||||
insert_stmt = insert(UserAchievementsTable).values(
|
||||
userid=user_id,
|
||||
achid=achievement_id,
|
||||
)
|
||||
await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = (
|
||||
select(*READ_PARAMS)
|
||||
.where(UserAchievementsTable.userid == user_id)
|
||||
.where(UserAchievementsTable.achid == achievement_id)
|
||||
)
|
||||
user_achievement = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert user_achievement is not None
|
||||
return cast(UserAchievement, user_achievement)
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
user_id: int | _UnsetSentinel = UNSET,
|
||||
achievement_id: int | _UnsetSentinel = UNSET,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[UserAchievement]:
|
||||
"""Fetch a list of user achievements."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if not isinstance(user_id, _UnsetSentinel):
|
||||
select_stmt = select_stmt.where(UserAchievementsTable.userid == user_id)
|
||||
if not isinstance(achievement_id, _UnsetSentinel):
|
||||
select_stmt = select_stmt.where(UserAchievementsTable.achid == achievement_id)
|
||||
|
||||
if page and page_size:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
user_achievements = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[UserAchievement], user_achievements)
|
||||
|
||||
|
||||
# TODO: delete?
|
270
app/repositories/users.py
Normal file
270
app/repositories/users.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.mysql import TINYINT
|
||||
|
||||
import app.state.services
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories import Base
|
||||
from app.utils import make_safe_name
|
||||
|
||||
|
||||
class UsersTable(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
|
||||
name = Column(String(32, collation="utf8"), nullable=False)
|
||||
safe_name = Column(String(32, collation="utf8"), nullable=False)
|
||||
email = Column(String(254), nullable=False)
|
||||
priv = Column(Integer, nullable=False, server_default="1")
|
||||
pw_bcrypt = Column(String(60), nullable=False)
|
||||
country = Column(String(2), nullable=False, server_default="xx")
|
||||
silence_end = Column(Integer, nullable=False, server_default="0")
|
||||
donor_end = Column(Integer, nullable=False, server_default="0")
|
||||
creation_time = Column(Integer, nullable=False, server_default="0")
|
||||
latest_activity = Column(Integer, nullable=False, server_default="0")
|
||||
clan_id = Column(Integer, nullable=False, server_default="0")
|
||||
clan_priv = Column(TINYINT, nullable=False, server_default="0")
|
||||
preferred_mode = Column(Integer, nullable=False, server_default="0")
|
||||
play_style = Column(Integer, nullable=False, server_default="0")
|
||||
custom_badge_name = Column(String(16, collation="utf8"))
|
||||
custom_badge_icon = Column(String(64))
|
||||
userpage_content = Column(String(2048, collation="utf8"))
|
||||
api_key = Column(String(36))
|
||||
|
||||
__table_args__ = (
|
||||
Index("users_priv_index", priv),
|
||||
Index("users_clan_id_index", clan_id),
|
||||
Index("users_clan_priv_index", clan_priv),
|
||||
Index("users_country_index", country),
|
||||
Index("users_api_key_uindex", api_key, unique=True),
|
||||
Index("users_email_uindex", email, unique=True),
|
||||
Index("users_name_uindex", name, unique=True),
|
||||
Index("users_safe_name_uindex", safe_name, unique=True),
|
||||
)
|
||||
|
||||
|
||||
READ_PARAMS = (
|
||||
UsersTable.id,
|
||||
UsersTable.name,
|
||||
UsersTable.safe_name,
|
||||
UsersTable.priv,
|
||||
UsersTable.country,
|
||||
UsersTable.silence_end,
|
||||
UsersTable.donor_end,
|
||||
UsersTable.creation_time,
|
||||
UsersTable.latest_activity,
|
||||
UsersTable.clan_id,
|
||||
UsersTable.clan_priv,
|
||||
UsersTable.preferred_mode,
|
||||
UsersTable.play_style,
|
||||
UsersTable.custom_badge_name,
|
||||
UsersTable.custom_badge_icon,
|
||||
UsersTable.userpage_content,
|
||||
)
|
||||
|
||||
|
||||
class User(TypedDict):
|
||||
id: int
|
||||
name: str
|
||||
safe_name: str
|
||||
priv: int
|
||||
pw_bcrypt: str
|
||||
country: str
|
||||
silence_end: int
|
||||
donor_end: int
|
||||
creation_time: int
|
||||
latest_activity: int
|
||||
clan_id: int
|
||||
clan_priv: int
|
||||
preferred_mode: int
|
||||
play_style: int
|
||||
custom_badge_name: str | None
|
||||
custom_badge_icon: str | None
|
||||
userpage_content: str | None
|
||||
api_key: str | None
|
||||
|
||||
|
||||
async def create(
|
||||
name: str,
|
||||
email: str,
|
||||
pw_bcrypt: bytes,
|
||||
country: str,
|
||||
) -> User:
|
||||
"""Create a new user in the database."""
|
||||
insert_stmt = insert(UsersTable).values(
|
||||
name=name,
|
||||
safe_name=make_safe_name(name),
|
||||
email=email,
|
||||
pw_bcrypt=pw_bcrypt,
|
||||
country=country,
|
||||
creation_time=func.unix_timestamp(),
|
||||
latest_activity=func.unix_timestamp(),
|
||||
)
|
||||
rec_id = await app.state.services.database.execute(insert_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(UsersTable.id == rec_id)
|
||||
user = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert user is not None
|
||||
return cast(User, user)
|
||||
|
||||
|
||||
async def fetch_one(
|
||||
id: int | None = None,
|
||||
name: str | None = None,
|
||||
email: str | None = None,
|
||||
fetch_all_fields: bool = False, # TODO: probably remove this if possible
|
||||
) -> User | None:
|
||||
"""Fetch a single user from the database."""
|
||||
if id is None and name is None and email is None:
|
||||
raise ValueError("Must provide at least one parameter.")
|
||||
|
||||
if fetch_all_fields:
|
||||
select_stmt = select(UsersTable)
|
||||
else:
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
|
||||
if id is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.id == id)
|
||||
if name is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.safe_name == make_safe_name(name))
|
||||
if email is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.email == email)
|
||||
|
||||
user = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(User | None, user)
|
||||
|
||||
|
||||
async def fetch_count(
|
||||
priv: int | None = None,
|
||||
country: str | None = None,
|
||||
clan_id: int | None = None,
|
||||
clan_priv: int | None = None,
|
||||
preferred_mode: int | None = None,
|
||||
play_style: int | None = None,
|
||||
) -> int:
|
||||
"""Fetch the number of users in the database."""
|
||||
select_stmt = select(func.count().label("count")).select_from(UsersTable)
|
||||
if priv is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.priv == priv)
|
||||
if country is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.country == country)
|
||||
if clan_id is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.clan_id == clan_id)
|
||||
if clan_priv is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.clan_priv == clan_priv)
|
||||
if preferred_mode is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.preferred_mode == preferred_mode)
|
||||
if play_style is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.play_style == play_style)
|
||||
|
||||
rec = await app.state.services.database.fetch_one(select_stmt)
|
||||
assert rec is not None
|
||||
return cast(int, rec["count"])
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
priv: int | None = None,
|
||||
country: str | None = None,
|
||||
clan_id: int | None = None,
|
||||
clan_priv: int | None = None,
|
||||
preferred_mode: int | None = None,
|
||||
play_style: int | None = None,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[User]:
|
||||
"""Fetch multiple users from the database."""
|
||||
select_stmt = select(*READ_PARAMS)
|
||||
if priv is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.priv == priv)
|
||||
if country is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.country == country)
|
||||
if clan_id is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.clan_id == clan_id)
|
||||
if clan_priv is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.clan_priv == clan_priv)
|
||||
if preferred_mode is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.preferred_mode == preferred_mode)
|
||||
if play_style is not None:
|
||||
select_stmt = select_stmt.where(UsersTable.play_style == play_style)
|
||||
|
||||
if page is not None and page_size is not None:
|
||||
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
|
||||
|
||||
users = await app.state.services.database.fetch_all(select_stmt)
|
||||
return cast(list[User], users)
|
||||
|
||||
|
||||
async def partial_update(
|
||||
id: int,
|
||||
name: str | _UnsetSentinel = UNSET,
|
||||
email: str | _UnsetSentinel = UNSET,
|
||||
priv: int | _UnsetSentinel = UNSET,
|
||||
country: str | _UnsetSentinel = UNSET,
|
||||
silence_end: int | _UnsetSentinel = UNSET,
|
||||
donor_end: int | _UnsetSentinel = UNSET,
|
||||
creation_time: _UnsetSentinel | _UnsetSentinel = UNSET,
|
||||
latest_activity: int | _UnsetSentinel = UNSET,
|
||||
clan_id: int | _UnsetSentinel = UNSET,
|
||||
clan_priv: int | _UnsetSentinel = UNSET,
|
||||
preferred_mode: int | _UnsetSentinel = UNSET,
|
||||
play_style: int | _UnsetSentinel = UNSET,
|
||||
custom_badge_name: str | None | _UnsetSentinel = UNSET,
|
||||
custom_badge_icon: str | None | _UnsetSentinel = UNSET,
|
||||
userpage_content: str | None | _UnsetSentinel = UNSET,
|
||||
api_key: str | None | _UnsetSentinel = UNSET,
|
||||
) -> User | None:
|
||||
"""Update a user in the database."""
|
||||
update_stmt = update(UsersTable).where(UsersTable.id == id)
|
||||
if not isinstance(name, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(name=name, safe_name=make_safe_name(name))
|
||||
if not isinstance(email, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(email=email)
|
||||
if not isinstance(priv, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(priv=priv)
|
||||
if not isinstance(country, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(country=country)
|
||||
if not isinstance(silence_end, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(silence_end=silence_end)
|
||||
if not isinstance(donor_end, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(donor_end=donor_end)
|
||||
if not isinstance(creation_time, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(creation_time=creation_time)
|
||||
if not isinstance(latest_activity, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(latest_activity=latest_activity)
|
||||
if not isinstance(clan_id, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(clan_id=clan_id)
|
||||
if not isinstance(clan_priv, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(clan_priv=clan_priv)
|
||||
if not isinstance(preferred_mode, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(preferred_mode=preferred_mode)
|
||||
if not isinstance(play_style, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(play_style=play_style)
|
||||
if not isinstance(custom_badge_name, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(custom_badge_name=custom_badge_name)
|
||||
if not isinstance(custom_badge_icon, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(custom_badge_icon=custom_badge_icon)
|
||||
if not isinstance(userpage_content, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(userpage_content=userpage_content)
|
||||
if not isinstance(api_key, _UnsetSentinel):
|
||||
update_stmt = update_stmt.values(api_key=api_key)
|
||||
|
||||
await app.state.services.database.execute(update_stmt)
|
||||
|
||||
select_stmt = select(*READ_PARAMS).where(UsersTable.id == id)
|
||||
user = await app.state.services.database.fetch_one(select_stmt)
|
||||
return cast(User | None, user)
|
||||
|
||||
|
||||
# TODO: delete?
|
73
app/settings.py
Normal file
73
app/settings.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tomllib
|
||||
from urllib.parse import quote
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.settings_utils import read_bool
|
||||
from app.settings_utils import read_list
|
||||
|
||||
load_dotenv()
|
||||
|
||||
APP_HOST = os.environ["APP_HOST"]
|
||||
APP_PORT = int(os.environ["APP_PORT"])
|
||||
|
||||
DB_HOST = os.environ["DB_HOST"]
|
||||
DB_PORT = int(os.environ["DB_PORT"])
|
||||
DB_USER = os.environ["DB_USER"]
|
||||
DB_PASS = quote(os.environ["DB_PASS"])
|
||||
DB_NAME = os.environ["DB_NAME"]
|
||||
DB_DSN = f"mysql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
|
||||
|
||||
REDIS_HOST = os.environ["REDIS_HOST"]
|
||||
REDIS_PORT = int(os.environ["REDIS_PORT"])
|
||||
REDIS_USER = os.environ["REDIS_USER"]
|
||||
REDIS_PASS = quote(os.environ["REDIS_PASS"])
|
||||
REDIS_DB = int(os.environ["REDIS_DB"])
|
||||
|
||||
REDIS_AUTH_STRING = f"{REDIS_USER}:{REDIS_PASS}@" if REDIS_USER and REDIS_PASS else ""
|
||||
REDIS_DSN = f"redis://{REDIS_AUTH_STRING}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
|
||||
|
||||
OSU_API_KEY = os.environ.get("OSU_API_KEY") or None
|
||||
|
||||
DOMAIN = os.environ["DOMAIN"]
|
||||
MIRROR_SEARCH_ENDPOINT = os.environ["MIRROR_SEARCH_ENDPOINT"]
|
||||
MIRROR_DOWNLOAD_ENDPOINT = os.environ["MIRROR_DOWNLOAD_ENDPOINT"]
|
||||
|
||||
COMMAND_PREFIX = os.environ["COMMAND_PREFIX"]
|
||||
|
||||
SEASONAL_BGS = read_list(os.environ["SEASONAL_BGS"])
|
||||
|
||||
MENU_ICON_URL = os.environ["MENU_ICON_URL"]
|
||||
MENU_ONCLICK_URL = os.environ["MENU_ONCLICK_URL"]
|
||||
|
||||
DATADOG_API_KEY = os.environ["DATADOG_API_KEY"]
|
||||
DATADOG_APP_KEY = os.environ["DATADOG_APP_KEY"]
|
||||
|
||||
DEBUG = read_bool(os.environ["DEBUG"])
|
||||
REDIRECT_OSU_URLS = read_bool(os.environ["REDIRECT_OSU_URLS"])
|
||||
|
||||
PP_CACHED_ACCURACIES = [int(acc) for acc in read_list(os.environ["PP_CACHED_ACCS"])]
|
||||
|
||||
DISALLOWED_NAMES = read_list(os.environ["DISALLOWED_NAMES"])
|
||||
DISALLOWED_PASSWORDS = read_list(os.environ["DISALLOWED_PASSWORDS"])
|
||||
DISALLOW_OLD_CLIENTS = read_bool(os.environ["DISALLOW_OLD_CLIENTS"])
|
||||
DISALLOW_INGAME_REGISTRATION = read_bool(os.environ["DISALLOW_INGAME_REGISTRATION"])
|
||||
|
||||
DISCORD_AUDIT_LOG_WEBHOOK = os.environ["DISCORD_AUDIT_LOG_WEBHOOK"]
|
||||
|
||||
AUTOMATICALLY_REPORT_PROBLEMS = read_bool(os.environ["AUTOMATICALLY_REPORT_PROBLEMS"])
|
||||
|
||||
LOG_WITH_COLORS = read_bool(os.environ["LOG_WITH_COLORS"])
|
||||
|
||||
# advanced dev settings
|
||||
|
||||
## WARNING touch this once you've
|
||||
## read through what it enables.
|
||||
## you could put your server at risk.
|
||||
DEVELOPER_MODE = read_bool(os.environ["DEVELOPER_MODE"])
|
||||
|
||||
with open("pyproject.toml", "rb") as f:
|
||||
VERSION = tomllib.load(f)["tool"]["poetry"]["version"]
|
48
app/settings_utils.py
Normal file
48
app/settings_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import date
|
||||
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
|
||||
|
||||
def read_bool(value: str) -> bool:
|
||||
return value.lower() in ("true", "1", "yes")
|
||||
|
||||
|
||||
def read_list(value: str) -> list[str]:
|
||||
return [v.strip() for v in value.split(",")]
|
||||
|
||||
|
||||
def support_deprecated_vars(
|
||||
new_name: str,
|
||||
deprecated_name: str,
|
||||
*,
|
||||
until: date,
|
||||
allow_empty_string: bool = False,
|
||||
) -> str:
|
||||
val1 = os.getenv(new_name)
|
||||
if val1:
|
||||
return val1
|
||||
|
||||
val2 = os.getenv(deprecated_name)
|
||||
if val2:
|
||||
if until < date.today():
|
||||
raise ValueError(
|
||||
f'The "{deprecated_name}" config option has been deprecated as of {until.isoformat()} and is no longer supported. Use {new_name} instead.',
|
||||
)
|
||||
|
||||
log(
|
||||
f'The "{deprecated_name}" config option has been deprecated and will be supported until {until.isoformat()}. Use {new_name} instead.',
|
||||
Ansi.LYELLOW,
|
||||
)
|
||||
return val2
|
||||
|
||||
if allow_empty_string:
|
||||
if val1 is not None:
|
||||
return val1
|
||||
if val2 is not None:
|
||||
return val2
|
||||
|
||||
raise KeyError(f"{new_name} is not set in the environment")
|
24
app/state/__init__.py
Normal file
24
app/state/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Literal
|
||||
|
||||
from . import cache
|
||||
from . import services
|
||||
from . import sessions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from asyncio import AbstractEventLoop
|
||||
|
||||
from app.packets import BasePacket
|
||||
from app.packets import ClientPackets
|
||||
|
||||
loop: AbstractEventLoop
|
||||
score_submission_locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
packets: dict[Literal["all", "restricted"], dict[ClientPackets, type[BasePacket]]] = {
|
||||
"all": {},
|
||||
"restricted": {},
|
||||
}
|
||||
shutting_down = False
|
14
app/state/cache.py
Normal file
14
app/state/cache.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.objects.beatmap import Beatmap
|
||||
from app.objects.beatmap import BeatmapSet
|
||||
|
||||
|
||||
bcrypt: dict[bytes, bytes] = {} # {bcrypt: md5, ...}
|
||||
beatmap: dict[str | int, Beatmap] = {} # {md5: map, id: map, ...}
|
||||
beatmapset: dict[int, BeatmapSet] = {} # {bsid: map_set}
|
||||
unsubmitted: set[str] = set() # {md5, ...}
|
||||
needs_update: set[str] = set() # {md5, ...}
|
492
app/state/services.py
Normal file
492
app/state/services.py
Normal file
@@ -0,0 +1,492 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import pickle
|
||||
import re
|
||||
import secrets
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import MutableMapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypedDict
|
||||
|
||||
import datadog as datadog_module
|
||||
import datadog.threadstats.base as datadog_client
|
||||
import httpx
|
||||
import pymysql
|
||||
from redis import asyncio as aioredis
|
||||
|
||||
import app.settings
|
||||
import app.state
|
||||
from app._typing import IPAddress
|
||||
from app.adapters.database import Database
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
|
||||
STRANGE_LOG_DIR = Path.cwd() / ".data/logs"
|
||||
|
||||
VERSION_RGX = re.compile(r"^# v(?P<ver>\d+\.\d+\.\d+)$")
|
||||
SQL_UPDATES_FILE = Path.cwd() / "migrations/migrations.sql"
|
||||
|
||||
|
||||
""" session objects """
|
||||
|
||||
http_client = httpx.AsyncClient()
|
||||
database = Database(app.settings.DB_DSN)
|
||||
redis: aioredis.Redis = aioredis.from_url(app.settings.REDIS_DSN)
|
||||
|
||||
datadog: datadog_client.ThreadStats | None = None
|
||||
if str(app.settings.DATADOG_API_KEY) and str(app.settings.DATADOG_APP_KEY):
|
||||
datadog_module.initialize(
|
||||
api_key=str(app.settings.DATADOG_API_KEY),
|
||||
app_key=str(app.settings.DATADOG_APP_KEY),
|
||||
)
|
||||
datadog = datadog_client.ThreadStats()
|
||||
|
||||
ip_resolver: IPResolver
|
||||
|
||||
""" session usecases """
|
||||
|
||||
|
||||
class Country(TypedDict):
|
||||
acronym: str
|
||||
numeric: int
|
||||
|
||||
|
||||
class Geolocation(TypedDict):
|
||||
latitude: float
|
||||
longitude: float
|
||||
country: Country
|
||||
|
||||
|
||||
# fmt: off
|
||||
country_codes = {
|
||||
"oc": 1, "eu": 2, "ad": 3, "ae": 4, "af": 5, "ag": 6, "ai": 7, "al": 8,
|
||||
"am": 9, "an": 10, "ao": 11, "aq": 12, "ar": 13, "as": 14, "at": 15, "au": 16,
|
||||
"aw": 17, "az": 18, "ba": 19, "bb": 20, "bd": 21, "be": 22, "bf": 23, "bg": 24,
|
||||
"bh": 25, "bi": 26, "bj": 27, "bm": 28, "bn": 29, "bo": 30, "br": 31, "bs": 32,
|
||||
"bt": 33, "bv": 34, "bw": 35, "by": 36, "bz": 37, "ca": 38, "cc": 39, "cd": 40,
|
||||
"cf": 41, "cg": 42, "ch": 43, "ci": 44, "ck": 45, "cl": 46, "cm": 47, "cn": 48,
|
||||
"co": 49, "cr": 50, "cu": 51, "cv": 52, "cx": 53, "cy": 54, "cz": 55, "de": 56,
|
||||
"dj": 57, "dk": 58, "dm": 59, "do": 60, "dz": 61, "ec": 62, "ee": 63, "eg": 64,
|
||||
"eh": 65, "er": 66, "es": 67, "et": 68, "fi": 69, "fj": 70, "fk": 71, "fm": 72,
|
||||
"fo": 73, "fr": 74, "fx": 75, "ga": 76, "gb": 77, "gd": 78, "ge": 79, "gf": 80,
|
||||
"gh": 81, "gi": 82, "gl": 83, "gm": 84, "gn": 85, "gp": 86, "gq": 87, "gr": 88,
|
||||
"gs": 89, "gt": 90, "gu": 91, "gw": 92, "gy": 93, "hk": 94, "hm": 95, "hn": 96,
|
||||
"hr": 97, "ht": 98, "hu": 99, "id": 100, "ie": 101, "il": 102, "in": 103, "io": 104,
|
||||
"iq": 105, "ir": 106, "is": 107, "it": 108, "jm": 109, "jo": 110, "jp": 111, "ke": 112,
|
||||
"kg": 113, "kh": 114, "ki": 115, "km": 116, "kn": 117, "kp": 118, "kr": 119, "kw": 120,
|
||||
"ky": 121, "kz": 122, "la": 123, "lb": 124, "lc": 125, "li": 126, "lk": 127, "lr": 128,
|
||||
"ls": 129, "lt": 130, "lu": 131, "lv": 132, "ly": 133, "ma": 134, "mc": 135, "md": 136,
|
||||
"mg": 137, "mh": 138, "mk": 139, "ml": 140, "mm": 141, "mn": 142, "mo": 143, "mp": 144,
|
||||
"mq": 145, "mr": 146, "ms": 147, "mt": 148, "mu": 149, "mv": 150, "mw": 151, "mx": 152,
|
||||
"my": 153, "mz": 154, "na": 155, "nc": 156, "ne": 157, "nf": 158, "ng": 159, "ni": 160,
|
||||
"nl": 161, "no": 162, "np": 163, "nr": 164, "nu": 165, "nz": 166, "om": 167, "pa": 168,
|
||||
"pe": 169, "pf": 170, "pg": 171, "ph": 172, "pk": 173, "pl": 174, "pm": 175, "pn": 176,
|
||||
"pr": 177, "ps": 178, "pt": 179, "pw": 180, "py": 181, "qa": 182, "re": 183, "ro": 184,
|
||||
"ru": 185, "rw": 186, "sa": 187, "sb": 188, "sc": 189, "sd": 190, "se": 191, "sg": 192,
|
||||
"sh": 193, "si": 194, "sj": 195, "sk": 196, "sl": 197, "sm": 198, "sn": 199, "so": 200,
|
||||
"sr": 201, "st": 202, "sv": 203, "sy": 204, "sz": 205, "tc": 206, "td": 207, "tf": 208,
|
||||
"tg": 209, "th": 210, "tj": 211, "tk": 212, "tm": 213, "tn": 214, "to": 215, "tl": 216,
|
||||
"tr": 217, "tt": 218, "tv": 219, "tw": 220, "tz": 221, "ua": 222, "ug": 223, "um": 224,
|
||||
"us": 225, "uy": 226, "uz": 227, "va": 228, "vc": 229, "ve": 230, "vg": 231, "vi": 232,
|
||||
"vn": 233, "vu": 234, "wf": 235, "ws": 236, "ye": 237, "yt": 238, "rs": 239, "za": 240,
|
||||
"zm": 241, "me": 242, "zw": 243, "xx": 244, "a2": 245, "o1": 246, "ax": 247, "gg": 248,
|
||||
"im": 249, "je": 250, "bl": 251, "mf": 252,
|
||||
}
|
||||
# fmt: on
|
||||
|
||||
|
||||
class IPResolver:
|
||||
def __init__(self) -> None:
|
||||
self.cache: MutableMapping[str, IPAddress] = {}
|
||||
|
||||
def get_ip(self, headers: Mapping[str, str]) -> IPAddress:
|
||||
"""Resolve the IP address from the headers."""
|
||||
ip_str = headers.get("CF-Connecting-IP")
|
||||
if ip_str is None:
|
||||
forwards = headers["X-Forwarded-For"].split(",")
|
||||
|
||||
if len(forwards) != 1:
|
||||
ip_str = forwards[0]
|
||||
else:
|
||||
ip_str = headers["X-Real-IP"]
|
||||
|
||||
ip = self.cache.get(ip_str)
|
||||
if ip is None:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
self.cache[ip_str] = ip
|
||||
|
||||
return ip
|
||||
|
||||
|
||||
async def fetch_geoloc(
|
||||
ip: IPAddress,
|
||||
headers: Mapping[str, str] | None = None,
|
||||
) -> Geolocation | None:
|
||||
"""Attempt to fetch geolocation data by any means necessary."""
|
||||
geoloc = None
|
||||
if headers is not None:
|
||||
geoloc = _fetch_geoloc_from_headers(headers)
|
||||
|
||||
if geoloc is None:
|
||||
geoloc = await _fetch_geoloc_from_ip(ip)
|
||||
|
||||
return geoloc
|
||||
|
||||
|
||||
def _fetch_geoloc_from_headers(headers: Mapping[str, str]) -> Geolocation | None:
|
||||
"""Attempt to fetch geolocation data from http headers."""
|
||||
geoloc = __fetch_geoloc_cloudflare(headers)
|
||||
|
||||
if geoloc is None:
|
||||
geoloc = __fetch_geoloc_nginx(headers)
|
||||
|
||||
return geoloc
|
||||
|
||||
|
||||
def __fetch_geoloc_cloudflare(headers: Mapping[str, str]) -> Geolocation | None:
|
||||
"""Attempt to fetch geolocation data from cloudflare headers."""
|
||||
if not all(
|
||||
key in headers for key in ("CF-IPCountry", "CF-IPLatitude", "CF-IPLongitude")
|
||||
):
|
||||
return None
|
||||
|
||||
country_code = headers["CF-IPCountry"].lower()
|
||||
latitude = float(headers["CF-IPLatitude"])
|
||||
longitude = float(headers["CF-IPLongitude"])
|
||||
|
||||
return {
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"country": {
|
||||
"acronym": country_code,
|
||||
"numeric": country_codes[country_code],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def __fetch_geoloc_nginx(headers: Mapping[str, str]) -> Geolocation | None:
|
||||
"""Attempt to fetch geolocation data from nginx headers."""
|
||||
if not all(
|
||||
key in headers for key in ("X-Country-Code", "X-Latitude", "X-Longitude")
|
||||
):
|
||||
return None
|
||||
|
||||
country_code = headers["X-Country-Code"].lower()
|
||||
latitude = float(headers["X-Latitude"])
|
||||
longitude = float(headers["X-Longitude"])
|
||||
|
||||
return {
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"country": {
|
||||
"acronym": country_code,
|
||||
"numeric": country_codes[country_code],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def _fetch_geoloc_from_ip(ip: IPAddress) -> Geolocation | None:
|
||||
"""Fetch geolocation data based on ip (using ip-api)."""
|
||||
if not ip.is_private:
|
||||
url = f"http://ip-api.com/line/{ip}"
|
||||
else:
|
||||
url = "http://ip-api.com/line/"
|
||||
|
||||
response = await http_client.get(
|
||||
url,
|
||||
params={
|
||||
"fields": ",".join(("status", "message", "countryCode", "lat", "lon")),
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
log("Failed to get geoloc data: request failed.", Ansi.LRED)
|
||||
return None
|
||||
|
||||
status, *lines = response.read().decode().split("\n")
|
||||
|
||||
if status != "success":
|
||||
err_msg = lines[0]
|
||||
if err_msg == "invalid query":
|
||||
err_msg += f" ({url})"
|
||||
|
||||
log(f"Failed to get geoloc data: {err_msg} for ip {ip}.", Ansi.LRED)
|
||||
return None
|
||||
|
||||
country_acronym = lines[0].lower()
|
||||
|
||||
return {
|
||||
"latitude": float(lines[1]),
|
||||
"longitude": float(lines[2]),
|
||||
"country": {
|
||||
"acronym": country_acronym,
|
||||
"numeric": country_codes[country_acronym],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def log_strange_occurrence(obj: object) -> None:
|
||||
pickled_obj: bytes = pickle.dumps(obj)
|
||||
uploaded = False
|
||||
|
||||
if app.settings.AUTOMATICALLY_REPORT_PROBLEMS:
|
||||
# automatically reporting problems to cmyui's server
|
||||
response = await http_client.post(
|
||||
url="https://log.cmyui.xyz/",
|
||||
headers={
|
||||
"Bancho-Version": app.settings.VERSION,
|
||||
"Bancho-Domain": app.settings.DOMAIN,
|
||||
},
|
||||
content=pickled_obj,
|
||||
)
|
||||
if response.status_code == 200 and response.read() == b"ok":
|
||||
uploaded = True
|
||||
log(
|
||||
"Logged strange occurrence to cmyui's server. "
|
||||
"Thank you for your participation! <3",
|
||||
Ansi.LBLUE,
|
||||
)
|
||||
else:
|
||||
log(
|
||||
f"Autoupload to cmyui's server failed (HTTP {response.status_code})",
|
||||
Ansi.LRED,
|
||||
)
|
||||
|
||||
if not uploaded:
|
||||
# log to a file locally, and prompt the user
|
||||
while True:
|
||||
log_file = STRANGE_LOG_DIR / f"strange_{secrets.token_hex(4)}.db"
|
||||
if not log_file.exists():
|
||||
break
|
||||
|
||||
log_file.touch(exist_ok=False)
|
||||
log_file.write_bytes(pickled_obj)
|
||||
|
||||
log(
|
||||
"Logged strange occurrence to" + "/".join(log_file.parts[-4:]),
|
||||
Ansi.LYELLOW,
|
||||
)
|
||||
log(
|
||||
"It would be greatly appreciated if you could forward this to the "
|
||||
"bancho.py development team. To do so, please email josh@akatsuki.gg",
|
||||
Ansi.LYELLOW,
|
||||
)
|
||||
|
||||
|
||||
# dependency management
|
||||
|
||||
|
||||
class Version:
|
||||
def __init__(self, major: int, minor: int, micro: int) -> None:
|
||||
self.major = major
|
||||
self.minor = minor
|
||||
self.micro = micro
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.major}.{self.minor}.{self.micro}"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.as_tuple.__hash__()
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Version):
|
||||
return NotImplemented
|
||||
|
||||
return self.as_tuple == other.as_tuple
|
||||
|
||||
def __lt__(self, other: Version) -> bool:
|
||||
return self.as_tuple < other.as_tuple
|
||||
|
||||
def __le__(self, other: Version) -> bool:
|
||||
return self.as_tuple <= other.as_tuple
|
||||
|
||||
def __gt__(self, other: Version) -> bool:
|
||||
return self.as_tuple > other.as_tuple
|
||||
|
||||
def __ge__(self, other: Version) -> bool:
|
||||
return self.as_tuple >= other.as_tuple
|
||||
|
||||
@property
|
||||
def as_tuple(self) -> tuple[int, int, int]:
|
||||
return (self.major, self.minor, self.micro)
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str) -> Version | None:
|
||||
split = s.split(".")
|
||||
if len(split) == 3:
|
||||
return cls(
|
||||
major=int(split[0]),
|
||||
minor=int(split[1]),
|
||||
micro=int(split[2]),
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _get_latest_dependency_versions() -> AsyncGenerator[
|
||||
tuple[str, Version, Version],
|
||||
None,
|
||||
]:
|
||||
"""Return the current installed & latest version for each dependency."""
|
||||
with open("requirements.txt") as f:
|
||||
dependencies = f.read().splitlines(keepends=False)
|
||||
|
||||
# TODO: use asyncio.gather() to do all requests at once? or chunk them
|
||||
|
||||
for dependency in dependencies:
|
||||
dependency_name, _, dependency_ver = dependency.partition("==")
|
||||
current_ver = Version.from_str(dependency_ver)
|
||||
|
||||
if not current_ver:
|
||||
# the module uses some more advanced (and often hard to parse)
|
||||
# versioning system, so we won't be able to report updates.
|
||||
continue
|
||||
|
||||
# TODO: split up and do the requests asynchronously
|
||||
url = f"https://pypi.org/pypi/{dependency_name}/json"
|
||||
response = await http_client.get(url)
|
||||
json = response.json()
|
||||
|
||||
if response.status_code == 200 and json:
|
||||
latest_ver = Version.from_str(json["info"]["version"])
|
||||
|
||||
if not latest_ver:
|
||||
# they've started using a more advanced versioning system.
|
||||
continue
|
||||
|
||||
yield (dependency_name, latest_ver, current_ver)
|
||||
else:
|
||||
yield (dependency_name, current_ver, current_ver)
|
||||
|
||||
|
||||
async def check_for_dependency_updates() -> None:
|
||||
"""Notify the developer of any dependency updates available."""
|
||||
updates_available = False
|
||||
|
||||
async for module, current_ver, latest_ver in _get_latest_dependency_versions():
|
||||
if latest_ver > current_ver:
|
||||
updates_available = True
|
||||
log(
|
||||
f"{module} has an update available "
|
||||
f"[{current_ver!r} -> {latest_ver!r}]",
|
||||
Ansi.LMAGENTA,
|
||||
)
|
||||
|
||||
if updates_available:
|
||||
log(
|
||||
"Python modules can be updated with "
|
||||
"`python3.11 -m pip install -U <modules>`.",
|
||||
Ansi.LMAGENTA,
|
||||
)
|
||||
|
||||
|
||||
# sql migrations
|
||||
|
||||
|
||||
async def _get_current_sql_structure_version() -> Version | None:
|
||||
"""Get the last launched version of the server."""
|
||||
res = await app.state.services.database.fetch_one(
|
||||
"SELECT ver_major, ver_minor, ver_micro "
|
||||
"FROM startups ORDER BY datetime DESC LIMIT 1",
|
||||
)
|
||||
|
||||
if res:
|
||||
return Version(res["ver_major"], res["ver_minor"], res["ver_micro"])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def run_sql_migrations() -> None:
|
||||
"""Update the sql structure, if it has changed."""
|
||||
software_version = Version.from_str(app.settings.VERSION)
|
||||
if software_version is None:
|
||||
raise RuntimeError(f"Invalid bancho.py version '{app.settings.VERSION}'")
|
||||
|
||||
last_run_migration_version = await _get_current_sql_structure_version()
|
||||
if not last_run_migration_version:
|
||||
# Migrations have never run before - this is the first time starting the server.
|
||||
# We'll insert the current version into the database, so future versions know to migrate.
|
||||
await app.state.services.database.execute(
|
||||
"INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) "
|
||||
"VALUES (:major, :minor, :micro, NOW())",
|
||||
{
|
||||
"major": software_version.major,
|
||||
"minor": software_version.minor,
|
||||
"micro": software_version.micro,
|
||||
},
|
||||
)
|
||||
return # already up to date (server has never run before)
|
||||
|
||||
if software_version == last_run_migration_version:
|
||||
return # already up to date
|
||||
|
||||
# version changed; there may be sql changes.
|
||||
content = SQL_UPDATES_FILE.read_text()
|
||||
|
||||
queries: list[str] = []
|
||||
q_lines: list[str] = []
|
||||
|
||||
update_ver = None
|
||||
|
||||
for line in content.splitlines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.startswith("#"):
|
||||
# may be normal comment or new version
|
||||
r_match = VERSION_RGX.fullmatch(line)
|
||||
if r_match:
|
||||
update_ver = Version.from_str(r_match["ver"])
|
||||
|
||||
continue
|
||||
elif not update_ver:
|
||||
continue
|
||||
|
||||
# we only need the updates between the
|
||||
# previous and new version of the server.
|
||||
if last_run_migration_version < update_ver <= software_version:
|
||||
if line.endswith(";"):
|
||||
if q_lines:
|
||||
q_lines.append(line)
|
||||
queries.append(" ".join(q_lines))
|
||||
q_lines = []
|
||||
else:
|
||||
queries.append(line)
|
||||
else:
|
||||
q_lines.append(line)
|
||||
|
||||
if queries:
|
||||
log(
|
||||
f"Updating mysql structure (v{last_run_migration_version!r} -> v{software_version!r}).",
|
||||
Ansi.LMAGENTA,
|
||||
)
|
||||
|
||||
# XXX: we can't use a transaction here with mysql as structural changes to
|
||||
# tables implicitly commit: https://dev.mysql.com/doc/refman/5.7/en/implicit-commit.html
|
||||
for query in queries:
|
||||
try:
|
||||
await app.state.services.database.execute(query)
|
||||
except pymysql.err.MySQLError as exc:
|
||||
log(f"Failed: {query}", Ansi.GRAY)
|
||||
log(repr(exc))
|
||||
log(
|
||||
"SQL failed to update - unless you've been "
|
||||
"modifying sql and know what caused this, "
|
||||
"please contact @cmyui on Discord.",
|
||||
Ansi.LRED,
|
||||
)
|
||||
raise KeyboardInterrupt from exc
|
||||
else:
|
||||
# all queries executed successfully
|
||||
await app.state.services.database.execute(
|
||||
"INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) "
|
||||
"VALUES (:major, :minor, :micro, NOW())",
|
||||
{
|
||||
"major": software_version.major,
|
||||
"minor": software_version.minor,
|
||||
"micro": software_version.micro,
|
||||
},
|
||||
)
|
54
app/state/sessions.py
Normal file
54
app/state/sessions.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
from app.objects.collections import Channels
|
||||
from app.objects.collections import Matches
|
||||
from app.objects.collections import Players
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.objects.player import Player
|
||||
|
||||
players = Players()
|
||||
channels = Channels()
|
||||
matches = Matches()
|
||||
|
||||
api_keys: dict[str, int] = {}
|
||||
|
||||
housekeeping_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
bot: Player
|
||||
|
||||
|
||||
# use cases
|
||||
|
||||
|
||||
async def cancel_housekeeping_tasks() -> None:
|
||||
log(
|
||||
f"-> Cancelling {len(housekeeping_tasks)} housekeeping tasks.",
|
||||
Ansi.LMAGENTA,
|
||||
)
|
||||
|
||||
# cancel housekeeping tasks
|
||||
for task in housekeeping_tasks:
|
||||
task.cancel()
|
||||
|
||||
await asyncio.gather(*housekeeping_tasks, return_exceptions=True)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for task in housekeeping_tasks:
|
||||
if not task.cancelled():
|
||||
exception = task.exception()
|
||||
if exception:
|
||||
loop.call_exception_handler(
|
||||
{
|
||||
"message": "unhandled exception during loop shutdown",
|
||||
"exception": exception,
|
||||
"task": task,
|
||||
},
|
||||
)
|
27
app/timer.py
Normal file
27
app/timer.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from types import TracebackType
|
||||
|
||||
|
||||
class Timer:
|
||||
def __init__(self) -> None:
|
||||
self.start_time: float | None = None
|
||||
self.end_time: float | None = None
|
||||
|
||||
def __enter__(self) -> Timer:
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed(self) -> float:
|
||||
if self.start_time is None or self.end_time is None:
|
||||
raise ValueError("Timer has not been started or stopped.")
|
||||
return self.end_time - self.start_time
|
0
app/usecases/__init__.py
Normal file
0
app/usecases/__init__.py
Normal file
30
app/usecases/achievements.py
Normal file
30
app/usecases/achievements.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import app.repositories.achievements
|
||||
from app.repositories.achievements import Achievement
|
||||
|
||||
|
||||
async def create(
|
||||
file: str,
|
||||
name: str,
|
||||
desc: str,
|
||||
cond: str,
|
||||
) -> Achievement:
|
||||
achievement = await app.repositories.achievements.create(
|
||||
file,
|
||||
name,
|
||||
desc,
|
||||
cond,
|
||||
)
|
||||
return achievement
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[Achievement]:
|
||||
achievements = await app.repositories.achievements.fetch_many(
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
return achievements
|
138
app/usecases/performance.py
Normal file
138
app/usecases/performance.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypedDict
|
||||
|
||||
from akatsuki_pp_py import Beatmap
|
||||
from akatsuki_pp_py import Calculator
|
||||
|
||||
from app.constants.mods import Mods
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoreParams:
|
||||
mode: int
|
||||
mods: int | None = None
|
||||
combo: int | None = None
|
||||
|
||||
# caller may pass either acc OR 300/100/50/geki/katu/miss
|
||||
# passing both will result in a value error being raised
|
||||
acc: float | None = None
|
||||
|
||||
n300: int | None = None
|
||||
n100: int | None = None
|
||||
n50: int | None = None
|
||||
ngeki: int | None = None
|
||||
nkatu: int | None = None
|
||||
nmiss: int | None = None
|
||||
|
||||
|
||||
class PerformanceRating(TypedDict):
|
||||
pp: float
|
||||
pp_acc: float | None
|
||||
pp_aim: float | None
|
||||
pp_speed: float | None
|
||||
pp_flashlight: float | None
|
||||
effective_miss_count: float | None
|
||||
pp_difficulty: float | None
|
||||
|
||||
|
||||
class DifficultyRating(TypedDict):
|
||||
stars: float
|
||||
aim: float | None
|
||||
speed: float | None
|
||||
flashlight: float | None
|
||||
slider_factor: float | None
|
||||
speed_note_count: float | None
|
||||
stamina: float | None
|
||||
color: float | None
|
||||
rhythm: float | None
|
||||
peak: float | None
|
||||
|
||||
|
||||
class PerformanceResult(TypedDict):
|
||||
performance: PerformanceRating
|
||||
difficulty: DifficultyRating
|
||||
|
||||
|
||||
def calculate_performances(
|
||||
osu_file_path: str,
|
||||
scores: Iterable[ScoreParams],
|
||||
) -> list[PerformanceResult]:
|
||||
"""\
|
||||
Calculate performance for multiple scores on a single beatmap.
|
||||
|
||||
Typically most useful for mass-recalculation situations.
|
||||
|
||||
TODO: Some level of error handling & returning to caller should be
|
||||
implemented here to handle cases where e.g. the beatmap file is invalid
|
||||
or there an issue during calculation.
|
||||
"""
|
||||
calc_bmap = Beatmap(path=osu_file_path)
|
||||
|
||||
results: list[PerformanceResult] = []
|
||||
|
||||
for score in scores:
|
||||
if score.acc and (
|
||||
score.n300 or score.n100 or score.n50 or score.ngeki or score.nkatu
|
||||
):
|
||||
raise ValueError(
|
||||
"Must not specify accuracy AND 300/100/50/geki/katu. Only one or the other.",
|
||||
)
|
||||
|
||||
# rosupp ignores NC and requires DT
|
||||
if score.mods is not None:
|
||||
if score.mods & Mods.NIGHTCORE:
|
||||
score.mods |= Mods.DOUBLETIME
|
||||
|
||||
calculator = Calculator(
|
||||
mode=score.mode,
|
||||
mods=score.mods or 0,
|
||||
combo=score.combo,
|
||||
acc=score.acc,
|
||||
n300=score.n300,
|
||||
n100=score.n100,
|
||||
n50=score.n50,
|
||||
n_geki=score.ngeki,
|
||||
n_katu=score.nkatu,
|
||||
n_misses=score.nmiss,
|
||||
)
|
||||
result = calculator.performance(calc_bmap)
|
||||
|
||||
pp = result.pp
|
||||
|
||||
if math.isnan(pp) or math.isinf(pp):
|
||||
# TODO: report to logserver
|
||||
pp = 0.0
|
||||
else:
|
||||
pp = round(pp, 3)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"performance": {
|
||||
"pp": pp,
|
||||
"pp_acc": result.pp_acc,
|
||||
"pp_aim": result.pp_aim,
|
||||
"pp_speed": result.pp_speed,
|
||||
"pp_flashlight": result.pp_flashlight,
|
||||
"effective_miss_count": result.effective_miss_count,
|
||||
"pp_difficulty": result.pp_difficulty,
|
||||
},
|
||||
"difficulty": {
|
||||
"stars": result.difficulty.stars,
|
||||
"aim": result.difficulty.aim,
|
||||
"speed": result.difficulty.speed,
|
||||
"flashlight": result.difficulty.flashlight,
|
||||
"slider_factor": result.difficulty.slider_factor,
|
||||
"speed_note_count": result.difficulty.speed_note_count,
|
||||
"stamina": result.difficulty.stamina,
|
||||
"color": result.difficulty.color,
|
||||
"rhythm": result.difficulty.rhythm,
|
||||
"peak": result.difficulty.peak,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return results
|
27
app/usecases/user_achievements.py
Normal file
27
app/usecases/user_achievements.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import app.repositories.user_achievements
|
||||
from app._typing import UNSET
|
||||
from app._typing import _UnsetSentinel
|
||||
from app.repositories.user_achievements import UserAchievement
|
||||
|
||||
|
||||
async def create(user_id: int, achievement_id: int) -> UserAchievement:
|
||||
user_achievement = await app.repositories.user_achievements.create(
|
||||
user_id,
|
||||
achievement_id,
|
||||
)
|
||||
return user_achievement
|
||||
|
||||
|
||||
async def fetch_many(
|
||||
user_id: int | _UnsetSentinel = UNSET,
|
||||
page: int | None = None,
|
||||
page_size: int | None = None,
|
||||
) -> list[UserAchievement]:
|
||||
user_achievements = await app.repositories.user_achievements.fetch_many(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return user_achievements
|
254
app/utils.py
Normal file
254
app/utils.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import inspect
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
import pymysql
|
||||
|
||||
import app.settings
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.repositories.users import User
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
DATA_PATH = Path.cwd() / ".data"
|
||||
ACHIEVEMENTS_ASSETS_PATH = DATA_PATH / "assets/medals/client"
|
||||
DEFAULT_AVATAR_PATH = DATA_PATH / "avatars/default.jpg"
|
||||
|
||||
|
||||
def make_safe_name(name: str) -> str:
|
||||
"""Return a name safe for usage in sql."""
|
||||
return name.lower().replace(" ", "_")
|
||||
|
||||
|
||||
def determine_highest_ranking_clan_member(members: list[User]) -> User:
|
||||
return next(iter(sorted(members, key=lambda m: m["clan_priv"], reverse=True)))
|
||||
|
||||
|
||||
def _download_achievement_images_osu(achievements_path: Path) -> bool:
|
||||
"""Download all used achievement images (one by one, from osu!)."""
|
||||
achs: list[str] = []
|
||||
|
||||
for resolution in ("", "@2x"):
|
||||
for mode in ("osu", "taiko", "fruits", "mania"):
|
||||
# only osu!std has 9 & 10 star pass/fc medals.
|
||||
for star_rating in range(1, 1 + (10 if mode == "osu" else 8)):
|
||||
achs.append(f"{mode}-skill-pass-{star_rating}{resolution}.png")
|
||||
achs.append(f"{mode}-skill-fc-{star_rating}{resolution}.png")
|
||||
|
||||
for combo in (500, 750, 1000, 2000):
|
||||
achs.append(f"osu-combo-{combo}{resolution}.png")
|
||||
|
||||
for mod in (
|
||||
"suddendeath",
|
||||
"hidden",
|
||||
"perfect",
|
||||
"hardrock",
|
||||
"doubletime",
|
||||
"flashlight",
|
||||
"easy",
|
||||
"nofail",
|
||||
"nightcore",
|
||||
"halftime",
|
||||
"spunout",
|
||||
):
|
||||
achs.append(f"all-intro-{mod}{resolution}.png")
|
||||
|
||||
log("Downloading achievement images from osu!.", Ansi.LCYAN)
|
||||
|
||||
for ach in achs:
|
||||
resp = httpx.get(f"https://assets.ppy.sh/medals/client/{ach}")
|
||||
if resp.status_code != 200:
|
||||
return False
|
||||
|
||||
log(f"Saving achievement: {ach}", Ansi.LCYAN)
|
||||
(achievements_path / ach).write_bytes(resp.content)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_achievement_images(achievements_path: Path) -> None:
|
||||
"""Download all used achievement images (using the best available source)."""
|
||||
|
||||
# download individual files from the official osu! servers
|
||||
downloaded = _download_achievement_images_osu(achievements_path)
|
||||
|
||||
if downloaded:
|
||||
log("Downloaded all achievement images.", Ansi.LGREEN)
|
||||
else:
|
||||
# TODO: make the code safe in this state
|
||||
log("Failed to download achievement images.", Ansi.LRED)
|
||||
achievements_path.rmdir()
|
||||
|
||||
# allow passthrough (don't hard crash).
|
||||
# the server will *mostly* work in this state.
|
||||
pass
|
||||
|
||||
|
||||
def download_default_avatar(default_avatar_path: Path) -> None:
|
||||
"""Download an avatar to use as the server's default."""
|
||||
resp = httpx.get("https://i.cmyui.xyz/U24XBZw-4wjVME-JaEz3.png")
|
||||
|
||||
if resp.status_code != 200:
|
||||
log("Failed to fetch default avatar.", Ansi.LRED)
|
||||
return
|
||||
|
||||
log("Downloaded default avatar.", Ansi.LGREEN)
|
||||
default_avatar_path.write_bytes(resp.content)
|
||||
|
||||
|
||||
def has_internet_connectivity(timeout: float = 1.0) -> bool:
|
||||
"""Check for an active internet connection."""
|
||||
COMMON_DNS_SERVERS = (
|
||||
# Cloudflare
|
||||
"1.1.1.1",
|
||||
"1.0.0.1",
|
||||
# Google
|
||||
"8.8.8.8",
|
||||
"8.8.4.4",
|
||||
)
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client:
|
||||
client.settimeout(timeout)
|
||||
for host in COMMON_DNS_SERVERS:
|
||||
try:
|
||||
client.connect((host, 53))
|
||||
except OSError:
|
||||
continue
|
||||
else:
|
||||
return True
|
||||
|
||||
# all connections failed
|
||||
return False
|
||||
|
||||
|
||||
class FrameInfo(TypedDict):
|
||||
function: str
|
||||
filename: str
|
||||
lineno: int
|
||||
charno: int
|
||||
locals: dict[str, str]
|
||||
|
||||
|
||||
def get_appropriate_stacktrace() -> list[FrameInfo]:
|
||||
"""Return information of all frames related to cmyui_pkg and below."""
|
||||
stack = inspect.stack()[1:]
|
||||
for idx, frame in enumerate(stack):
|
||||
if frame.function == "run":
|
||||
break
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
return [
|
||||
{
|
||||
"function": frame.function,
|
||||
"filename": Path(frame.filename).name,
|
||||
"lineno": frame.lineno,
|
||||
"charno": frame.index or 0,
|
||||
"locals": {k: repr(v) for k, v in frame.frame.f_locals.items()},
|
||||
}
|
||||
# reverse for python-like stacktrace
|
||||
# ordering; puts the most recent
|
||||
# call closest to the command line
|
||||
for frame in reversed(stack[:idx])
|
||||
]
|
||||
|
||||
|
||||
def pymysql_encode(
|
||||
conv: Callable[[Any, dict[object, object] | None], str],
|
||||
) -> Callable[[type[T]], type[T]]:
|
||||
"""Decorator to allow for adding to pymysql's encoders."""
|
||||
|
||||
def wrapper(cls: type[T]) -> type[T]:
|
||||
pymysql.converters.encoders[cls] = conv
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def escape_enum(
|
||||
val: Any,
|
||||
_: dict[object, object] | None = None,
|
||||
) -> str: # used for ^
|
||||
return str(int(val))
|
||||
|
||||
|
||||
def ensure_persistent_volumes_are_available() -> None:
|
||||
# create /.data directory
|
||||
DATA_PATH.mkdir(exist_ok=True)
|
||||
|
||||
# create /.data/... subdirectories
|
||||
for sub_dir in ("avatars", "logs", "osu", "osr", "ss"):
|
||||
subdir = DATA_PATH / sub_dir
|
||||
subdir.mkdir(exist_ok=True)
|
||||
|
||||
# download achievement images from osu!
|
||||
if not ACHIEVEMENTS_ASSETS_PATH.exists():
|
||||
ACHIEVEMENTS_ASSETS_PATH.mkdir(parents=True)
|
||||
download_achievement_images(ACHIEVEMENTS_ASSETS_PATH)
|
||||
|
||||
# download a default avatar image for new users
|
||||
if not DEFAULT_AVATAR_PATH.exists():
|
||||
download_default_avatar(DEFAULT_AVATAR_PATH)
|
||||
|
||||
|
||||
def is_running_as_admin() -> bool:
|
||||
try:
|
||||
return os.geteuid() == 0 # type: ignore[attr-defined, no-any-return, unused-ignore]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ctypes.windll.shell32.IsUserAnAdmin() == 1 # type: ignore[attr-defined, no-any-return, unused-ignore]
|
||||
except AttributeError:
|
||||
raise Exception(
|
||||
f"{sys.platform} is not currently supported on bancho.py, please create a github issue!",
|
||||
)
|
||||
|
||||
|
||||
def display_startup_dialog() -> None:
|
||||
"""Print any general information or warnings to the console."""
|
||||
if app.settings.DEVELOPER_MODE:
|
||||
log("running in advanced mode", Ansi.LYELLOW)
|
||||
if app.settings.DEBUG:
|
||||
log("running in debug mode", Ansi.LMAGENTA)
|
||||
|
||||
# running on root/admin grants the software potentally dangerous and
|
||||
# unnecessary power over the operating system and is not advised.
|
||||
if is_running_as_admin():
|
||||
log(
|
||||
"It is not recommended to run bancho.py as root/admin, especially in production."
|
||||
+ (
|
||||
" You are at increased risk as developer mode is enabled."
|
||||
if app.settings.DEVELOPER_MODE
|
||||
else ""
|
||||
),
|
||||
Ansi.LYELLOW,
|
||||
)
|
||||
|
||||
if not has_internet_connectivity():
|
||||
log("No internet connectivity detected", Ansi.LYELLOW)
|
||||
|
||||
|
||||
def has_jpeg_headers_and_trailers(data_view: memoryview) -> bool:
|
||||
return data_view[:4] == b"\xff\xd8\xff\xe0" and data_view[6:11] == b"JFIF\x00"
|
||||
|
||||
|
||||
def has_png_headers_and_trailers(data_view: memoryview) -> bool:
|
||||
return (
|
||||
data_view[:8] == b"\x89PNG\r\n\x1a\n"
|
||||
and data_view[-8:] == b"\x49END\xae\x42\x60\x82"
|
||||
)
|
99
docker-compose.test.yml
Normal file
99
docker-compose.test.yml
Normal file
@@ -0,0 +1,99 @@
|
||||
services:
|
||||
## shared services
|
||||
|
||||
mysql-test:
|
||||
image: mysql:latest
|
||||
# ports:
|
||||
# - ${DB_PORT}:${DB_PORT}
|
||||
environment:
|
||||
MYSQL_USER: ${DB_USER}
|
||||
MYSQL_PASSWORD: ${DB_PASS}
|
||||
MYSQL_DATABASE: ${DB_NAME}
|
||||
MYSQL_HOST: ${DB_HOST}
|
||||
MYSQL_PORT: ${DB_PORT}
|
||||
MYSQL_ROOT_PASSWORD: ${DB_PASS}
|
||||
volumes:
|
||||
- ./migrations/base.sql:/docker-entrypoint-initdb.d/init.sql:ro
|
||||
- test-db-data:/var/lib/mysql
|
||||
networks:
|
||||
- test-network
|
||||
healthcheck:
|
||||
test: "/usr/bin/mysql --user=$$MYSQL_USER --password=$$MYSQL_PASSWORD --execute \"SHOW DATABASES;\""
|
||||
interval: 2s
|
||||
timeout: 20s
|
||||
retries: 10
|
||||
|
||||
redis-test:
|
||||
image: bitnami/redis:latest
|
||||
# ports:
|
||||
# - ${REDIS_PORT}:${REDIS_PORT}
|
||||
user: root
|
||||
volumes:
|
||||
- test-redis-data:/bitnami/redis/data
|
||||
networks:
|
||||
- test-network
|
||||
environment:
|
||||
- ALLOW_EMPTY_PASSWORD=yes
|
||||
- REDIS_PASSWORD=${REDIS_PASS}
|
||||
|
||||
## application services
|
||||
|
||||
bancho-test:
|
||||
# we also have a public image: osuakatsuki/bancho.py:latest
|
||||
image: bancho:latest
|
||||
depends_on:
|
||||
mysql-test:
|
||||
condition: service_healthy
|
||||
redis-test:
|
||||
condition: service_started
|
||||
tty: true
|
||||
init: true
|
||||
volumes:
|
||||
- .:/srv/root
|
||||
- test-data:/srv/root/.data
|
||||
networks:
|
||||
- test-network
|
||||
environment:
|
||||
- APP_HOST=${APP_HOST}
|
||||
- APP_PORT=${APP_PORT}
|
||||
- DB_USER=${DB_USER}
|
||||
- DB_PASS=${DB_PASS}
|
||||
- DB_NAME=${DB_NAME}
|
||||
- DB_HOST=${DB_HOST}
|
||||
- DB_PORT=${DB_PORT}
|
||||
- REDIS_USER=${REDIS_USER}
|
||||
- REDIS_PASS=${REDIS_PASS}
|
||||
- REDIS_HOST=${REDIS_HOST}
|
||||
- REDIS_PORT=${REDIS_PORT}
|
||||
- REDIS_DB=${REDIS_DB}
|
||||
- OSU_API_KEY=${OSU_API_KEY}
|
||||
- MIRROR_SEARCH_ENDPOINT=${MIRROR_SEARCH_ENDPOINT}
|
||||
- MIRROR_DOWNLOAD_ENDPOINT=${MIRROR_DOWNLOAD_ENDPOINT}
|
||||
- DOMAIN=${DOMAIN}
|
||||
- COMMAND_PREFIX=${COMMAND_PREFIX}
|
||||
- SEASONAL_BGS=${SEASONAL_BGS}
|
||||
- MENU_ICON_URL=${MENU_ICON_URL}
|
||||
- MENU_ONCLICK_URL=${MENU_ONCLICK_URL}
|
||||
- DATADOG_API_KEY=${DATADOG_API_KEY}
|
||||
- DATADOG_APP_KEY=${DATADOG_APP_KEY}
|
||||
- DEBUG=${DEBUG}
|
||||
- REDIRECT_OSU_URLS=${REDIRECT_OSU_URLS}
|
||||
- PP_CACHED_ACCS=${PP_CACHED_ACCS}
|
||||
- DISALLOWED_NAMES=${DISALLOWED_NAMES}
|
||||
- DISALLOWED_PASSWORDS=${DISALLOWED_PASSWORDS}
|
||||
- DISALLOW_OLD_CLIENTS=${DISALLOW_OLD_CLIENTS}
|
||||
- DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION}
|
||||
- DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK}
|
||||
- AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS}
|
||||
- LOG_WITH_COLORS=${LOG_WITH_COLORS}
|
||||
- SSL_CERT_PATH=${SSL_CERT_PATH}
|
||||
- SSL_KEY_PATH=${SSL_KEY_PATH}
|
||||
- DEVELOPER_MODE=${DEVELOPER_MODE}
|
||||
|
||||
volumes:
|
||||
test-data:
|
||||
test-db-data:
|
||||
test-redis-data:
|
||||
|
||||
networks:
|
||||
test-network:
|
92
docker-compose.yml
Normal file
92
docker-compose.yml
Normal file
@@ -0,0 +1,92 @@
|
||||
services:
|
||||
## shared services
|
||||
|
||||
mysql:
|
||||
image: mysql:latest
|
||||
# ports:
|
||||
# - ${DB_PORT}:${DB_PORT}
|
||||
environment:
|
||||
MYSQL_USER: ${DB_USER}
|
||||
MYSQL_PASSWORD: ${DB_PASS}
|
||||
MYSQL_DATABASE: ${DB_NAME}
|
||||
MYSQL_HOST: ${DB_HOST}
|
||||
MYSQL_PORT: ${DB_PORT}
|
||||
MYSQL_RANDOM_ROOT_PASSWORD: "true"
|
||||
volumes:
|
||||
- ./migrations/base.sql:/docker-entrypoint-initdb.d/init.sql:ro
|
||||
- db-data:/var/lib/mysql
|
||||
healthcheck:
|
||||
test: "/usr/bin/mysql --user=$$MYSQL_USER --password=$$MYSQL_PASSWORD --execute \"SHOW DATABASES;\""
|
||||
interval: 2s
|
||||
timeout: 20s
|
||||
retries: 10
|
||||
|
||||
redis:
|
||||
image: bitnami/redis:latest
|
||||
# ports:
|
||||
# - ${REDIS_PORT}:${REDIS_PORT}
|
||||
user: root
|
||||
volumes:
|
||||
- redis-data:/bitnami/redis/data
|
||||
environment:
|
||||
- ALLOW_EMPTY_PASSWORD=yes
|
||||
- REDIS_PASSWORD=${REDIS_PASS}
|
||||
|
||||
## application services
|
||||
|
||||
bancho:
|
||||
# we also have a public image: osuakatsuki/bancho.py:latest
|
||||
image: bancho:latest
|
||||
ports:
|
||||
- ${APP_PORT}:${APP_PORT}
|
||||
depends_on:
|
||||
mysql:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_started
|
||||
tty: true
|
||||
init: true
|
||||
volumes:
|
||||
- .:/srv/root
|
||||
- data:/srv/root/.data
|
||||
environment:
|
||||
- APP_HOST=${APP_HOST}
|
||||
- APP_PORT=${APP_PORT}
|
||||
- DB_USER=${DB_USER}
|
||||
- DB_PASS=${DB_PASS}
|
||||
- DB_NAME=${DB_NAME}
|
||||
- DB_HOST=${DB_HOST}
|
||||
- DB_PORT=${DB_PORT}
|
||||
- REDIS_USER=${REDIS_USER}
|
||||
- REDIS_PASS=${REDIS_PASS}
|
||||
- REDIS_HOST=${REDIS_HOST}
|
||||
- REDIS_PORT=${REDIS_PORT}
|
||||
- REDIS_DB=${REDIS_DB}
|
||||
- OSU_API_KEY=${OSU_API_KEY}
|
||||
- MIRROR_SEARCH_ENDPOINT=${MIRROR_SEARCH_ENDPOINT}
|
||||
- MIRROR_DOWNLOAD_ENDPOINT=${MIRROR_DOWNLOAD_ENDPOINT}
|
||||
- DOMAIN=${DOMAIN}
|
||||
- COMMAND_PREFIX=${COMMAND_PREFIX}
|
||||
- SEASONAL_BGS=${SEASONAL_BGS}
|
||||
- MENU_ICON_URL=${MENU_ICON_URL}
|
||||
- MENU_ONCLICK_URL=${MENU_ONCLICK_URL}
|
||||
- DATADOG_API_KEY=${DATADOG_API_KEY}
|
||||
- DATADOG_APP_KEY=${DATADOG_APP_KEY}
|
||||
- DEBUG=${DEBUG}
|
||||
- REDIRECT_OSU_URLS=${REDIRECT_OSU_URLS}
|
||||
- PP_CACHED_ACCS=${PP_CACHED_ACCS}
|
||||
- DISALLOWED_NAMES=${DISALLOWED_NAMES}
|
||||
- DISALLOWED_PASSWORDS=${DISALLOWED_PASSWORDS}
|
||||
- DISALLOW_OLD_CLIENTS=${DISALLOW_OLD_CLIENTS}
|
||||
- DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION}
|
||||
- DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK}
|
||||
- AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS}
|
||||
- LOG_WITH_COLORS=${LOG_WITH_COLORS}
|
||||
- SSL_CERT_PATH=${SSL_CERT_PATH}
|
||||
- SSL_KEY_PATH=${SSL_KEY_PATH}
|
||||
- DEVELOPER_MODE=${DEVELOPER_MODE}
|
||||
|
||||
volumes:
|
||||
data:
|
||||
db-data:
|
||||
redis-data:
|
30
ext/Caddyfile
Normal file
30
ext/Caddyfile
Normal file
@@ -0,0 +1,30 @@
|
||||
# Comment this out if you need to explicitly
|
||||
# use self-signed certs.
|
||||
# NOTE: Not necessary if using a '.local' domain
|
||||
#
|
||||
# {
|
||||
# local_certs
|
||||
# }
|
||||
|
||||
c.{$DOMAIN}, ce.{$DOMAIN}, c4.{$DOMAIN}, osu.{$DOMAIN}, b.{$DOMAIN}, api.{$DOMAIN} {
|
||||
encode gzip
|
||||
reverse_proxy * 127.0.0.1:{$APP_PORT} {
|
||||
header_up X-Real-IP {remote_host}
|
||||
}
|
||||
|
||||
request_body {
|
||||
max_size 20MB
|
||||
}
|
||||
}
|
||||
|
||||
assets.{$DOMAIN} {
|
||||
encode gzip
|
||||
root * {$DATA_DIRECTORY}/assets
|
||||
file_server
|
||||
}
|
||||
|
||||
a.{$DOMAIN} {
|
||||
encode gzip
|
||||
root * {$DATA_DIRECTORY}/avatars
|
||||
try_files {path} {file.base}.png {file.base}.jpg {file.base}.gif {file.base}.jpeg {file.base}.jfif default.jpg =404
|
||||
}
|
54
ext/nginx.conf.example
Normal file
54
ext/nginx.conf.example
Normal file
@@ -0,0 +1,54 @@
|
||||
# c[e4]?.ppy.sh is used for bancho
|
||||
# osu.ppy.sh is used for /web, /api, etc.
|
||||
# a.ppy.sh is used for osu! avatars
|
||||
|
||||
upstream bancho {
|
||||
server 127.0.0.1:${APP_PORT};
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl;
|
||||
server_name c.${DOMAIN} ce.${DOMAIN} c4.${DOMAIN} osu.${DOMAIN} b.${DOMAIN} api.${DOMAIN};
|
||||
client_max_body_size 20M;
|
||||
|
||||
ssl_certificate ${SSL_CERT_PATH};
|
||||
ssl_certificate_key ${SSL_KEY_PATH};
|
||||
ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:@SECLEVEL=1";
|
||||
|
||||
location / {
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $http_host;
|
||||
add_header Access-Control-Allow-Origin *;
|
||||
proxy_redirect off;
|
||||
proxy_pass http://bancho;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl;
|
||||
server_name assets.${DOMAIN};
|
||||
|
||||
ssl_certificate ${SSL_CERT_PATH};
|
||||
ssl_certificate_key ${SSL_KEY_PATH};
|
||||
ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:@SECLEVEL=1";
|
||||
|
||||
location / {
|
||||
default_type image/png;
|
||||
root ${DATA_DIRECTORY}/assets;
|
||||
}
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl;
|
||||
server_name a.${DOMAIN};
|
||||
|
||||
ssl_certificate ${SSL_CERT_PATH};
|
||||
ssl_certificate_key ${SSL_KEY_PATH};
|
||||
ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:@SECLEVEL=1";
|
||||
|
||||
location / {
|
||||
root ${DATA_DIRECTORY}/avatars;
|
||||
try_files $uri $uri.png $uri.jpg $uri.gif $uri.jpeg $uri.jfif /default.jpg = 404;
|
||||
}
|
||||
}
|
36
logging.yaml.example
Normal file
36
logging.yaml.example
Normal file
@@ -0,0 +1,36 @@
|
||||
version: 1
|
||||
disable_existing_loggers: true
|
||||
loggers:
|
||||
httpx:
|
||||
level: WARNING
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
httpcore:
|
||||
level: WARNING
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
multipart.multipart:
|
||||
level: ERROR
|
||||
handlers: [console]
|
||||
propagate: no
|
||||
handlers:
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
level: INFO
|
||||
formatter: plaintext
|
||||
stream: ext://sys.stdout
|
||||
# file:
|
||||
# class: logging.FileHandler
|
||||
# level: INFO
|
||||
# formatter: json
|
||||
# filename: logs.log
|
||||
formatters:
|
||||
plaintext:
|
||||
format: '[%(asctime)s] %(levelname)s %(message)s'
|
||||
datefmt: '%Y-%m-%d %H:%M:%S'
|
||||
# json:
|
||||
# class: pythonjsonlogger.jsonlogger.JsonFormatter
|
||||
# format: '%(asctime)s %(name)s %(levelname)s %(message)s'
|
||||
root:
|
||||
level: INFO
|
||||
handlers: [console] # , file]
|
31
main.py
Normal file
31
main.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3.11
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
|
||||
import app.logging
|
||||
import app.settings
|
||||
import app.utils
|
||||
|
||||
app.logging.configure_logging()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
app.utils.display_startup_dialog()
|
||||
uvicorn.run(
|
||||
"app.api.init_api:asgi_app",
|
||||
reload=app.settings.DEBUG,
|
||||
log_level=logging.WARNING,
|
||||
server_header=False,
|
||||
date_header=False,
|
||||
headers=[("bancho-version", app.settings.VERSION)],
|
||||
host=app.settings.APP_HOST,
|
||||
port=app.settings.APP_PORT,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
1953
poetry.lock
generated
Normal file
1953
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
92
pyproject.toml
Normal file
92
pyproject.toml
Normal file
@@ -0,0 +1,92 @@
|
||||
[tool.mypy]
|
||||
plugins = ["pydantic.mypy"]
|
||||
strict = true
|
||||
disallow_untyped_calls = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "tests.*"
|
||||
disable_error_code = ["var-annotated", "has-type"]
|
||||
disallow_untyped_defs = false
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"aiomysql.*",
|
||||
"mitmproxy.*",
|
||||
"py3rijndael.*",
|
||||
"timeago.*",
|
||||
"pytimeparse.*",
|
||||
"cpuinfo.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pydantic-mypy]
|
||||
init_forbid_extra = true
|
||||
init_typed = true
|
||||
warn_requird_dynamic_aliases = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
[tool.isort]
|
||||
add_imports = ["from __future__ import annotations"]
|
||||
force_single_line = true
|
||||
profile = "black"
|
||||
|
||||
[tool.poetry]
|
||||
name = "bancho-py"
|
||||
version = "5.2.2"
|
||||
description = "An osu! server implementation optimized for maintainability in modern python"
|
||||
authors = ["Akatsuki Team"]
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
packages = [{ include = "bancho" }]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
async-timeout = "4.0.3"
|
||||
bcrypt = "4.1.2"
|
||||
datadog = "0.48.0"
|
||||
fastapi = "0.109.2"
|
||||
orjson = "3.9.13"
|
||||
psutil = "5.9.8"
|
||||
python-dotenv = "1.0.1"
|
||||
python-multipart = "0.0.9"
|
||||
requests = "2.31.0"
|
||||
timeago = "1.0.16"
|
||||
uvicorn = "0.27.1"
|
||||
uvloop = { markers = "sys_platform != 'win32'", version = "0.19.0" }
|
||||
winloop = { platform = "win32", version = "0.1.1" }
|
||||
py3rijndael = "0.3.3"
|
||||
pytimeparse = "1.1.8"
|
||||
pydantic = "2.6.1"
|
||||
redis = { extras = ["hiredis"], version = "5.0.1" }
|
||||
sqlalchemy = ">=1.4.42,<1.5"
|
||||
akatsuki-pp-py = "1.0.5"
|
||||
cryptography = "42.0.2"
|
||||
tenacity = "8.2.3"
|
||||
httpx = "0.26.0"
|
||||
py-cpuinfo = "9.0.0"
|
||||
pytest = "8.0.0"
|
||||
pytest-asyncio = "0.23.5"
|
||||
asgi-lifespan = "2.1.0"
|
||||
respx = "0.20.2"
|
||||
tzdata = "2024.1"
|
||||
coverage = "^7.4.1"
|
||||
databases = { version = "^0.8.0", extras = ["mysql"] }
|
||||
python-json-logger = "^2.0.7"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "3.6.1"
|
||||
black = "24.1.1"
|
||||
isort = "5.13.2"
|
||||
autoflake = "2.2.1"
|
||||
types-psutil = "5.9.5.20240205"
|
||||
types-pymysql = "1.1.0.1"
|
||||
types-requests = "2.31.0.20240125"
|
||||
mypy = "1.8.0"
|
||||
types-pyyaml = "^6.0.12.12"
|
||||
sqlalchemy2-stubs = "^0.0.2a38"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
Reference in New Issue
Block a user