Add files via upload

This commit is contained in:
purr
2025-04-04 21:30:31 +09:00
committed by GitHub
parent 5763658177
commit 966e7691a3
90 changed files with 20938 additions and 0 deletions

1
CODEOWNERS Normal file
View File

@@ -0,0 +1 @@
* @cmyui @kingdom5500 @NiceAesth @tsunyoku @7mochi

22
Dockerfile Normal file
View File

@@ -0,0 +1,22 @@
FROM python:3.11-slim
ENV PYTHONUNBUFFERED=1
WORKDIR /srv/root
RUN apt update && apt install --no-install-recommends -y \
git curl build-essential=12.9 \
&& rm -rf /var/lib/apt/lists/*
COPY pyproject.toml poetry.lock ./
RUN pip install -U pip poetry
RUN poetry config virtualenvs.create false
RUN poetry install --no-root
RUN apt update && \
apt install -y default-mysql-client redis-tools
# NOTE: done last to avoid re-run of previous steps
COPY . .
ENTRYPOINT [ "scripts/start_server.sh" ]

47
Makefile Normal file
View File

@@ -0,0 +1,47 @@
#!/usr/bin/env make
build:
if [ -d ".dbdata" ]; then sudo chmod -R 755 .dbdata; fi
docker build -t bancho:latest .
run:
docker compose up bancho mysql redis
run-bg:
docker compose up -d bancho mysql redis
run-caddy:
caddy run --envfile .env --config ext/Caddyfile
last?=1
logs:
docker compose logs -f bancho mysql redis --tail ${last}
shell:
poetry shell
test:
docker compose -f docker-compose.test.yml up -d bancho-test mysql-test redis-test
docker compose -f docker-compose.test.yml exec -T bancho-test /srv/root/scripts/run-tests.sh
lint:
poetry run pre-commit run --all-files
type-check:
poetry run mypy .
install:
POETRY_VIRTUALENVS_IN_PROJECT=1 poetry install --no-root
install-dev:
POETRY_VIRTUALENVS_IN_PROJECT=1 poetry install --no-root --with dev
poetry run pre-commit install
uninstall:
poetry env remove python
# To bump the version number run `make bump version=<major/minor/patch>`
# (DO NOT USE IF YOU DON'T KNOW WHAT YOU'RE DOING)
# https://python-poetry.org/docs/cli/#version
bump:
poetry version $(version)

23
README_CN.md Normal file
View File

@@ -0,0 +1,23 @@
# bancho.py - 中文文档
[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/osuAkatsuki/bancho.py/master.svg)](https://results.pre-commit.ci/latest/github/osuAkatsuki/bancho.py/master)
[![Discord](https://discordapp.com/api/guilds/748687781605408908/widget.png?style=shield)](https://discord.gg/ShEQgUx)
The English version: [[English]](https://github.com/osuAkatsuki/bancho.py/blob/master/README.md)
这是中文翻译哦~由 [hedgehog-qd](https://github.com/hedgehog-qd) 在根据原英语文档部署成功后翻译的。这里
我根据我当时遇到的问题补充了一些提示,如有错误请指正,谢谢!
bancho.py 是一个还在被不断维护的osu!后端项目,不论你的水平如何,都
可以去使用他来开一个自己的osu!私服!
这个项目最初是由 [Akatsuki](https://akatsuki.pw/) 团队开发的,我们的目标是创建一个非常容易
维护并且功能很丰富的osu!私服的服务端!
注意bancho.py是一个后端当你跟着下面的步骤部署完成后你可以正常登录
并游玩。这个项目自带api但是没有前端(就是网页),前端的话你也可以去看
他们团队开发的前端项目。
api文档(英语)<https://github.com/JKBGL/gulag-api-docs>
前端(guweb)<https://github.com/Varkaria/guweb>

14
README_DE.MD Normal file
View File

@@ -0,0 +1,14 @@
# bancho.py
[![Python 3.11+](https://img.shields.io/badge/python-3.11+-blue.svg)](https://www.python.org/downloads/)
[![Code Stil: schwarz](https://img.shields.io/badge/Code%20Stil-Schwarz-black)](https://github.com/ambv/black)
[![pre-commit.ci Status](https://results.pre-commit.ci/badge/github/osuAkatsuki/bancho.py/master.svg)](https://results.pre-commit.ci/latest/github/osuAkatsuki/bancho.py/master)
[![Discord](https://discordapp.com/api/guilds/748687781605408908/widget.png?style=shield)](https://discord.gg/ShEQgUx)
bancho.py ist eine in Arbeit befindliche osu!-Server-Implementierung für
Entwickler aller Erfahrungsstufen, die daran interessiert sind, ihre eigene(n)
private(n) osu-Server-Instanz(en) zu hosten
Das Projekt wird hauptsächlich vom [Akatsuki](https://akatsuki.pw/)-Team entwickelt,
und unser Ziel ist es, die am einfachsten zu wartende, zuverlässigste und
funktionsreichste osu!-Server-Implementierung auf dem Markt zu schaffen.

13
app/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# type: ignore
# isort: dont-add-imports
from . import api
from . import bg_loops
from . import commands
from . import constants
from . import discord
from . import logging
from . import objects
from . import packets
from . import state
from . import utils

27
app/_typing.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
from ipaddress import IPv4Address
from ipaddress import IPv6Address
from typing import Any
from typing import TypeVar
T = TypeVar("T")
IPAddress = IPv4Address | IPv6Address
class _UnsetSentinel:
def __repr__(self) -> str:
return "Unset"
def __copy__(self: T) -> T:
return self
def __reduce__(self) -> str:
return "Unset"
def __deepcopy__(self: T, _: Any) -> T:
return self
UNSET = _UnsetSentinel()

0
app/adapters/__init__.py Normal file
View File

164
app/adapters/database.py Normal file
View File

@@ -0,0 +1,164 @@
from __future__ import annotations
from typing import Any
from typing import cast
from databases import Database as _Database
from databases.core import Transaction
from sqlalchemy.dialects.mysql.mysqldb import MySQLDialect_mysqldb
from sqlalchemy.sql.compiler import Compiled
from sqlalchemy.sql.expression import ClauseElement
from app import settings
from app.logging import log
from app.timer import Timer
class MySQLDialect(MySQLDialect_mysqldb):
default_paramstyle = "named"
DIALECT = MySQLDialect()
MySQLRow = dict[str, Any]
MySQLParams = dict[str, Any] | None
MySQLQuery = ClauseElement | str
class Database:
def __init__(self, url: str) -> None:
self._database = _Database(url)
async def connect(self) -> None:
await self._database.connect()
async def disconnect(self) -> None:
await self._database.disconnect()
def _compile(self, clause_element: ClauseElement) -> tuple[str, MySQLParams]:
compiled: Compiled = clause_element.compile(
dialect=DIALECT,
compile_kwargs={"render_postcompile": True},
)
return str(compiled), compiled.params
async def fetch_one(
self,
query: MySQLQuery,
params: MySQLParams = None,
) -> MySQLRow | None:
if isinstance(query, ClauseElement):
query, params = self._compile(query)
with Timer() as timer:
row = await self._database.fetch_one(query, params)
if settings.DEBUG:
time_elapsed = timer.elapsed()
log(
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
extra={
"query": query,
"params": params,
"time_elapsed": time_elapsed,
},
)
return dict(row._mapping) if row is not None else None
async def fetch_all(
self,
query: MySQLQuery,
params: MySQLParams = None,
) -> list[MySQLRow]:
if isinstance(query, ClauseElement):
query, params = self._compile(query)
with Timer() as timer:
rows = await self._database.fetch_all(query, params)
if settings.DEBUG:
time_elapsed = timer.elapsed()
log(
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
extra={
"query": query,
"params": params,
"time_elapsed": time_elapsed,
},
)
return [dict(row._mapping) for row in rows]
async def fetch_val(
self,
query: MySQLQuery,
params: MySQLParams = None,
column: Any = 0,
) -> Any:
if isinstance(query, ClauseElement):
query, params = self._compile(query)
with Timer() as timer:
val = await self._database.fetch_val(query, params, column)
if settings.DEBUG:
time_elapsed = timer.elapsed()
log(
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
extra={
"query": query,
"params": params,
"time_elapsed": time_elapsed,
},
)
return val
async def execute(self, query: MySQLQuery, params: MySQLParams = None) -> int:
if isinstance(query, ClauseElement):
query, params = self._compile(query)
with Timer() as timer:
rec_id = await self._database.execute(query, params)
if settings.DEBUG:
time_elapsed = timer.elapsed()
log(
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
extra={
"query": query,
"params": params,
"time_elapsed": time_elapsed,
},
)
return cast(int, rec_id)
# NOTE: this accepts str since current execute_many uses are not using alchemy.
# alchemy does execute_many in a single query so this method will be unneeded once raw SQL is not in use.
async def execute_many(self, query: str, params: list[MySQLParams]) -> None:
if isinstance(query, ClauseElement):
query, _ = self._compile(query)
with Timer() as timer:
await self._database.execute_many(query, params)
if settings.DEBUG:
time_elapsed = timer.elapsed()
log(
f"Executed SQL query: {query} {params} in {time_elapsed * 1000:.2f} msec.",
extra={
"query": query,
"params": params,
"time_elapsed": time_elapsed,
},
)
def transaction(
self,
*,
force_rollback: bool = False,
**kwargs: Any,
) -> Transaction:
return self._database.transaction(force_rollback=force_rollback, **kwargs)

16
app/api/__init__.py Normal file
View File

@@ -0,0 +1,16 @@
# type: ignore
# isort: dont-add-imports
from fastapi import APIRouter
from .v1 import apiv1_router
from .v2 import apiv2_router
api_router = APIRouter()
api_router.include_router(apiv1_router)
api_router.include_router(apiv2_router)
from . import domains
from . import init_api
from . import middlewares

View File

@@ -0,0 +1,5 @@
# isort: dont-add-imports
from . import cho
from . import map
from . import osu

2227
app/api/domains/cho.py Normal file

File diff suppressed because it is too large Load Diff

22
app/api/domains/map.py Normal file
View File

@@ -0,0 +1,22 @@
"""bmap: static beatmap info (thumbnails, previews, etc.)"""
from __future__ import annotations
from fastapi import APIRouter
from fastapi import status
from fastapi.requests import Request
from fastapi.responses import RedirectResponse
# import app.settings
router = APIRouter(tags=["Beatmaps"])
# forward any unmatched request to osu!
# eventually if we do bmap submission, we'll need this.
@router.get("/{file_path:path}")
async def everything(request: Request) -> RedirectResponse:
return RedirectResponse(
url=f"https://b.ppy.sh{request['path']}",
status_code=status.HTTP_301_MOVED_PERMANENTLY,
)

1786
app/api/domains/osu.py Normal file

File diff suppressed because it is too large Load Diff

196
app/api/init_api.py Normal file
View File

@@ -0,0 +1,196 @@
# #!/usr/bin/env python3.11
from __future__ import annotations
import asyncio
import io
import pprint
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any
import starlette.routing
from fastapi import FastAPI
from fastapi import status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from fastapi.openapi.utils import get_openapi
from fastapi.requests import Request
from fastapi.responses import ORJSONResponse
from fastapi.responses import Response
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import ClientDisconnect
import app.bg_loops
import app.settings
import app.state
import app.utils
from app.api import api_router # type: ignore[attr-defined]
from app.api import domains
from app.api import middlewares
from app.logging import Ansi
from app.logging import log
from app.objects import collections
class BanchoAPI(FastAPI):
def openapi(self) -> dict[str, Any]:
if not self.openapi_schema:
routes = self.routes
starlette_hosts = [
host
for host in super().routes
if isinstance(host, starlette.routing.Host)
]
# XXX:HACK fastapi will not show documentation for routes
# added through use sub applications using the Host class
# (e.g. app.host('other.domain', app2))
for host in starlette_hosts:
for route in host.routes:
if route not in routes:
routes.append(route)
self.openapi_schema = get_openapi(
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
description=self.description,
terms_of_service=self.terms_of_service,
contact=self.contact,
license_info=self.license_info,
routes=routes,
tags=self.openapi_tags,
servers=self.servers,
)
return self.openapi_schema
@asynccontextmanager
async def lifespan(asgi_app: BanchoAPI) -> AsyncIterator[None]:
if isinstance(sys.stdout, io.TextIOWrapper):
sys.stdout.reconfigure(encoding="utf-8")
app.utils.ensure_persistent_volumes_are_available()
app.state.loop = asyncio.get_running_loop()
if app.utils.is_running_as_admin():
log(
"Running the server with root privileges is not recommended.",
Ansi.LYELLOW,
)
await app.state.services.database.connect()
await app.state.services.redis.initialize()
if app.state.services.datadog is not None:
app.state.services.datadog.start(
flush_in_thread=True,
flush_interval=15,
)
app.state.services.datadog.gauge("bancho.online_players", 0)
app.state.services.ip_resolver = app.state.services.IPResolver()
await app.state.services.run_sql_migrations()
await collections.initialize_ram_caches()
await app.bg_loops.initialize_housekeeping_tasks()
log("Startup process complete.", Ansi.LGREEN)
log(
f"Listening @ {app.settings.APP_HOST}:{app.settings.APP_PORT}",
Ansi.LMAGENTA,
)
yield
# we want to attempt to gracefully finish any ongoing connections
# and shut down any of the housekeeping tasks running in the background.
await app.state.sessions.cancel_housekeeping_tasks()
# shutdown services
await app.state.services.http_client.aclose()
await app.state.services.database.disconnect()
await app.state.services.redis.aclose()
if app.state.services.datadog is not None:
app.state.services.datadog.stop()
app.state.services.datadog.flush()
def init_exception_handlers(asgi_app: BanchoAPI) -> None:
@asgi_app.exception_handler(RequestValidationError)
async def handle_validation_error(
request: Request,
exc: RequestValidationError,
) -> Response:
"""Wrapper around 422 validation errors to print out info for devs."""
log(f"Validation error on {request.url}", Ansi.LRED)
pprint.pprint(exc.errors())
return ORJSONResponse(
content={"detail": jsonable_encoder(exc.errors())},
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
)
def init_middlewares(asgi_app: BanchoAPI) -> None:
"""Initialize our app's middleware stack."""
asgi_app.add_middleware(middlewares.MetricsMiddleware)
@asgi_app.middleware("http")
async def http_middleware(
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
# if an osu! client is waiting on leaderboard data
# and switches to another leaderboard, it will cancel
# the previous request midway, resulting in a large
# error in the console. this is to catch that :)
try:
return await call_next(request)
except ClientDisconnect:
# client disconnected from the server
# while we were reading the body.
return Response("Client disconnected while reading request.")
except RuntimeError as exc:
if exc.args[0] == "No response returned.":
# client disconnected from the server
# while we were sending the response.
return Response("Client returned empty response.")
# unrelated issue, raise normally
raise exc
def init_routes(asgi_app: BanchoAPI) -> None:
"""Initialize our app's route endpoints."""
for domain in ("ppy.sh", app.settings.DOMAIN):
for subdomain in ("c", "ce", "c4", "c5", "c6"):
asgi_app.host(f"{subdomain}.{domain}", domains.cho.router)
asgi_app.host(f"osu.{domain}", domains.osu.router)
asgi_app.host(f"b.{domain}", domains.map.router)
# bancho.py's developer-facing api
asgi_app.host(f"api.{domain}", api_router)
def init_api() -> BanchoAPI:
"""Create & initialize our app."""
asgi_app = BanchoAPI(lifespan=lifespan)
init_middlewares(asgi_app)
init_exception_handlers(asgi_app)
init_routes(asgi_app)
return asgi_app
asgi_app = init_api()

37
app/api/middlewares.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import time
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from app.logging import Ansi
from app.logging import log
from app.logging import magnitude_fmt_time
class MetricsMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
start_time = time.perf_counter_ns()
response = await call_next(request)
end_time = time.perf_counter_ns()
time_elapsed = end_time - start_time
col = Ansi.LGREEN if response.status_code < 400 else Ansi.LRED
url = f"{request.headers['host']}{request['path']}"
log(
f"[{request.method}] {response.status_code} {url}{Ansi.RESET!r} | {Ansi.LBLUE!r}Request took: {magnitude_fmt_time(time_elapsed)}",
col,
)
response.headers["process-time"] = str(round(time_elapsed) / 1e6)
return response

10
app/api/v1/__init__.py Normal file
View File

@@ -0,0 +1,10 @@
# type: ignore
# isort: dont-add-imports
from fastapi import APIRouter
from .api import router
apiv1_router = APIRouter(tags=["API v1"], prefix="/v1")
apiv1_router.include_router(router)

1080
app/api/v1/api.py Normal file

File diff suppressed because it is too large Load Diff

15
app/api/v2/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
# isort: dont-add-imports
from fastapi import APIRouter
from . import clans
from . import maps
from . import players
from . import scores
apiv2_router = APIRouter(tags=["API v2"], prefix="/v2")
apiv2_router.include_router(clans.router)
apiv2_router.include_router(maps.router)
apiv2_router.include_router(players.router)
apiv2_router.include_router(scores.router)

50
app/api/v2/clans.py Normal file
View File

@@ -0,0 +1,50 @@
"""bancho.py's v2 apis for interacting with clans"""
from __future__ import annotations
from fastapi import APIRouter
from fastapi import status
from fastapi.param_functions import Query
from app.api.v2.common import responses
from app.api.v2.common.responses import Failure
from app.api.v2.common.responses import Success
from app.api.v2.models.clans import Clan
from app.repositories import clans as clans_repo
router = APIRouter()
@router.get("/clans")
async def get_clans(
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
) -> Success[list[Clan]] | Failure:
clans = await clans_repo.fetch_many(
page=page,
page_size=page_size,
)
total_clans = await clans_repo.fetch_count()
response = [Clan.from_mapping(rec) for rec in clans]
return responses.success(
content=response,
meta={
"total": total_clans,
"page": page,
"page_size": page_size,
},
)
@router.get("/clans/{clan_id}")
async def get_clan(clan_id: int) -> Success[Clan] | Failure:
data = await clans_repo.fetch_one(id=clan_id)
if data is None:
return responses.failure(
message="Clan not found.",
status_code=status.HTTP_404_NOT_FOUND,
)
response = Clan.from_mapping(data)
return responses.success(response)

29
app/api/v2/common/json.py Normal file
View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from typing import Any
import orjson
from fastapi.responses import JSONResponse
from pydantic import BaseModel
def _default_processor(data: Any) -> Any:
if isinstance(data, BaseModel):
return _default_processor(data.dict())
elif isinstance(data, dict):
return {k: _default_processor(v) for k, v in data.items()}
elif isinstance(data, list):
return [_default_processor(v) for v in data]
else:
return data
def dumps(data: Any) -> bytes:
return orjson.dumps(data, default=_default_processor)
class ORJSONResponse(JSONResponse):
media_type = "application/json"
def render(self, content: Any) -> bytes:
return dumps(content)

View File

@@ -0,0 +1,47 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import Literal
from typing import TypeVar
from typing import cast
from pydantic import BaseModel
from app.api.v2.common import json
T = TypeVar("T")
class Success(BaseModel, Generic[T]):
status: Literal["success"]
data: T
meta: dict[str, Any]
def success(
content: T,
status_code: int = 200,
headers: dict[str, Any] | None = None,
meta: dict[str, Any] | None = None,
) -> Success[T]:
if meta is None:
meta = {}
data = {"status": "success", "data": content, "meta": meta}
# XXX:HACK to make typing work
return cast(Success[T], json.ORJSONResponse(data, status_code, headers))
class Failure(BaseModel):
status: Literal["error"]
error: str
def failure(
message: str,
status_code: int = 400,
headers: dict[str, Any] | None = None,
) -> Failure:
data = {"status": "error", "error": message}
# XXX:HACK to make typing work
return cast(Failure, json.ORJSONResponse(data, status_code, headers))

76
app/api/v2/maps.py Normal file
View File

@@ -0,0 +1,76 @@
"""bancho.py's v2 apis for interacting with maps"""
from __future__ import annotations
from fastapi import APIRouter
from fastapi import status
from fastapi.param_functions import Query
from app.api.v2.common import responses
from app.api.v2.common.responses import Failure
from app.api.v2.common.responses import Success
from app.api.v2.models.maps import Map
from app.repositories import maps as maps_repo
router = APIRouter()
@router.get("/maps")
async def get_maps(
set_id: int | None = None,
server: str | None = None,
status: int | None = None,
artist: str | None = None,
creator: str | None = None,
filename: str | None = None,
mode: int | None = None,
frozen: bool | None = None,
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
) -> Success[list[Map]] | Failure:
maps = await maps_repo.fetch_many(
server=server,
set_id=set_id,
status=status,
artist=artist,
creator=creator,
filename=filename,
mode=mode,
frozen=frozen,
page=page,
page_size=page_size,
)
total_maps = await maps_repo.fetch_count(
server=server,
set_id=set_id,
status=status,
artist=artist,
creator=creator,
filename=filename,
mode=mode,
frozen=frozen,
)
response = [Map.from_mapping(rec) for rec in maps]
return responses.success(
content=response,
meta={
"total": total_maps,
"page": page,
"page_size": page_size,
},
)
@router.get("/maps/{map_id}")
async def get_map(map_id: int) -> Success[Map] | Failure:
data = await maps_repo.fetch_one(id=map_id)
if data is None:
return responses.failure(
message="Map not found.",
status_code=status.HTTP_404_NOT_FOUND,
)
response = Map.from_mapping(data)
return responses.success(response)

View File

@@ -0,0 +1,18 @@
# isort: dont-add-imports
from collections.abc import Mapping
from typing import Any
from typing import TypeVar
from pydantic import BaseModel as _pydantic_BaseModel
from pydantic import ConfigDict
T = TypeVar("T", bound="BaseModel")
class BaseModel(_pydantic_BaseModel):
model_config = ConfigDict(str_strip_whitespace=True)
@classmethod
def from_mapping(cls: type[T], mapping: Mapping[str, Any]) -> T:
return cls(**{k: mapping[k] for k in cls.model_fields})

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from datetime import datetime
from . import BaseModel
# input models
# output models
class Clan(BaseModel):
id: int
name: str
tag: str
owner: int
created_at: datetime

36
app/api/v2/models/maps.py Normal file
View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from datetime import datetime
from . import BaseModel
# input models
# output models
class Map(BaseModel):
id: int
server: str
set_id: int
status: int
md5: str
artist: str
title: str
version: str
creator: str
filename: str
last_update: datetime
total_length: int
max_combo: int
frozen: bool
plays: int
passes: int
mode: int
bpm: float
cs: float
ar: float
od: float
hp: float
diff: float

View File

@@ -0,0 +1,60 @@
from __future__ import annotations
from . import BaseModel
# input models
# output models
class Player(BaseModel):
id: int
name: str
safe_name: str
priv: int
country: str
silence_end: int
donor_end: int
creation_time: int
latest_activity: int
clan_id: int
clan_priv: int
preferred_mode: int
play_style: int
custom_badge_name: str | None
custom_badge_icon: str | None
userpage_content: str | None
class PlayerStatus(BaseModel):
login_time: int
action: int
info_text: str
mode: int
mods: int
beatmap_id: int
class PlayerStats(BaseModel):
id: int
mode: int
tscore: int
rscore: int
pp: float
plays: int
playtime: int
acc: float
max_combo: int
total_hits: int
replay_views: int
xh_count: int
x_count: int
sh_count: int
s_count: int
a_count: int

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from datetime import datetime
from . import BaseModel
# input models
# output models
class Score(BaseModel):
id: int
map_md5: str
userid: int
score: int
pp: float
acc: float
max_combo: int
mods: int
n300: int
n100: int
n50: int
nmiss: int
nkatu: int
grade: str
status: int
mode: int
play_time: datetime
time_elapsed: int
perfect: bool

137
app/api/v2/players.py Normal file
View File

@@ -0,0 +1,137 @@
"""bancho.py's v2 apis for interacting with players"""
from __future__ import annotations
from fastapi import APIRouter
from fastapi import status
from fastapi.param_functions import Query
import app.state.sessions
from app.api.v2.common import responses
from app.api.v2.common.responses import Failure
from app.api.v2.common.responses import Success
from app.api.v2.models.players import Player
from app.api.v2.models.players import PlayerStats
from app.api.v2.models.players import PlayerStatus
from app.repositories import stats as stats_repo
from app.repositories import users as users_repo
router = APIRouter()
@router.get("/players")
async def get_players(
priv: int | None = None,
country: str | None = None,
clan_id: int | None = None,
clan_priv: int | None = None,
preferred_mode: int | None = None,
play_style: int | None = None,
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
) -> Success[list[Player]] | Failure:
players = await users_repo.fetch_many(
priv=priv,
country=country,
clan_id=clan_id,
clan_priv=clan_priv,
preferred_mode=preferred_mode,
play_style=play_style,
page=page,
page_size=page_size,
)
total_players = await users_repo.fetch_count(
priv=priv,
country=country,
clan_id=clan_id,
clan_priv=clan_priv,
preferred_mode=preferred_mode,
play_style=play_style,
)
response = [Player.from_mapping(rec) for rec in players]
return responses.success(
content=response,
meta={
"total": total_players,
"page": page,
"page_size": page_size,
},
)
@router.get("/players/{player_id}")
async def get_player(player_id: int) -> Success[Player] | Failure:
data = await users_repo.fetch_one(id=player_id)
if data is None:
return responses.failure(
message="Player not found.",
status_code=status.HTTP_404_NOT_FOUND,
)
response = Player.from_mapping(data)
return responses.success(response)
@router.get("/players/{player_id}/status")
async def get_player_status(player_id: int) -> Success[PlayerStatus] | Failure:
player = app.state.sessions.players.get(id=player_id)
if not player:
return responses.failure(
message="Player status not found.",
status_code=status.HTTP_404_NOT_FOUND,
)
response = PlayerStatus(
login_time=int(player.login_time),
action=int(player.status.action),
info_text=player.status.info_text,
mode=int(player.status.mode),
mods=int(player.status.mods),
beatmap_id=player.status.map_id,
)
return responses.success(response)
@router.get("/players/{player_id}/stats/{mode}")
async def get_player_mode_stats(
player_id: int,
mode: int,
) -> Success[PlayerStats] | Failure:
data = await stats_repo.fetch_one(player_id, mode)
if data is None:
return responses.failure(
message="Player stats not found.",
status_code=status.HTTP_404_NOT_FOUND,
)
response = PlayerStats.from_mapping(data)
return responses.success(response)
@router.get("/players/{player_id}/stats")
async def get_player_stats(
player_id: int,
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
) -> Success[list[PlayerStats]] | Failure:
data = await stats_repo.fetch_many(
player_id=player_id,
page=page,
page_size=page_size,
)
total_stats = await stats_repo.fetch_count(
player_id=player_id,
)
response = [PlayerStats.from_mapping(rec) for rec in data]
return responses.success(
response,
meta={
"total": total_stats,
"page": page,
"page_size": page_size,
},
)

67
app/api/v2/scores.py Normal file
View File

@@ -0,0 +1,67 @@
"""bancho.py's v2 apis for interacting with scores"""
from __future__ import annotations
from fastapi import APIRouter
from fastapi import status
from fastapi.param_functions import Query
from app.api.v2.common import responses
from app.api.v2.common.responses import Failure
from app.api.v2.common.responses import Success
from app.api.v2.models.scores import Score
from app.repositories import scores as scores_repo
router = APIRouter()
@router.get("/scores")
async def get_all_scores(
map_md5: str | None = None,
mods: int | None = None,
status: int | None = None,
mode: int | None = None,
user_id: int | None = None,
page: int = Query(1, ge=1),
page_size: int = Query(50, ge=1, le=100),
) -> Success[list[Score]] | Failure:
scores = await scores_repo.fetch_many(
map_md5=map_md5,
mods=mods,
status=status,
mode=mode,
user_id=user_id,
page=page,
page_size=page_size,
)
total_scores = await scores_repo.fetch_count(
map_md5=map_md5,
mods=mods,
status=status,
mode=mode,
user_id=user_id,
)
response = [Score.from_mapping(rec) for rec in scores]
return responses.success(
content=response,
meta={
"total": total_scores,
"page": page,
"page_size": page_size,
},
)
@router.get("/scores/{score_id}")
async def get_score(score_id: int) -> Success[Score] | Failure:
data = await scores_repo.fetch_one(id=score_id)
if data is None:
return responses.failure(
message="Score not found.",
status_code=status.HTTP_404_NOT_FOUND,
)
response = Score.from_mapping(data)
return responses.success(response)

89
app/bg_loops.py Normal file
View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import asyncio
import time
import app.packets
import app.settings
import app.state
from app.constants.privileges import Privileges
from app.logging import Ansi
from app.logging import log
OSU_CLIENT_MIN_PING_INTERVAL = 300000 // 1000 # defined by osu!
async def initialize_housekeeping_tasks() -> None:
"""Create tasks for each housekeeping tasks."""
log("Initializing housekeeping tasks.", Ansi.LCYAN)
loop = asyncio.get_running_loop()
app.state.sessions.housekeeping_tasks.update(
{
loop.create_task(task)
for task in (
_remove_expired_donation_privileges(interval=30 * 60),
_update_bot_status(interval=5 * 60),
_disconnect_ghosts(interval=OSU_CLIENT_MIN_PING_INTERVAL // 3),
)
},
)
async def _remove_expired_donation_privileges(interval: int) -> None:
"""Remove donation privileges from users with expired sessions."""
while True:
if app.settings.DEBUG:
log("Removing expired donation privileges.", Ansi.LMAGENTA)
expired_donors = await app.state.services.database.fetch_all(
"SELECT id FROM users "
"WHERE donor_end <= UNIX_TIMESTAMP() "
"AND priv & :donor_priv",
{"donor_priv": Privileges.DONATOR.value},
)
for expired_donor in expired_donors:
player = await app.state.sessions.players.from_cache_or_sql(
id=expired_donor["id"],
)
assert player is not None
# TODO: perhaps make a `revoke_donor` method?
await player.remove_privs(Privileges.DONATOR)
player.donor_end = 0
await app.state.services.database.execute(
"UPDATE users SET donor_end = 0 WHERE id = :id",
{"id": player.id},
)
if player.is_online:
player.enqueue(
app.packets.notification("Your supporter status has expired."),
)
log(f"{player}'s supporter status has expired.", Ansi.LMAGENTA)
await asyncio.sleep(interval)
async def _disconnect_ghosts(interval: int) -> None:
"""Actively disconnect users above the
disconnection time threshold on the osu! server."""
while True:
await asyncio.sleep(interval)
current_time = time.time()
for player in app.state.sessions.players:
if current_time - player.last_recv_time > OSU_CLIENT_MIN_PING_INTERVAL:
log(f"Auto-dced {player}.", Ansi.LMAGENTA)
player.logout()
async def _update_bot_status(interval: int) -> None:
"""Re roll the bot status, every `interval`."""
while True:
await asyncio.sleep(interval)
app.packets.bot_stats.cache_clear()

2533
app/commands.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,8 @@
# type: ignore
# isort: dont-add-imports
from . import clientflags
from . import gamemodes
from . import mods
from . import privileges
from . import regexes

View File

@@ -0,0 +1,68 @@
from __future__ import annotations
from enum import IntFlag
from enum import unique
from app.utils import escape_enum
from app.utils import pymysql_encode
@unique
@pymysql_encode(escape_enum)
class ClientFlags(IntFlag):
"""osu! anticheat <= 2016 (unsure of age)"""
# NOTE: many of these flags are quite outdated and/or
# broken and are even known to false positive quite often.
# they can be helpful; just take them with a grain of salt.
CLEAN = 0 # no flags sent
# flags for timing errors or desync.
SPEED_HACK_DETECTED = 1 << 1
# this is to be ignored by server implementations. osu! team trolling hard
INCORRECT_MOD_VALUE = 1 << 2
MULTIPLE_OSU_CLIENTS = 1 << 3
CHECKSUM_FAILURE = 1 << 4
FLASHLIGHT_CHECKSUM_INCORRECT = 1 << 5
# these are only used on the osu!bancho official server.
OSU_EXECUTABLE_CHECKSUM = 1 << 6
MISSING_PROCESSES_IN_LIST = 1 << 7 # also deprecated as of 2018
# flags for either:
# 1. pixels that should be outside the visible radius
# (and thus black) being brighter than they should be.
# 2. from an internal alpha value being incorrect.
FLASHLIGHT_IMAGE_HACK = 1 << 8
SPINNER_HACK = 1 << 9
TRANSPARENT_WINDOW = 1 << 10
# (mania) flags for consistently low press intervals.
FAST_PRESS = 1 << 11
# from my experience, pretty decent
# for detecting autobotted scores.
RAW_MOUSE_DISCREPANCY = 1 << 12
RAW_KEYBOARD_DISCREPANCY = 1 << 13
@unique
@pymysql_encode(escape_enum)
class LastFMFlags(IntFlag):
"""osu! anticheat 2019"""
# XXX: the aqn flags were fixed within hours of the osu!
# update, and vanilla hq is not so widely used anymore.
RUN_WITH_LD_FLAG = 1 << 14
CONSOLE_OPEN = 1 << 15
EXTRA_THREADS = 1 << 16
HQ_ASSEMBLY = 1 << 17
HQ_FILE = 1 << 18
REGISTRY_EDITS = 1 << 19
SDL2_LIBRARY = 1 << 20
OPENSSL_LIBRARY = 1 << 21
AQN_MENU_SAMPLE = 1 << 22

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
import functools
from enum import IntEnum
from enum import unique
from app.constants.mods import Mods
from app.utils import escape_enum
from app.utils import pymysql_encode
GAMEMODE_REPR_LIST = (
"vn!std",
"vn!taiko",
"vn!catch",
"vn!mania",
"rx!std",
"rx!taiko",
"rx!catch",
"rx!mania", # unused
"ap!std",
"ap!taiko", # unused
"ap!catch", # unused
"ap!mania", # unused
)
@unique
@pymysql_encode(escape_enum)
class GameMode(IntEnum):
VANILLA_OSU = 0
VANILLA_TAIKO = 1
VANILLA_CATCH = 2
VANILLA_MANIA = 3
RELAX_OSU = 4
RELAX_TAIKO = 5
RELAX_CATCH = 6
RELAX_MANIA = 7 # unused
AUTOPILOT_OSU = 8
AUTOPILOT_TAIKO = 9 # unused
AUTOPILOT_CATCH = 10 # unused
AUTOPILOT_MANIA = 11 # unused
@classmethod
def from_params(cls, mode_vn: int, mods: Mods) -> GameMode:
mode = mode_vn
if mods & Mods.AUTOPILOT:
mode += 8
elif mods & Mods.RELAX:
mode += 4
return cls(mode)
@classmethod
@functools.cache
def valid_gamemodes(cls) -> list[GameMode]:
ret = []
for mode in cls:
if mode not in (
cls.RELAX_MANIA,
cls.AUTOPILOT_TAIKO,
cls.AUTOPILOT_CATCH,
cls.AUTOPILOT_MANIA,
):
ret.append(mode)
return ret
@property
def as_vanilla(self) -> int:
return self.value % 4
def __repr__(self) -> str:
return GAMEMODE_REPR_LIST[self.value]

296
app/constants/mods.py Normal file
View File

@@ -0,0 +1,296 @@
from __future__ import annotations
import functools
from enum import IntFlag
from enum import unique
from app.utils import escape_enum
from app.utils import pymysql_encode
@unique
@pymysql_encode(escape_enum)
class Mods(IntFlag):
NOMOD = 0
NOFAIL = 1 << 0
EASY = 1 << 1
TOUCHSCREEN = 1 << 2 # old: 'NOVIDEO'
HIDDEN = 1 << 3
HARDROCK = 1 << 4
SUDDENDEATH = 1 << 5
DOUBLETIME = 1 << 6
RELAX = 1 << 7
HALFTIME = 1 << 8
NIGHTCORE = 1 << 9
FLASHLIGHT = 1 << 10
AUTOPLAY = 1 << 11
SPUNOUT = 1 << 12
AUTOPILOT = 1 << 13
PERFECT = 1 << 14
KEY4 = 1 << 15
KEY5 = 1 << 16
KEY6 = 1 << 17
KEY7 = 1 << 18
KEY8 = 1 << 19
FADEIN = 1 << 20
RANDOM = 1 << 21
CINEMA = 1 << 22
TARGET = 1 << 23
KEY9 = 1 << 24
KEYCOOP = 1 << 25
KEY1 = 1 << 26
KEY3 = 1 << 27
KEY2 = 1 << 28
SCOREV2 = 1 << 29
MIRROR = 1 << 30
@functools.cache
def __repr__(self) -> str:
if self.value == Mods.NOMOD:
return "NM"
mod_str = []
_dict = mod2modstr_dict # global
for mod in Mods:
if self.value & mod:
mod_str.append(_dict[mod])
return "".join(mod_str)
def filter_invalid_combos(self, mode_vn: int) -> Mods:
"""Remove any invalid mod combinations."""
# 1. mode-inspecific mod conflictions
_dtnc = self & (Mods.DOUBLETIME | Mods.NIGHTCORE)
if _dtnc == (Mods.DOUBLETIME | Mods.NIGHTCORE):
self &= ~Mods.DOUBLETIME # DTNC
elif _dtnc and self & Mods.HALFTIME:
self &= ~Mods.HALFTIME # (DT|NC)HT
if self & Mods.EASY and self & Mods.HARDROCK:
self &= ~Mods.HARDROCK # EZHR
if self & (Mods.NOFAIL | Mods.RELAX | Mods.AUTOPILOT):
if self & Mods.SUDDENDEATH:
self &= ~Mods.SUDDENDEATH # (NF|RX|AP)SD
if self & Mods.PERFECT:
self &= ~Mods.PERFECT # (NF|RX|AP)PF
if self & (Mods.RELAX | Mods.AUTOPILOT):
if self & Mods.NOFAIL:
self &= ~Mods.NOFAIL # (RX|AP)NF
if self & Mods.PERFECT and self & Mods.SUDDENDEATH:
self &= ~Mods.SUDDENDEATH # PFSD
# 2. remove mode-unique mods from incorrect gamemodes
if mode_vn != 0: # osu! specific
self &= ~OSU_SPECIFIC_MODS
# ctb & taiko have no unique mods
if mode_vn != 3: # mania specific
self &= ~MANIA_SPECIFIC_MODS
# 3. mode-specific mod conflictions
if mode_vn == 0:
if self & Mods.AUTOPILOT:
if self & (Mods.SPUNOUT | Mods.RELAX):
self &= ~Mods.AUTOPILOT # (SO|RX)AP
if mode_vn == 3:
self &= ~Mods.RELAX # rx is std/taiko/ctb common
if self & Mods.HIDDEN and self & Mods.FADEIN:
self &= ~Mods.FADEIN # HDFI
# 4 remove multiple keymods
keymods_used = self & KEY_MODS
if bin(keymods_used).count("1") > 1:
# keep only the first
first_keymod = None
for mod in KEY_MODS:
if keymods_used & mod:
first_keymod = mod
break
assert first_keymod is not None
# remove all but the first keymod.
self &= ~(keymods_used & ~first_keymod)
return self
@classmethod
@functools.lru_cache(maxsize=64)
def from_modstr(cls, s: str) -> Mods:
# from fmt: `HDDTRX`
mods = cls.NOMOD
_dict = modstr2mod_dict # global
# split into 2 character chunks
mod_strs = [s[idx : idx + 2].upper() for idx in range(0, len(s), 2)]
# find matching mods
for mod in mod_strs:
if mod not in _dict:
continue
mods |= _dict[mod]
return mods
@classmethod
@functools.lru_cache(maxsize=64)
def from_np(cls, s: str, mode_vn: int) -> Mods:
mods = cls.NOMOD
_dict = npstr2mod_dict # global
for mod in s.split(" "):
if mod not in _dict:
continue
mods |= _dict[mod]
# NOTE: for fetching from /np, we automatically
# call cls.filter_invalid_combos as we assume
# the input string is from user input.
return mods.filter_invalid_combos(mode_vn)
modstr2mod_dict = {
"NF": Mods.NOFAIL,
"EZ": Mods.EASY,
"TD": Mods.TOUCHSCREEN,
"HD": Mods.HIDDEN,
"HR": Mods.HARDROCK,
"SD": Mods.SUDDENDEATH,
"DT": Mods.DOUBLETIME,
"RX": Mods.RELAX,
"HT": Mods.HALFTIME,
"NC": Mods.NIGHTCORE,
"FL": Mods.FLASHLIGHT,
"AU": Mods.AUTOPLAY,
"SO": Mods.SPUNOUT,
"AP": Mods.AUTOPILOT,
"PF": Mods.PERFECT,
"FI": Mods.FADEIN,
"RN": Mods.RANDOM,
"CN": Mods.CINEMA,
"TP": Mods.TARGET,
"V2": Mods.SCOREV2,
"MR": Mods.MIRROR,
"1K": Mods.KEY1,
"2K": Mods.KEY2,
"3K": Mods.KEY3,
"4K": Mods.KEY4,
"5K": Mods.KEY5,
"6K": Mods.KEY6,
"7K": Mods.KEY7,
"8K": Mods.KEY8,
"9K": Mods.KEY9,
"CO": Mods.KEYCOOP,
}
npstr2mod_dict = {
"-NoFail": Mods.NOFAIL,
"-Easy": Mods.EASY,
"+Hidden": Mods.HIDDEN,
"+HardRock": Mods.HARDROCK,
"+SuddenDeath": Mods.SUDDENDEATH,
"+DoubleTime": Mods.DOUBLETIME,
"~Relax~": Mods.RELAX,
"-HalfTime": Mods.HALFTIME,
"+Nightcore": Mods.NIGHTCORE,
"+Flashlight": Mods.FLASHLIGHT,
"|Autoplay|": Mods.AUTOPLAY,
"-SpunOut": Mods.SPUNOUT,
"~Autopilot~": Mods.AUTOPILOT,
"+Perfect": Mods.PERFECT,
"|Cinema|": Mods.CINEMA,
"~Target~": Mods.TARGET,
# perhaps could modify regex
# to only allow these once,
# and only at the end of str?
"|1K|": Mods.KEY1,
"|2K|": Mods.KEY2,
"|3K|": Mods.KEY3,
"|4K|": Mods.KEY4,
"|5K|": Mods.KEY5,
"|6K|": Mods.KEY6,
"|7K|": Mods.KEY7,
"|8K|": Mods.KEY8,
"|9K|": Mods.KEY9,
# XXX: kinda mood that there's no way
# to tell K1-K4 co-op from /np, but
# scores won't submit or anything, so
# it's not ultimately a problem.
"|10K|": Mods.KEY5 | Mods.KEYCOOP,
"|12K|": Mods.KEY6 | Mods.KEYCOOP,
"|14K|": Mods.KEY7 | Mods.KEYCOOP,
"|16K|": Mods.KEY8 | Mods.KEYCOOP,
"|18K|": Mods.KEY9 | Mods.KEYCOOP,
}
mod2modstr_dict = {
Mods.NOFAIL: "NF",
Mods.EASY: "EZ",
Mods.TOUCHSCREEN: "TD",
Mods.HIDDEN: "HD",
Mods.HARDROCK: "HR",
Mods.SUDDENDEATH: "SD",
Mods.DOUBLETIME: "DT",
Mods.RELAX: "RX",
Mods.HALFTIME: "HT",
Mods.NIGHTCORE: "NC",
Mods.FLASHLIGHT: "FL",
Mods.AUTOPLAY: "AU",
Mods.SPUNOUT: "SO",
Mods.AUTOPILOT: "AP",
Mods.PERFECT: "PF",
Mods.FADEIN: "FI",
Mods.RANDOM: "RN",
Mods.CINEMA: "CN",
Mods.TARGET: "TP",
Mods.SCOREV2: "V2",
Mods.MIRROR: "MR",
Mods.KEY1: "1K",
Mods.KEY2: "2K",
Mods.KEY3: "3K",
Mods.KEY4: "4K",
Mods.KEY5: "5K",
Mods.KEY6: "6K",
Mods.KEY7: "7K",
Mods.KEY8: "8K",
Mods.KEY9: "9K",
Mods.KEYCOOP: "CO",
}
KEY_MODS = (
Mods.KEY1
| Mods.KEY2
| Mods.KEY3
| Mods.KEY4
| Mods.KEY5
| Mods.KEY6
| Mods.KEY7
| Mods.KEY8
| Mods.KEY9
)
# FREE_MOD_ALLOWED = (
# Mods.NOFAIL | Mods.EASY | Mods.HIDDEN | Mods.HARDROCK |
# Mods.SUDDENDEATH | Mods.FLASHLIGHT | Mods.FADEIN |
# Mods.RELAX | Mods.AUTOPILOT | Mods.SPUNOUT | KEY_MODS
# )
SCORE_INCREASE_MODS = (
Mods.HIDDEN | Mods.HARDROCK | Mods.FADEIN | Mods.DOUBLETIME | Mods.FLASHLIGHT
)
SPEED_CHANGING_MODS = Mods.DOUBLETIME | Mods.NIGHTCORE | Mods.HALFTIME
OSU_SPECIFIC_MODS = Mods.AUTOPILOT | Mods.SPUNOUT | Mods.TARGET
# taiko & catch have no specific mods
MANIA_SPECIFIC_MODS = Mods.MIRROR | Mods.RANDOM | Mods.FADEIN | KEY_MODS

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from enum import IntEnum
from enum import IntFlag
from enum import unique
from app.utils import escape_enum
from app.utils import pymysql_encode
@unique
@pymysql_encode(escape_enum)
class Privileges(IntFlag):
"""Server side user privileges."""
# privileges intended for all normal players.
UNRESTRICTED = 1 << 0 # is an unbanned player.
VERIFIED = 1 << 1 # has logged in to the server in-game.
# has bypass to low-ceiling anticheat measures (trusted).
WHITELISTED = 1 << 2
# donation tiers, receives some extra benefits.
SUPPORTER = 1 << 4
PREMIUM = 1 << 5
# notable users, receives some extra benefits.
ALUMNI = 1 << 7
# staff permissions, able to manage server app.state.
TOURNEY_MANAGER = 1 << 10 # able to manage match state without host.
NOMINATOR = 1 << 11 # able to manage maps ranked status.
MODERATOR = 1 << 12 # able to manage users (level 1).
ADMINISTRATOR = 1 << 13 # able to manage users (level 2).
DEVELOPER = 1 << 14 # able to manage full server app.state.
DONATOR = SUPPORTER | PREMIUM
STAFF = MODERATOR | ADMINISTRATOR | DEVELOPER
@unique
@pymysql_encode(escape_enum)
class ClientPrivileges(IntFlag):
"""Client side user privileges."""
PLAYER = 1 << 0
MODERATOR = 1 << 1
SUPPORTER = 1 << 2
OWNER = 1 << 3
DEVELOPER = 1 << 4
TOURNAMENT = 1 << 5 # NOTE: not used in communications with osu! client
@unique
@pymysql_encode(escape_enum)
class ClanPrivileges(IntEnum):
"""A class to represent a clan members privs."""
Member = 1
Officer = 2
Owner = 3

23
app/constants/regexes.py Normal file
View File

@@ -0,0 +1,23 @@
from __future__ import annotations
import re
OSU_VERSION = re.compile(
r"^b(?P<date>\d{8})(?:\.(?P<revision>\d))?"
r"(?P<stream>beta|cuttingedge|dev|tourney)?$",
)
USERNAME = re.compile(r"^[\w \[\]-]{2,15}$")
EMAIL = re.compile(r"^[^@\s]{1,200}@[^@\s\.]{1,30}(?:\.[^@\.\s]{2,24})+$")
TOURNEY_MATCHNAME = re.compile(
r"^(?P<name>[a-zA-Z0-9_ ]+): "
r"\((?P<T1>[a-zA-Z0-9_ ]+)\)"
r" vs\.? "
r"\((?P<T2>[a-zA-Z0-9_ ]+)\)$",
flags=re.IGNORECASE,
)
MAPPOOL_PICK = re.compile(r"^([a-zA-Z]+)([0-9]+)$")
BEST_OF = re.compile(r"^(?:bo)?(\d{1,2})$")

173
app/discord.py Normal file
View File

@@ -0,0 +1,173 @@
"""Functionality related to Discord interactivity."""
from __future__ import annotations
from typing import Any
from tenacity import retry
from tenacity import stop_after_attempt
from tenacity import wait_exponential
from app.state import services
class Footer:
def __init__(self, text: str, **kwargs: Any) -> None:
self.text = text
self.icon_url = kwargs.get("icon_url")
self.proxy_icon_url = kwargs.get("proxy_icon_url")
class Image:
def __init__(self, **kwargs: Any) -> None:
self.url = kwargs.get("url")
self.proxy_url = kwargs.get("proxy_url")
self.height = kwargs.get("height")
self.width = kwargs.get("width")
class Thumbnail:
def __init__(self, **kwargs: Any) -> None:
self.url = kwargs.get("url")
self.proxy_url = kwargs.get("proxy_url")
self.height = kwargs.get("height")
self.width = kwargs.get("width")
class Video:
def __init__(self, **kwargs: Any) -> None:
self.url = kwargs.get("url")
self.height = kwargs.get("height")
self.width = kwargs.get("width")
class Provider:
def __init__(self, **kwargs: str) -> None:
self.url = kwargs.get("url")
self.name = kwargs.get("name")
class Author:
def __init__(self, **kwargs: str) -> None:
self.name = kwargs.get("name")
self.url = kwargs.get("url")
self.icon_url = kwargs.get("icon_url")
self.proxy_icon_url = kwargs.get("proxy_icon_url")
class Field:
def __init__(self, name: str, value: str, inline: bool = False) -> None:
self.name = name
self.value = value
self.inline = inline
class Embed:
def __init__(self, **kwargs: Any) -> None:
self.title = kwargs.get("title")
self.type = kwargs.get("type")
self.description = kwargs.get("description")
self.url = kwargs.get("url")
self.timestamp = kwargs.get("timestamp") # datetime
self.color = kwargs.get("color", 0x000000)
self.footer: Footer | None = kwargs.get("footer")
self.image: Image | None = kwargs.get("image")
self.thumbnail: Thumbnail | None = kwargs.get("thumbnail")
self.video: Video | None = kwargs.get("video")
self.provider: Provider | None = kwargs.get("provider")
self.author: Author | None = kwargs.get("author")
self.fields: list[Field] = kwargs.get("fields", [])
def set_footer(self, **kwargs: Any) -> None:
self.footer = Footer(**kwargs)
def set_image(self, **kwargs: Any) -> None:
self.image = Image(**kwargs)
def set_thumbnail(self, **kwargs: Any) -> None:
self.thumbnail = Thumbnail(**kwargs)
def set_video(self, **kwargs: Any) -> None:
self.video = Video(**kwargs)
def set_provider(self, **kwargs: Any) -> None:
self.provider = Provider(**kwargs)
def set_author(self, **kwargs: Any) -> None:
self.author = Author(**kwargs)
def add_field(self, name: str, value: str, inline: bool = False) -> None:
self.fields.append(Field(name, value, inline))
class Webhook:
"""A class to represent a single-use Discord webhook."""
def __init__(self, url: str, **kwargs: Any) -> None:
self.url = url
self.content = kwargs.get("content")
self.username = kwargs.get("username")
self.avatar_url = kwargs.get("avatar_url")
self.tts = kwargs.get("tts")
self.file = kwargs.get("file")
self.embeds = kwargs.get("embeds", [])
def add_embed(self, embed: Embed) -> None:
self.embeds.append(embed)
@property
def json(self) -> Any:
if not any([self.content, self.file, self.embeds]):
raise Exception(
"Webhook must contain at least one " "of (content, file, embeds).",
)
if self.content and len(self.content) > 2000:
raise Exception("Webhook content must be under " "2000 characters.")
payload: dict[str, Any] = {"embeds": []}
for key in ("content", "username", "avatar_url", "tts", "file"):
val = getattr(self, key)
if val is not None:
payload[key] = val
for embed in self.embeds:
embed_payload = {}
# simple params
for key in ("title", "type", "description", "url", "timestamp", "color"):
val = getattr(embed, key)
if val is not None:
embed_payload[key] = val
# class params, must turn into dict
for key in ("footer", "image", "thumbnail", "video", "provider", "author"):
val = getattr(embed, key)
if val is not None:
embed_payload[key] = val.__dict__
if embed.fields:
embed_payload["fields"] = [f.__dict__ for f in embed.fields]
payload["embeds"].append(embed_payload)
return payload
@retry(
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=10),
)
async def post(self) -> None:
"""Post the webhook in JSON format."""
# TODO: if `self.file is not None`, then we should
# use multipart/form-data instead of json payload.
headers = {"Content-Type": "application/json"}
response = await services.http_client.post(
self.url,
json=self.json,
headers=headers,
)
response.raise_for_status()

59
app/encryption.py Normal file
View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from base64 import b64decode
from base64 import b64encode
from py3rijndael import Pkcs7Padding
from py3rijndael import RijndaelCbc
def encrypt_score_aes_data(
# to encode
score_data: list[str],
client_hash: str,
# used for encoding
iv_b64: bytes,
osu_version: str,
) -> tuple[bytes, bytes]:
"""Encrypt the score data to base64."""
# TODO: perhaps this should return TypedDict?
# attempt to encrypt score data
aes = RijndaelCbc(
key=f"osu!-scoreburgr---------{osu_version}".encode(),
iv=b64decode(iv_b64),
padding=Pkcs7Padding(32),
block_size=32,
)
score_data_joined = ":".join(score_data)
score_data_b64 = b64encode(aes.encrypt(score_data_joined.encode()))
client_hash_b64 = b64encode(aes.encrypt(client_hash.encode()))
return score_data_b64, client_hash_b64
def decrypt_score_aes_data(
# to decode
score_data_b64: bytes,
client_hash_b64: bytes,
# used for decoding
iv_b64: bytes,
osu_version: str,
) -> tuple[list[str], str]:
"""Decrypt the base64'ed score data."""
# TODO: perhaps this should return TypedDict?
# attempt to decrypt score data
aes = RijndaelCbc(
key=f"osu!-scoreburgr---------{osu_version}".encode(),
iv=b64decode(iv_b64),
padding=Pkcs7Padding(32),
block_size=32,
)
score_data = aes.decrypt(b64decode(score_data_b64)).decode().split(":")
client_hash_decoded = aes.decrypt(b64decode(client_hash_b64)).decode()
# score data is delimited by colons (:).
return score_data, client_hash_decoded

102
app/logging.py Normal file
View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import datetime
import logging.config
import re
from collections.abc import Mapping
from enum import IntEnum
from zoneinfo import ZoneInfo
import yaml
from app import settings
def configure_logging() -> None:
with open("logging.yaml") as f:
config = yaml.safe_load(f.read())
logging.config.dictConfig(config)
class Ansi(IntEnum):
# Default colours
BLACK = 30
RED = 31
GREEN = 32
YELLOW = 33
BLUE = 34
MAGENTA = 35
CYAN = 36
WHITE = 37
# Light colours
GRAY = 90
LRED = 91
LGREEN = 92
LYELLOW = 93
LBLUE = 94
LMAGENTA = 95
LCYAN = 96
LWHITE = 97
RESET = 0
def __repr__(self) -> str:
return f"\x1b[{self.value}m"
def get_timestamp(full: bool = False, tz: ZoneInfo | None = None) -> str:
fmt = "%d/%m/%Y %I:%M:%S%p" if full else "%I:%M:%S%p"
return f"{datetime.datetime.now(tz=tz):{fmt}}"
ANSI_ESCAPE_REGEX = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
def escape_ansi(line: str) -> str:
return ANSI_ESCAPE_REGEX.sub("", line)
ROOT_LOGGER = logging.getLogger()
def log(
msg: str,
start_color: Ansi | None = None,
extra: Mapping[str, object] | None = None,
) -> None:
"""\
A thin wrapper around the stdlib logging module to handle mostly
backwards-compatibility for colours during our migration to the
standard library logging module.
"""
# TODO: decouple colors from the base logging function; move it to
# be a formatter-specific concern such that we can log without color.
if start_color is Ansi.LYELLOW:
log_level = logging.WARNING
elif start_color is Ansi.LRED:
log_level = logging.ERROR
else:
log_level = logging.INFO
if settings.LOG_WITH_COLORS:
color_prefix = f"{start_color!r}" if start_color is not None else ""
color_suffix = f"{Ansi.RESET!r}" if start_color is not None else ""
else:
msg = escape_ansi(msg)
color_prefix = color_suffix = ""
ROOT_LOGGER.log(log_level, f"{color_prefix}{msg}{color_suffix}", extra=extra)
TIME_ORDER_SUFFIXES = ["nsec", "μsec", "msec", "sec"]
def magnitude_fmt_time(nanosec: int | float) -> str:
suffix = None
for suffix in TIME_ORDER_SUFFIXES:
if nanosec < 1000:
break
nanosec /= 1000
return f"{nanosec:.2f} {suffix}"

11
app/objects/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
# type: ignore
# isort: dont-add-imports
from . import achievement
from . import beatmap
from . import channel
from . import collections
from . import match
from . import models
from . import player
from . import score

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.objects.score import Score
class Achievement:
"""A class to represent a single osu! achievement."""
def __init__(
self,
id: int,
file: str,
name: str,
desc: str,
cond: Callable[[Score, int], bool], # (score, mode) -> unlocked
) -> None:
self.id = id
self.file = file
self.name = name
self.desc = desc
self.cond = cond
def __repr__(self) -> str:
return f"{self.file}+{self.name}+{self.desc}"

996
app/objects/beatmap.py Normal file
View File

@@ -0,0 +1,996 @@
from __future__ import annotations
import functools
import hashlib
from collections import defaultdict
from collections.abc import Mapping
from datetime import datetime
from datetime import timedelta
from enum import IntEnum
from enum import unique
from pathlib import Path
from typing import Any
from typing import TypedDict
import httpx
from tenacity import retry
from tenacity.stop import stop_after_attempt
import app.settings
import app.state
import app.utils
from app.constants.gamemodes import GameMode
from app.logging import Ansi
from app.logging import log
from app.repositories import maps as maps_repo
from app.utils import escape_enum
from app.utils import pymysql_encode
# from dataclasses import dataclass
BEATMAPS_PATH = Path.cwd() / ".data/osu"
DEFAULT_LAST_UPDATE = datetime(1970, 1, 1)
IGNORED_BEATMAP_CHARS = dict.fromkeys(map(ord, r':\/*<>?"|'), None)
class BeatmapApiResponse(TypedDict):
data: list[dict[str, Any]] | None
status_code: int
@retry(reraise=True, stop=stop_after_attempt(3))
async def api_get_beatmaps(**params: Any) -> BeatmapApiResponse:
"""\
Fetch data from the osu!api with a beatmap's md5.
Optionally use osu.direct's API if the user has not provided an osu! api key.
"""
if app.settings.DEBUG:
log(f"Doing api (getbeatmaps) request {params}", Ansi.LMAGENTA)
if app.settings.OSU_API_KEY:
# https://github.com/ppy/osu-api/wiki#apiget_beatmaps
url = "https://old.ppy.sh/api/get_beatmaps"
params["k"] = str(app.settings.OSU_API_KEY)
else:
# https://osu.direct/doc
url = "https://osu.direct/api/get_beatmaps"
response = await app.state.services.http_client.get(url, params=params)
response_data = response.json()
if response.status_code == 200 and response_data: # (data may be [])
return {"data": response_data, "status_code": response.status_code}
return {"data": None, "status_code": response.status_code}
@retry(reraise=True, stop=stop_after_attempt(3))
async def api_get_osu_file(beatmap_id: int) -> bytes:
url = f"https://old.ppy.sh/osu/{beatmap_id}"
response = await app.state.services.http_client.get(url)
response.raise_for_status()
return response.read()
def disk_has_expected_osu_file(
beatmap_id: int,
expected_md5: str | None = None,
) -> bool:
osu_file_path = BEATMAPS_PATH / f"{beatmap_id}.osu"
file_exists = osu_file_path.exists()
if file_exists and expected_md5 is not None:
osu_file_md5 = hashlib.md5(osu_file_path.read_bytes()).hexdigest()
return osu_file_md5 == expected_md5
return file_exists
def write_osu_file_to_disk(beatmap_id: int, data: bytes) -> None:
osu_file_path = BEATMAPS_PATH / f"{beatmap_id}.osu"
osu_file_path.write_bytes(data)
async def ensure_osu_file_is_available(
beatmap_id: int,
expected_md5: str | None = None,
) -> bool:
"""\
Download the .osu file for a beatmap if it's not already present.
If `expected_md5` is provided, the file will be downloaded if it
does not match the expected md5 hash -- this is typically used for
ensuring a file is the latest expected version.
Returns whether the file is available for use.
"""
if disk_has_expected_osu_file(beatmap_id, expected_md5):
return True
try:
latest_osu_file = await api_get_osu_file(beatmap_id)
except httpx.HTTPStatusError:
return False
except Exception:
log(f"Failed to fetch osu file for {beatmap_id}", Ansi.LRED)
return False
write_osu_file_to_disk(beatmap_id, latest_osu_file)
return True
# for some ungodly reason, different values are used to
# represent different ranked statuses all throughout osu!
# This drives me and probably everyone else pretty insane,
# but we have nothing to do but deal with it B).
@unique
@pymysql_encode(escape_enum)
class RankedStatus(IntEnum):
"""Server side osu! beatmap ranked statuses.
Same as used in osu!'s /web/getscores.php.
"""
NotSubmitted = -1
Pending = 0
UpdateAvailable = 1
Ranked = 2
Approved = 3
Qualified = 4
Loved = 5
def __str__(self) -> str:
return {
self.NotSubmitted: "Unsubmitted",
self.Pending: "Unranked",
self.UpdateAvailable: "Outdated",
self.Ranked: "Ranked",
self.Approved: "Approved",
self.Qualified: "Qualified",
self.Loved: "Loved",
}[self]
@functools.cached_property
def osu_api(self) -> int:
"""Convert the value to osu!api status."""
# XXX: only the ones that exist are mapped.
return {
self.Pending: 0,
self.Ranked: 1,
self.Approved: 2,
self.Qualified: 3,
self.Loved: 4,
}[self]
@classmethod
@functools.cache
def from_osuapi(cls, osuapi_status: int) -> RankedStatus:
"""Convert from osu!api status."""
mapping: Mapping[int, RankedStatus] = defaultdict(
lambda: cls.UpdateAvailable,
{
-2: cls.Pending, # graveyard
-1: cls.Pending, # wip
0: cls.Pending,
1: cls.Ranked,
2: cls.Approved,
3: cls.Qualified,
4: cls.Loved,
},
)
return mapping[osuapi_status]
@classmethod
@functools.cache
def from_osudirect(cls, osudirect_status: int) -> RankedStatus:
"""Convert from osu!direct status."""
mapping: Mapping[int, RankedStatus] = defaultdict(
lambda: cls.UpdateAvailable,
{
0: cls.Ranked,
2: cls.Pending,
3: cls.Qualified,
# 4: all ranked statuses lol
5: cls.Pending, # graveyard
7: cls.Ranked, # played before
8: cls.Loved,
},
)
return mapping[osudirect_status]
@classmethod
@functools.cache
def from_str(cls, status_str: str) -> RankedStatus:
"""Convert from string value.""" # could perhaps have `'unranked': cls.Pending`?
mapping: Mapping[str, RankedStatus] = defaultdict(
lambda: cls.UpdateAvailable,
{
"pending": cls.Pending,
"ranked": cls.Ranked,
"approved": cls.Approved,
"qualified": cls.Qualified,
"loved": cls.Loved,
},
)
return mapping[status_str]
# @dataclass
# class BeatmapInfoRequest:
# filenames: Sequence[str]
# ids: Sequence[int]
# @dataclass
# class BeatmapInfo:
# id: int # i16
# map_id: int # i32
# set_id: int # i32
# thread_id: int # i32
# status: int # u8
# osu_rank: int # u8
# fruits_rank: int # u8
# taiko_rank: int # u8
# mania_rank: int # u8
# map_md5: str
class Beatmap:
"""A class representing an osu! beatmap.
This class provides a high level api which should always be the
preferred method of fetching beatmaps due to its housekeeping.
It will perform caching & invalidation, handle map updates while
minimizing osu!api requests, and always use the most efficient
method available to fetch the beatmap's information, while
maintaining a low overhead.
The only methods you should need are:
await Beatmap.from_md5(md5: str, set_id: int = -1) -> Beatmap | None
await Beatmap.from_bid(bid: int) -> Beatmap | None
Properties:
Beatmap.full -> str # Artist - Title [Version]
Beatmap.url -> str # https://osu.cmyui.xyz/b/321
Beatmap.embed -> str # [{url} {full}]
Beatmap.has_leaderboard -> bool
Beatmap.awards_ranked_pp -> bool
Beatmap.as_dict -> dict[str, object]
Lower level API:
Beatmap._from_md5_cache(md5: str, check_updates: bool = True) -> Beatmap | None
Beatmap._from_bid_cache(bid: int, check_updates: bool = True) -> Beatmap | None
Beatmap._from_md5_sql(md5: str) -> Beatmap | None
Beatmap._from_bid_sql(bid: int) -> Beatmap | None
Beatmap._parse_from_osuapi_resp(osuapi_resp: dict[str, object]) -> None
Note that the BeatmapSet class also provides a similar API.
Possibly confusing attributes
-----------
frozen: `bool`
Whether the beatmap's status is to be kept when a newer
version is found in the osu!api.
# XXX: This is set when a map's status is manually changed.
"""
def __init__(
self,
map_set: BeatmapSet,
md5: str = "",
id: int = 0,
set_id: int = 0,
artist: str = "",
title: str = "",
version: str = "",
creator: str = "",
last_update: datetime = DEFAULT_LAST_UPDATE,
total_length: int = 0,
max_combo: int = 0,
status: RankedStatus = RankedStatus.Pending,
frozen: bool = False,
plays: int = 0,
passes: int = 0,
mode: GameMode = GameMode.VANILLA_OSU,
bpm: float = 0.0,
cs: float = 0.0,
od: float = 0.0,
ar: float = 0.0,
hp: float = 0.0,
diff: float = 0.0,
filename: str = "",
) -> None:
self.set = map_set
self.md5 = md5
self.id = id
self.set_id = set_id
self.artist = artist
self.title = title
self.version = version
self.creator = creator
self.last_update = last_update
self.total_length = total_length
self.max_combo = max_combo
self.status = status
self.frozen = frozen
self.plays = plays
self.passes = passes
self.mode = mode
self.bpm = bpm
self.cs = cs
self.od = od
self.ar = ar
self.hp = hp
self.diff = diff
self.filename = filename
def __repr__(self) -> str:
return self.full_name
@property
def full_name(self) -> str:
"""The full osu! formatted name `self`."""
return f"{self.artist} - {self.title} [{self.version}]"
@property
def url(self) -> str:
"""The osu! beatmap url for `self`."""
return f"https://osu.{app.settings.DOMAIN}/b/{self.id}"
@property
def embed(self) -> str:
"""An osu! chat embed to `self`'s osu! beatmap page."""
return f"[{self.url} {self.full_name}]"
@property
def has_leaderboard(self) -> bool:
"""Return whether the map has a ranked leaderboard."""
return self.status in (
RankedStatus.Ranked,
RankedStatus.Approved,
RankedStatus.Loved,
)
@property
def awards_ranked_pp(self) -> bool:
"""Return whether the map's status awards ranked pp for scores."""
return self.status in (RankedStatus.Ranked, RankedStatus.Approved)
@property # perhaps worth caching some of?
def as_dict(self) -> dict[str, object]:
return {
"md5": self.md5,
"id": self.id,
"set_id": self.set_id,
"artist": self.artist,
"title": self.title,
"version": self.version,
"creator": self.creator,
"last_update": self.last_update,
"total_length": self.total_length,
"max_combo": self.max_combo,
"status": self.status,
"plays": self.plays,
"passes": self.passes,
"mode": self.mode,
"bpm": self.bpm,
"cs": self.cs,
"od": self.od,
"ar": self.ar,
"hp": self.hp,
"diff": self.diff,
}
""" High level API """
# There are three levels of storage used for beatmaps,
# the cache (ram), the db (disk), and the osu!api (web).
# Going down this list gets exponentially slower, so
# we always prioritze what's fastest when possible.
# These methods will keep beatmaps reasonably up to
# date and use the fastest storage available, while
# populating the higher levels of the cache with new maps.
@classmethod
async def from_md5(cls, md5: str, set_id: int = -1) -> Beatmap | None:
"""Fetch a map from the cache, database, or osuapi by md5."""
bmap = await cls._from_md5_cache(md5)
if not bmap:
# map not found in cache
# to be efficient, we want to cache the whole set
# at once rather than caching the individual map
if set_id <= 0:
# set id not provided - fetch it from the map md5
rec = await maps_repo.fetch_one(md5=md5)
if rec is not None:
# set found in db
set_id = rec["set_id"]
else:
# set not found in db, try api
api_data = await api_get_beatmaps(h=md5)
if api_data["data"] is None:
return None
api_response = api_data["data"]
set_id = int(api_response[0]["beatmapset_id"])
# fetch (and cache) beatmap set
beatmap_set = await BeatmapSet.from_bsid(set_id)
if beatmap_set is not None:
# the beatmap set has been cached - fetch beatmap from cache
bmap = await cls._from_md5_cache(md5)
# XXX:HACK in this case, BeatmapSet.from_bsid will have
# ensured the map is up to date, so we can just return it
return bmap
if bmap is not None:
if bmap.set._cache_expired():
await bmap.set._update_if_available()
return bmap
@classmethod
async def from_bid(cls, bid: int) -> Beatmap | None:
"""Fetch a map from the cache, database, or osuapi by id."""
bmap = await cls._from_bid_cache(bid)
if not bmap:
# map not found in cache
# to be efficient, we want to cache the whole set
# at once rather than caching the individual map
rec = await maps_repo.fetch_one(id=bid)
if rec is not None:
# set found in db
set_id = rec["set_id"]
else:
# set not found in db, try getting via api
api_data = await api_get_beatmaps(b=bid)
if api_data["data"] is None:
return None
api_response = api_data["data"]
set_id = int(api_response[0]["beatmapset_id"])
# fetch (and cache) beatmap set
beatmap_set = await BeatmapSet.from_bsid(set_id)
if beatmap_set is not None:
# the beatmap set has been cached - fetch beatmap from cache
bmap = await cls._from_bid_cache(bid)
# XXX:HACK in this case, BeatmapSet.from_bsid will have
# ensured the map is up to date, so we can just return it
return bmap
if bmap is not None:
if bmap.set._cache_expired():
await bmap.set._update_if_available()
return bmap
""" Lower level API """
# These functions are meant for internal use under
# all normal circumstances and should only be used
# if you're really modifying bancho.py by adding new
# features, or perhaps optimizing parts of the code.
def _parse_from_osuapi_resp(self, osuapi_resp: dict[str, Any]) -> None:
"""Change internal data with the data in osu!api format."""
# NOTE: `self` is not guaranteed to have any attributes
# initialized when this is called.
self.md5 = osuapi_resp["file_md5"]
# self.id = int(osuapi_resp['beatmap_id'])
self.set_id = int(osuapi_resp["beatmapset_id"])
self.artist, self.title, self.version, self.creator = (
osuapi_resp["artist"],
osuapi_resp["title"],
osuapi_resp["version"],
osuapi_resp["creator"],
)
self.filename = (
("{artist} - {title} ({creator}) [{version}].osu")
.format(**osuapi_resp)
.translate(IGNORED_BEATMAP_CHARS)
)
# quite a bit faster than using dt.strptime.
_last_update = osuapi_resp["last_update"]
self.last_update = datetime(
year=int(_last_update[0:4]),
month=int(_last_update[5:7]),
day=int(_last_update[8:10]),
hour=int(_last_update[11:13]),
minute=int(_last_update[14:16]),
second=int(_last_update[17:19]),
)
self.total_length = int(osuapi_resp["total_length"])
if osuapi_resp["max_combo"] is not None:
self.max_combo = int(osuapi_resp["max_combo"])
else:
self.max_combo = 0
# if a map is 'frozen', we keep its status
# even after an update from the osu!api.
if not getattr(self, "frozen", False):
osuapi_status = int(osuapi_resp["approved"])
self.status = RankedStatus.from_osuapi(osuapi_status)
self.mode = GameMode(int(osuapi_resp["mode"]))
if osuapi_resp["bpm"] is not None:
self.bpm = float(osuapi_resp["bpm"])
else:
self.bpm = 0.0
self.cs = float(osuapi_resp["diff_size"])
self.od = float(osuapi_resp["diff_overall"])
self.ar = float(osuapi_resp["diff_approach"])
self.hp = float(osuapi_resp["diff_drain"])
self.diff = float(osuapi_resp["difficultyrating"])
@staticmethod
async def _from_md5_cache(md5: str) -> Beatmap | None:
"""Fetch a map from the cache by md5."""
return app.state.cache.beatmap.get(md5, None)
@staticmethod
async def _from_bid_cache(bid: int) -> Beatmap | None:
"""Fetch a map from the cache by id."""
return app.state.cache.beatmap.get(bid, None)
class BeatmapSet:
"""A class to represent an osu! beatmap set.
Like the Beatmap class, this class provides a high level api
which should always be the preferred method of fetching beatmaps
due to its housekeeping. It will perform caching & invalidation,
handle map updates while minimizing osu!api requests, and always
use the most efficient method available to fetch the beatmap's
information, while maintaining a low overhead.
The only methods you should need are:
await BeatmapSet.from_bsid(bsid: int) -> BeatmapSet | None
BeatmapSet.all_officially_ranked_or_approved() -> bool
BeatmapSet.all_officially_loved() -> bool
Properties:
BeatmapSet.url -> str # https://osu.cmyui.xyz/s/123
Lower level API:
await BeatmapSet._from_bsid_cache(bsid: int) -> BeatmapSet | None
await BeatmapSet._from_bsid_sql(bsid: int) -> BeatmapSet | None
await BeatmapSet._from_bsid_osuapi(bsid: int) -> BeatmapSet | None
BeatmapSet._cache_expired() -> bool
await BeatmapSet._update_if_available() -> None
await BeatmapSet._save_to_sql() -> None
"""
def __init__(
self,
id: int,
last_osuapi_check: datetime,
maps: list[Beatmap] | None = None,
) -> None:
self.id = id
self.maps = maps or []
self.last_osuapi_check = last_osuapi_check
def __repr__(self) -> str:
map_names = []
for bmap in self.maps:
name = f"{bmap.artist} - {bmap.title}"
if name not in map_names:
map_names.append(name)
return ", ".join(map_names)
@property
def url(self) -> str:
"""The online url for this beatmap set."""
return f"https://osu.{app.settings.DOMAIN}/s/{self.id}"
def any_beatmaps_have_official_leaderboards(self) -> bool:
"""Whether all the maps in the set have leaderboards on official servers."""
leaderboard_having_statuses = (
RankedStatus.Loved,
RankedStatus.Ranked,
RankedStatus.Approved,
)
return any(bmap.status in leaderboard_having_statuses for bmap in self.maps)
def _cache_expired(self) -> bool:
"""Whether the cached version of the set is
expired and needs an update from the osu!api."""
current_datetime = datetime.now()
if not self.maps:
return True
# the delta between cache invalidations will increase depending
# on how long it's been since the map was last updated on osu!
last_map_update = max(bmap.last_update for bmap in self.maps)
update_delta = current_datetime - last_map_update
# with a minimum of 2 hours, add 5 hours per year since its update.
# the formula for this is subject to adjustment in the future.
check_delta = timedelta(hours=2 + ((5 / 365) * update_delta.days))
# it's much less likely that a beatmapset who has beatmaps with
# leaderboards on official servers will be updated.
if self.any_beatmaps_have_official_leaderboards():
check_delta *= 4
# we'll cache for an absolute maximum of 1 day.
check_delta = min(check_delta, timedelta(days=1))
return current_datetime > (self.last_osuapi_check + check_delta)
async def _update_if_available(self) -> None:
"""Fetch the newest data from the api, check for differences
and propogate any update into our cache & database."""
try:
api_data = await api_get_beatmaps(s=self.id)
except (httpx.TransportError, httpx.DecodingError):
# NOTE: TransportError is directly caused by the API being unavailable
# NOTE: DecodingError is caused by the API returning HTML and
# normally happens when CF protection is enabled while
# osu! recovers from a DDOS attack
# we do not want to delete the beatmap in this case, so we simply return
# but do not set the last check, as we would like to retry these ASAP
return
if api_data["data"] is not None:
api_response = api_data["data"]
old_maps = {bmap.id: bmap for bmap in self.maps}
new_maps = {int(api_map["beatmap_id"]): api_map for api_map in api_response}
self.last_osuapi_check = datetime.now()
# delete maps from old_maps where old.id not in new_maps
# update maps from old_maps where old.md5 != new.md5
# add maps to old_maps where new.id not in old_maps
updated_maps: list[Beatmap] = []
map_md5s_to_delete: set[str] = set()
# temp value for building the new beatmap
bmap: Beatmap
# find maps in our current state that've been deleted, or need updates
for old_id, old_map in old_maps.items():
if old_id not in new_maps:
# delete map from old_maps
map_md5s_to_delete.add(old_map.md5)
else:
new_map = new_maps[old_id]
new_ranked_status = RankedStatus.from_osuapi(
int(new_map["approved"]),
)
if (
old_map.md5 != new_map["file_md5"]
or old_map.status != new_ranked_status
):
# update map from old_maps
bmap = old_maps[old_id]
bmap._parse_from_osuapi_resp(new_map)
updated_maps.append(bmap)
else:
# map is the same, make no changes
updated_maps.append(old_map) # (this line is _maybe_ needed?)
# find maps that aren't in our current state, and add them
for new_id, new_map in new_maps.items():
if new_id not in old_maps:
# new map we don't have locally, add it
bmap = Beatmap.__new__(Beatmap)
bmap.id = new_id
bmap._parse_from_osuapi_resp(new_map)
# (some implementation-specific stuff not given by api)
bmap.frozen = False
bmap.passes = 0
bmap.plays = 0
bmap.set = self
updated_maps.append(bmap)
# save changes to cache
self.maps = updated_maps
# save changes to sql
if map_md5s_to_delete:
# delete maps
await app.state.services.database.execute(
"DELETE FROM maps WHERE md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# delete scores on the maps
await app.state.services.database.execute(
"DELETE FROM scores WHERE map_md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# update last_osuapi_check
await app.state.services.database.execute(
"REPLACE INTO mapsets "
"(id, server, last_osuapi_check) "
"VALUES (:id, :server, :last_osuapi_check)",
{
"id": self.id,
"server": "osu!",
"last_osuapi_check": self.last_osuapi_check,
},
)
# update maps in sql
await self._save_to_sql()
elif api_data["status_code"] in (404, 200):
# NOTE: 200 can return an empty array of beatmaps,
# so we still delete in this case if the beatmap data is None
# TODO: a couple of open questions here:
# - should we delete the beatmap from the database if it's not in the osu!api?
# - are 404 and 200 the only cases where we should delete the beatmap?
if self.maps:
map_md5s_to_delete = {bmap.md5 for bmap in self.maps}
# delete maps
await app.state.services.database.execute(
"DELETE FROM maps WHERE md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# delete scores on the maps
await app.state.services.database.execute(
"DELETE FROM scores WHERE map_md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# delete set
await app.state.services.database.execute(
"DELETE FROM mapsets WHERE id = :set_id",
{"set_id": self.id},
)
async def _save_to_sql(self) -> None:
"""Save the object's attributes into the database."""
await app.state.services.database.execute_many(
"REPLACE INTO maps ("
"md5, id, server, set_id, "
"artist, title, version, creator, "
"filename, last_update, total_length, "
"max_combo, status, frozen, "
"plays, passes, mode, bpm, "
"cs, od, ar, hp, diff"
") VALUES ("
":md5, :id, :server, :set_id, "
":artist, :title, :version, :creator, "
":filename, :last_update, :total_length, "
":max_combo, :status, :frozen, "
":plays, :passes, :mode, :bpm, "
":cs, :od, :ar, :hp, :diff"
")",
[
{
"md5": bmap.md5,
"id": bmap.id,
"server": "osu!",
"set_id": bmap.set_id,
"artist": bmap.artist,
"title": bmap.title,
"version": bmap.version,
"creator": bmap.creator,
"filename": bmap.filename,
"last_update": bmap.last_update,
"total_length": bmap.total_length,
"max_combo": bmap.max_combo,
"status": bmap.status,
"frozen": bmap.frozen,
"plays": bmap.plays,
"passes": bmap.passes,
"mode": bmap.mode,
"bpm": bmap.bpm,
"cs": bmap.cs,
"od": bmap.od,
"ar": bmap.ar,
"hp": bmap.hp,
"diff": bmap.diff,
}
for bmap in self.maps
],
)
@staticmethod
async def _from_bsid_cache(bsid: int) -> BeatmapSet | None:
"""Fetch a mapset from the cache by set id."""
return app.state.cache.beatmapset.get(bsid, None)
@classmethod
async def _from_bsid_sql(cls, bsid: int) -> BeatmapSet | None:
"""Fetch a mapset from the database by set id."""
last_osuapi_check = await app.state.services.database.fetch_val(
"SELECT last_osuapi_check FROM mapsets WHERE id = :set_id",
{"set_id": bsid},
column=0, # last_osuapi_check
)
if last_osuapi_check is None:
return None
bmap_set = cls(id=bsid, last_osuapi_check=last_osuapi_check)
for row in await maps_repo.fetch_many(set_id=bsid):
bmap = Beatmap(
md5=row["md5"],
id=row["id"],
set_id=row["set_id"],
artist=row["artist"],
title=row["title"],
version=row["version"],
creator=row["creator"],
last_update=row["last_update"],
total_length=row["total_length"],
max_combo=row["max_combo"],
status=RankedStatus(row["status"]),
frozen=row["frozen"],
plays=row["plays"],
passes=row["passes"],
mode=GameMode(row["mode"]),
bpm=row["bpm"],
cs=row["cs"],
od=row["od"],
ar=row["ar"],
hp=row["hp"],
diff=row["diff"],
filename=row["filename"],
map_set=bmap_set,
)
# XXX: tempfix for bancho.py <v3.4.1,
# where filenames weren't stored.
if not bmap.filename:
bmap.filename = (
("{artist} - {title} ({creator}) [{version}].osu")
.format(
artist=row["artist"],
title=row["title"],
creator=row["creator"],
version=row["version"],
)
.translate(IGNORED_BEATMAP_CHARS)
)
await maps_repo.partial_update(bmap.id, filename=bmap.filename)
bmap_set.maps.append(bmap)
return bmap_set
@classmethod
async def _from_bsid_osuapi(cls, bsid: int) -> BeatmapSet | None:
"""Fetch a mapset from the osu!api by set id."""
api_data = await api_get_beatmaps(s=bsid)
if api_data["data"] is not None:
api_response = api_data["data"]
self = cls(id=bsid, last_osuapi_check=datetime.now())
# XXX: pre-mapset bancho.py support
# select all current beatmaps
# that're frozen in the db
res = await app.state.services.database.fetch_all(
"SELECT id, status FROM maps WHERE set_id = :set_id AND frozen = 1",
{"set_id": bsid},
)
current_maps = {row["id"]: row["status"] for row in res}
for api_bmap in api_response:
# newer version available for this map
bmap: Beatmap = Beatmap.__new__(Beatmap)
bmap.id = int(api_bmap["beatmap_id"])
if bmap.id in current_maps:
# map is currently frozen, keep it's status.
bmap.status = RankedStatus(current_maps[bmap.id])
bmap.frozen = True
else:
bmap.frozen = False
bmap._parse_from_osuapi_resp(api_bmap)
# (some implementation-specific stuff not given by api)
bmap.passes = 0
bmap.plays = 0
bmap.set = self
self.maps.append(bmap)
await app.state.services.database.execute(
"REPLACE INTO mapsets "
"(id, server, last_osuapi_check) "
"VALUES (:id, :server, :last_osuapi_check)",
{
"id": self.id,
"server": "osu!",
"last_osuapi_check": self.last_osuapi_check,
},
)
await self._save_to_sql()
return self
return None
@classmethod
async def from_bsid(cls, bsid: int) -> BeatmapSet | None:
"""Cache all maps in a set from the osuapi, optionally
returning beatmaps by their md5 or id."""
bmap_set = await cls._from_bsid_cache(bsid)
did_api_request = False
if not bmap_set:
bmap_set = await cls._from_bsid_sql(bsid)
if not bmap_set:
bmap_set = await cls._from_bsid_osuapi(bsid)
if not bmap_set:
return None
did_api_request = True
# TODO: this can be done less often for certain types of maps,
# such as ones that're ranked on bancho and won't be updated,
# and perhaps ones that haven't been updated in a long time.
if not did_api_request and bmap_set._cache_expired():
await bmap_set._update_if_available()
# cache the beatmap set, and beatmaps
# to be efficient in future requests
cache_beatmap_set(bmap_set)
return bmap_set
def cache_beatmap(beatmap: Beatmap) -> None:
"""Add the beatmap to the cache."""
app.state.cache.beatmap[beatmap.md5] = beatmap
app.state.cache.beatmap[beatmap.id] = beatmap
def cache_beatmap_set(beatmap_set: BeatmapSet) -> None:
"""Add the beatmap set, and each beatmap to the cache."""
app.state.cache.beatmapset[beatmap_set.id] = beatmap_set
for beatmap in beatmap_set.maps:
cache_beatmap(beatmap)

138
app/objects/channel.py Normal file
View File

@@ -0,0 +1,138 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import app.packets
import app.state
from app.constants.privileges import Privileges
if TYPE_CHECKING:
from app.objects.player import Player
class Channel:
"""An osu! chat channel.
Possibly confusing attributes
-----------
_name: `str`
A name string of the channel.
The cls.`name` property wraps handling for '#multiplayer' and
'#spectator' when communicating with the osu! client; only use
this attr when you need the channel's true name; otherwise you
should use the `name` property described below.
instance: `bool`
Instanced channels are deleted when all players have left;
this is useful for things like multiplayer, spectator, etc.
"""
def __init__(
self,
name: str,
topic: str,
read_priv: Privileges = Privileges.UNRESTRICTED,
write_priv: Privileges = Privileges.UNRESTRICTED,
auto_join: bool = True,
instance: bool = False,
) -> None:
# TODO: think of better names than `_name` and `name`
self._name = name # 'real' name ('#{multi/spec}_{id}')
if self._name.startswith("#spec_"):
self.name = "#spectator"
elif self._name.startswith("#multi_"):
self.name = "#multiplayer"
else:
self.name = self._name
self.topic = topic
self.read_priv = read_priv
self.write_priv = write_priv
self.auto_join = auto_join
self.instance = instance
self.players: list[Player] = []
def __repr__(self) -> str:
return f"<{self._name}>"
def __contains__(self, player: Player) -> bool:
return player in self.players
# XXX: should this be cached differently?
def can_read(self, priv: Privileges) -> bool:
if not self.read_priv:
return True
return priv & self.read_priv != 0
def can_write(self, priv: Privileges) -> bool:
if not self.write_priv:
return True
return priv & self.write_priv != 0
def send(self, msg: str, sender: Player, to_self: bool = False) -> None:
"""Enqueue `msg` to all appropriate clients from `sender`."""
data = app.packets.send_message(
sender=sender.name,
msg=msg,
recipient=self.name,
sender_id=sender.id,
)
for player in self.players:
if sender.id not in player.blocks and (to_self or player.id != sender.id):
player.enqueue(data)
def send_bot(self, msg: str) -> None:
"""Enqueue `msg` to all connected clients from bot."""
bot = app.state.sessions.bot
msg_len = len(msg)
if msg_len >= 31979: # TODO ??????????
msg = f"message would have crashed games ({msg_len} chars)"
self.enqueue(
app.packets.send_message(
sender=bot.name,
msg=msg,
recipient=self.name,
sender_id=bot.id,
),
)
def send_selective(
self,
msg: str,
sender: Player,
recipients: set[Player],
) -> None:
"""Enqueue `sender`'s `msg` to `recipients`."""
for player in recipients:
if player in self:
player.send(msg, sender=sender, chan=self)
def append(self, player: Player) -> None:
"""Add `player` to the channel's players."""
self.players.append(player)
def remove(self, player: Player) -> None:
"""Remove `player` from the channel's players."""
self.players.remove(player)
if not self.players and self.instance:
# if it's an instance channel and this
# is the last member leaving, just remove
# the channel from the global list.
app.state.sessions.channels.remove(self)
def enqueue(self, data: bytes, immune: Sequence[int] = []) -> None:
"""Enqueue `data` to all connected clients not in `immune`."""
for player in self.players:
if player.id not in immune:
player.enqueue(data)

314
app/objects/collections.py Normal file
View File

@@ -0,0 +1,314 @@
from __future__ import annotations
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Sequence
from typing import Any
import databases.core
import app.settings
import app.state
import app.utils
from app.constants.privileges import ClanPrivileges
from app.constants.privileges import Privileges
from app.logging import Ansi
from app.logging import log
from app.objects.channel import Channel
from app.objects.match import Match
from app.objects.player import Player
from app.repositories import channels as channels_repo
from app.repositories import clans as clans_repo
from app.repositories import users as users_repo
from app.utils import make_safe_name
class Channels(list[Channel]):
"""The currently active chat channels on the server."""
def __iter__(self) -> Iterator[Channel]:
return super().__iter__()
def __contains__(self, o: object) -> bool:
"""Check whether internal list contains `o`."""
# Allow string to be passed to compare vs. name.
if isinstance(o, str):
return o in (chan.name for chan in self)
else:
return super().__contains__(o)
def __repr__(self) -> str:
# XXX: we use the "real" name, aka
# #multi_1 instead of #multiplayer
# #spect_1 instead of #spectator.
return f'[{", ".join(c._name for c in self)}]'
def get_by_name(self, name: str) -> Channel | None:
"""Get a channel from the list by `name`."""
for channel in self:
if channel._name == name:
return channel
return None
def append(self, channel: Channel) -> None:
"""Append `channel` to the list."""
super().append(channel)
if app.settings.DEBUG:
log(f"{channel} added to channels list.")
def extend(self, channels: Iterable[Channel]) -> None:
"""Extend the list with `channels`."""
super().extend(channels)
if app.settings.DEBUG:
log(f"{channels} added to channels list.")
def remove(self, channel: Channel) -> None:
"""Remove `channel` from the list."""
super().remove(channel)
if app.settings.DEBUG:
log(f"{channel} removed from channels list.")
async def prepare(self) -> None:
"""Fetch data from sql & return; preparing to run the server."""
log("Fetching channels from sql.", Ansi.LCYAN)
for row in await channels_repo.fetch_many():
self.append(
Channel(
name=row["name"],
topic=row["topic"],
read_priv=Privileges(row["read_priv"]),
write_priv=Privileges(row["write_priv"]),
auto_join=row["auto_join"] == 1,
),
)
class Matches(list[Match | None]):
"""The currently active multiplayer matches on the server."""
def __init__(self) -> None:
MAX_MATCHES = 64 # TODO: refactor this out of existence
super().__init__([None] * MAX_MATCHES)
def __iter__(self) -> Iterator[Match | None]:
return super().__iter__()
def __repr__(self) -> str:
return f'[{", ".join(match.name for match in self if match)}]'
def get_free(self) -> int | None:
"""Return the first free match id from `self`."""
for idx, match in enumerate(self):
if match is None:
return idx
return None
def remove(self, match: Match | None) -> None:
"""Remove `match` from the list."""
for i, _m in enumerate(self):
if match is _m:
self[i] = None
break
if app.settings.DEBUG:
log(f"{match} removed from matches list.")
class Players(list[Player]):
"""The currently active players on the server."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __iter__(self) -> Iterator[Player]:
return super().__iter__()
def __contains__(self, player: object) -> bool:
# allow us to either pass in the player
# obj, or the player name as a string.
if isinstance(player, str):
return player in (player.name for player in self)
else:
return super().__contains__(player)
def __repr__(self) -> str:
return f'[{", ".join(map(repr, self))}]'
@property
def ids(self) -> set[int]:
"""Return a set of the current ids in the list."""
return {p.id for p in self}
@property
def staff(self) -> set[Player]:
"""Return a set of the current staff online."""
return {p for p in self if p.priv & Privileges.STAFF}
@property
def restricted(self) -> set[Player]:
"""Return a set of the current restricted players."""
return {p for p in self if not p.priv & Privileges.UNRESTRICTED}
@property
def unrestricted(self) -> set[Player]:
"""Return a set of the current unrestricted players."""
return {p for p in self if p.priv & Privileges.UNRESTRICTED}
def enqueue(self, data: bytes, immune: Sequence[Player] = []) -> None:
"""Enqueue `data` to all players, except for those in `immune`."""
for player in self:
if player not in immune:
player.enqueue(data)
def get(
self,
token: str | None = None,
id: int | None = None,
name: str | None = None,
) -> Player | None:
"""Get a player by token, id, or name from cache."""
for player in self:
if token is not None:
if player.token == token:
return player
elif id is not None:
if player.id == id:
return player
elif name is not None:
if player.safe_name == make_safe_name(name):
return player
return None
async def get_sql(
self,
id: int | None = None,
name: str | None = None,
) -> Player | None:
"""Get a player by token, id, or name from sql."""
# try to get from sql.
player = await users_repo.fetch_one(
id=id,
name=name,
fetch_all_fields=True,
)
if player is None:
return None
clan_id: int | None = None
clan_priv: ClanPrivileges | None = None
if player["clan_id"] != 0:
clan_id = player["clan_id"]
clan_priv = ClanPrivileges(player["clan_priv"])
return Player(
id=player["id"],
name=player["name"],
priv=Privileges(player["priv"]),
pw_bcrypt=player["pw_bcrypt"].encode(),
token=Player.generate_token(),
clan_id=clan_id,
clan_priv=clan_priv,
geoloc={
"latitude": 0.0,
"longitude": 0.0,
"country": {
"acronym": player["country"],
"numeric": app.state.services.country_codes[player["country"]],
},
},
silence_end=player["silence_end"],
donor_end=player["donor_end"],
api_key=player["api_key"],
)
async def from_cache_or_sql(
self,
id: int | None = None,
name: str | None = None,
) -> Player | None:
"""Try to get player from cache, or sql as fallback."""
player = self.get(id=id, name=name)
if player is not None:
return player
player = await self.get_sql(id=id, name=name)
if player is not None:
return player
return None
async def from_login(
self,
name: str,
pw_md5: str,
sql: bool = False,
) -> Player | None:
"""Return a player with a given name & pw_md5, from cache or sql."""
player = self.get(name=name)
if not player:
if not sql:
return None
player = await self.get_sql(name=name)
if not player:
return None
assert player.pw_bcrypt is not None
if app.state.cache.bcrypt[player.pw_bcrypt] == pw_md5.encode():
return player
return None
def append(self, player: Player) -> None:
"""Append `p` to the list."""
if player in self:
if app.settings.DEBUG:
log(f"{player} double-added to global player list?")
return
super().append(player)
def remove(self, player: Player) -> None:
"""Remove `p` from the list."""
if player not in self:
if app.settings.DEBUG:
log(f"{player} removed from player list when not online?")
return
super().remove(player)
async def initialize_ram_caches() -> None:
"""Setup & cache the global collections before listening for connections."""
# fetch channels, clans and pools from db
await app.state.sessions.channels.prepare()
bot = await users_repo.fetch_one(id=1)
if bot is None:
raise RuntimeError("Bot account not found in database.")
# create bot & add it to online players
app.state.sessions.bot = Player(
id=1,
name=bot["name"],
priv=Privileges.UNRESTRICTED,
pw_bcrypt=None,
token=Player.generate_token(),
login_time=float(0x7FFFFFFF), # (never auto-dc)
is_bot_client=True,
)
app.state.sessions.players.append(app.state.sessions.bot)
# static api keys
app.state.sessions.api_keys = {
row["api_key"]: row["id"]
for row in await app.state.services.database.fetch_all(
"SELECT id, api_key FROM users WHERE api_key IS NOT NULL",
)
}

552
app/objects/match.py Normal file
View File

@@ -0,0 +1,552 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Sequence
from datetime import datetime as datetime
from datetime import timedelta as timedelta
from enum import IntEnum
from enum import unique
from typing import TYPE_CHECKING
from typing import TypedDict
import app.packets
import app.settings
import app.state
from app.constants import regexes
from app.constants.gamemodes import GameMode
from app.constants.mods import Mods
from app.objects.beatmap import Beatmap
from app.repositories.tourney_pools import TourneyPool
from app.utils import escape_enum
from app.utils import pymysql_encode
if TYPE_CHECKING:
from asyncio import TimerHandle
from app.objects.channel import Channel
from app.objects.player import Player
MAX_MATCH_NAME_LENGTH = 50
@unique
@pymysql_encode(escape_enum)
class SlotStatus(IntEnum):
open = 1
locked = 2
not_ready = 4
ready = 8
no_map = 16
playing = 32
complete = 64
quit = 128
# has_player = not_ready | ready | no_map | playing | complete
@unique
@pymysql_encode(escape_enum)
class MatchTeams(IntEnum):
neutral = 0
blue = 1
red = 2
"""
# implemented by osu! and send between client/server,
# quite frequently even, but seems useless??
@unique
@pymysql_encode(escape_enum)
class MatchTypes(IntEnum):
standard = 0
powerplay = 1 # literally no idea what this is for
"""
@unique
@pymysql_encode(escape_enum)
class MatchWinConditions(IntEnum):
score = 0
accuracy = 1
combo = 2
scorev2 = 3
@unique
@pymysql_encode(escape_enum)
class MatchTeamTypes(IntEnum):
head_to_head = 0
tag_coop = 1
team_vs = 2
tag_team_vs = 3
class Slot:
"""An individual player slot in an osu! multiplayer match."""
def __init__(self) -> None:
self.player: Player | None = None
self.status = SlotStatus.open
self.team = MatchTeams.neutral
self.mods = Mods.NOMOD
self.loaded = False
self.skipped = False
def empty(self) -> bool:
return self.player is None
def copy_from(self, other: Slot) -> None:
self.player = other.player
self.status = other.status
self.team = other.team
self.mods = other.mods
def reset(self, new_status: SlotStatus = SlotStatus.open) -> None:
self.player = None
self.status = new_status
self.team = MatchTeams.neutral
self.mods = Mods.NOMOD
self.loaded = False
self.skipped = False
class StartingTimers(TypedDict):
start: TimerHandle
alerts: list[TimerHandle]
time: float
class Match:
"""\
An osu! multiplayer match.
Possibly confusing attributes
-----------
_refs: set[`Player`]
A set of players who have access to mp commands in the match.
These can be used with the !mp <addref/rmref/listref> commands.
slots: list[`Slot`]
A list of 16 `Slot` objects representing the match's slots.
starting: dict[str, `TimerHandle`] | None
Used when the match is started with !mp start <seconds>.
It stores both the starting timer, and the chat alert timers.
seed: `int`
The seed used for osu!mania's random mod.
use_pp_scoring: `bool`
Whether pp should be used as a win condition override during scrims.
"""
def __init__(
self,
id: int,
name: str,
password: str,
has_public_history: bool,
map_name: str,
map_id: int,
map_md5: str,
host_id: int,
mode: GameMode,
mods: Mods,
win_condition: MatchWinConditions,
team_type: MatchTeamTypes,
freemods: bool,
seed: int,
chat_channel: Channel,
) -> None:
self.id = id
self.name = name
self.passwd = password
self.has_public_history = has_public_history
self.host_id = host_id
self._refs: set[Player] = set()
self.map_id = map_id
self.map_md5 = map_md5
self.map_name = map_name
self.prev_map_id = 0 # previously chosen map
self.mods = mods
self.mode = mode
self.freemods = freemods
self.chat = chat_channel
self.slots = [Slot() for _ in range(16)]
# self.type = MatchTypes.standard
self.team_type = team_type
self.win_condition = win_condition
self.in_progress = False
self.starting: StartingTimers | None = None
self.seed = seed # used for mania random mod
self.tourney_pool: TourneyPool | None = None
# scrimmage stuff
self.is_scrimming = False
self.match_points: dict[MatchTeams | Player, int] = defaultdict(int)
self.bans: set[tuple[Mods, int]] = set()
self.winners: list[Player | MatchTeams | None] = [] # none for tie
self.winning_pts = 0
self.use_pp_scoring = False # only for scrims
self.tourney_clients: set[int] = set() # player ids
@property
def host(self) -> Player:
player = app.state.sessions.players.get(id=self.host_id)
if player is None:
raise ValueError(
f"Host with id {self.host_id} not found for match {self!r}",
)
return player
@property
def url(self) -> str:
"""The match's invitation url."""
return f"osump://{self.id}/{self.passwd}"
@property
def map_url(self) -> str:
"""The osu! beatmap url for `self`'s map."""
return f"https://osu.{app.settings.DOMAIN}/b/{self.map_id}"
@property
def embed(self) -> str:
"""An osu! chat embed for `self`."""
return f"[{self.url} {self.name}]"
@property
def map_embed(self) -> str:
"""An osu! chat embed for `self`'s map."""
return f"[{self.map_url} {self.map_name}]"
@property
def refs(self) -> set[Player]:
"""Return all players with referee permissions."""
refs = self._refs
if self.host is not None:
refs.add(self.host)
return refs
def __repr__(self) -> str:
return f"<{self.name} ({self.id})>"
def get_slot(self, player: Player) -> Slot | None:
"""Return the slot containing a given player."""
for s in self.slots:
if player is s.player:
return s
return None
def get_slot_id(self, player: Player) -> int | None:
"""Return the slot index containing a given player."""
for idx, s in enumerate(self.slots):
if player is s.player:
return idx
return None
def get_free(self) -> int | None:
"""Return the first unoccupied slot in multi, if any."""
for idx, s in enumerate(self.slots):
if s.status == SlotStatus.open:
return idx
return None
def get_host_slot(self) -> Slot | None:
"""Return the slot containing the host."""
for s in self.slots:
if s.player is not None and s.player is self.host:
return s
return None
def copy(self, m: Match) -> None:
"""Fully copy the data of another match obj."""
self.map_id = m.map_id
self.map_md5 = m.map_md5
self.map_name = m.map_name
self.freemods = m.freemods
self.mode = m.mode
self.team_type = m.team_type
self.win_condition = m.win_condition
self.mods = m.mods
self.name = m.name
def enqueue(
self,
data: bytes,
lobby: bool = True,
immune: Sequence[int] = [],
) -> None:
"""Add data to be sent to all clients in the match."""
self.chat.enqueue(data, immune)
lchan = app.state.sessions.channels.get_by_name("#lobby")
if lobby and lchan and lchan.players:
lchan.enqueue(data)
def enqueue_state(self, lobby: bool = True) -> None:
"""Enqueue `self`'s state to players in the match & lobby."""
# TODO: hmm this is pretty bad, writes twice
# send password only to users currently in the match.
self.chat.enqueue(app.packets.update_match(self, send_pw=True))
lchan = app.state.sessions.channels.get_by_name("#lobby")
if lobby and lchan and lchan.players:
lchan.enqueue(app.packets.update_match(self, send_pw=False))
def unready_players(self, expected: SlotStatus = SlotStatus.ready) -> None:
"""Unready any players in the `expected` state."""
for s in self.slots:
if s.status == expected:
s.status = SlotStatus.not_ready
def reset_players_loaded_status(self) -> None:
"""Reset all players' loaded status."""
for s in self.slots:
s.loaded = False
s.skipped = False
def start(self) -> None:
"""Start the match for all ready players with the map."""
no_map: list[int] = []
for s in self.slots:
# start each player who has the map.
if s.player is not None:
if s.status != SlotStatus.no_map:
s.status = SlotStatus.playing
else:
no_map.append(s.player.id)
self.in_progress = True
self.enqueue(app.packets.match_start(self), immune=no_map, lobby=False)
self.enqueue_state()
def reset_scrim(self) -> None:
"""Reset the current scrim's winning points & bans."""
self.match_points.clear()
self.winners.clear()
self.bans.clear()
async def await_submissions(
self,
was_playing: Sequence[Slot],
) -> tuple[dict[MatchTeams | Player, int], Sequence[Player]]:
"""Await score submissions from all players in completed state."""
scores: dict[MatchTeams | Player, int] = defaultdict(int)
didnt_submit: list[Player] = []
time_waited = 0.0 # allow up to 10s (total, not per player)
ffa = self.team_type in (MatchTeamTypes.head_to_head, MatchTeamTypes.tag_coop)
if self.use_pp_scoring:
win_cond = "pp"
else:
win_cond = ("score", "acc", "max_combo", "score")[self.win_condition]
bmap = await Beatmap.from_md5(self.map_md5)
if not bmap:
# map isn't submitted
return {}, ()
for s in was_playing:
# continue trying to fetch each player's
# scores until they've all been submitted.
while True:
assert s.player is not None
rc_score = s.player.recent_score
max_age = datetime.now() - timedelta(
seconds=bmap.total_length + time_waited + 0.5,
)
if (
rc_score
and rc_score.bmap
and rc_score.bmap.md5 == self.map_md5
and rc_score.server_time > max_age
):
# score found, add to our scores dict if != 0.
score: int = getattr(rc_score, win_cond)
if score:
key: MatchTeams | Player = s.player if ffa else s.team
scores[key] += score
break
# wait 0.5s and try again
await asyncio.sleep(0.5)
time_waited += 0.5
if time_waited > 10:
# inform the match this user didn't
# submit a score in time, and skip them.
didnt_submit.append(s.player)
break
# all scores retrieved, update the match.
return scores, didnt_submit
async def update_matchpoints(self, was_playing: Sequence[Slot]) -> None:
"""\
Determine the winner from `scores`, increment & inform players.
This automatically works with the match settings (such as
win condition, teams, & co-op) to determine the appropriate
winner, and will use any team names included in the match name,
along with the match name (fmt: OWC2020: (Team1) vs. (Team2)).
For the examples, we'll use accuracy as a win condition.
Teams, match title: `OWC2015: (United States) vs. (China)`.
United States takes the point! (293.32% vs 292.12%)
Total Score: United States | 7 - 2 | China
United States takes the match, finishing OWC2015 with a score of 7 - 2!
FFA, the top <=3 players will be listed for the total score.
Justice takes the point! (94.32% [Match avg. 91.22%])
Total Score: Justice - 3 | cmyui - 2 | FrostiDrinks - 2
Justice takes the match, finishing with a score of 4 - 2!
"""
scores, didnt_submit = await self.await_submissions(was_playing)
for player in didnt_submit:
self.chat.send_bot(f"{player} didn't submit a score (timeout: 10s).")
if not scores:
self.chat.send_bot("Scores could not be calculated.")
return None
ffa = self.team_type in (
MatchTeamTypes.head_to_head,
MatchTeamTypes.tag_coop,
)
# all scores are equal, it was a tie.
if len(scores) != 1 and len(set(scores.values())) == 1:
self.winners.append(None)
self.chat.send_bot("The point has ended in a tie!")
return None
# Find the winner & increment their matchpoints.
winner: Player | MatchTeams = max(scores, key=lambda k: scores[k])
self.winners.append(winner)
self.match_points[winner] += 1
msg: list[str] = []
def add_suffix(score: int | float) -> str | int | float:
if self.use_pp_scoring:
return f"{score:.2f}pp"
elif self.win_condition == MatchWinConditions.accuracy:
return f"{score:.2f}%"
elif self.win_condition == MatchWinConditions.combo:
return f"{score}x"
else:
return str(score)
if ffa:
from app.objects.player import Player
assert isinstance(winner, Player)
msg.append(
f"{winner.name} takes the point! ({add_suffix(scores[winner])} "
f"[Match avg. {add_suffix(sum(scores.values()) / len(scores))}])",
)
wmp = self.match_points[winner]
# check if match point #1 has enough points to win.
if self.winning_pts and wmp == self.winning_pts:
# we have a champion, announce & reset our match.
self.is_scrimming = False
self.reset_scrim()
self.bans.clear()
m = f"{winner.name} takes the match! Congratulations!"
else:
# no winner, just announce the match points so far.
# for ffa, we'll only announce the top <=3 players.
m_points = sorted(self.match_points.items(), key=lambda x: x[1])
m = f"Total Score: {' | '.join([f'{k.name} - {v}' for k, v in m_points])}"
msg.append(m)
del m
else: # teams
assert isinstance(winner, MatchTeams)
r_match = regexes.TOURNEY_MATCHNAME.match(self.name)
if r_match:
match_name = r_match["name"]
team_names = {
MatchTeams.blue: r_match["T1"],
MatchTeams.red: r_match["T2"],
}
else:
match_name = self.name
team_names = {MatchTeams.blue: "Blue", MatchTeams.red: "Red"}
# teams are binary, so we have a loser.
if winner is MatchTeams.blue:
loser = MatchTeams.red
else:
loser = MatchTeams.blue
# from match name if available, else blue/red.
wname = team_names[winner]
lname = team_names[loser]
# scores from the recent play
# (according to win condition)
ws = add_suffix(scores[winner])
ls = add_suffix(scores[loser])
# total win/loss score in the match.
wmp = self.match_points[winner]
lmp = self.match_points[loser]
# announce the score for the most recent play.
msg.append(f"{wname} takes the point! ({ws} vs. {ls})")
# check if the winner has enough match points to win the match.
if self.winning_pts and wmp == self.winning_pts:
# we have a champion, announce & reset our match.
self.is_scrimming = False
self.reset_scrim()
msg.append(
f"{wname} takes the match, finishing {match_name} "
f"with a score of {wmp} - {lmp}! Congratulations!",
)
else:
# no winner, just announce the match points so far.
msg.append(f"Total Score: {wname} | {wmp} - {lmp} | {lname}")
if didnt_submit:
self.chat.send_bot(
"If you'd like to perform a rematch, "
"please use the `!mp rematch` command.",
)
for line in msg:
self.chat.send_bot(line)

8
app/objects/models.py Normal file
View File

@@ -0,0 +1,8 @@
from __future__ import annotations
from pydantic import BaseModel
class OsuBeatmapRequestForm(BaseModel):
Filenames: list[str]
Ids: list[int]

1017
app/objects/player.py Normal file

File diff suppressed because it is too large Load Diff

453
app/objects/score.py Normal file
View File

@@ -0,0 +1,453 @@
from __future__ import annotations
import functools
import hashlib
from datetime import datetime
from enum import IntEnum
from enum import unique
from pathlib import Path
from typing import TYPE_CHECKING
import app.state
import app.usecases.performance
import app.utils
from app.constants.clientflags import ClientFlags
from app.constants.gamemodes import GameMode
from app.constants.mods import Mods
from app.objects.beatmap import Beatmap
from app.repositories import scores as scores_repo
from app.usecases.performance import ScoreParams
from app.utils import escape_enum
from app.utils import pymysql_encode
if TYPE_CHECKING:
from app.objects.player import Player
BEATMAPS_PATH = Path.cwd() / ".data/osu"
@unique
class Grade(IntEnum):
# NOTE: these are implemented in the opposite order
# as osu! to make more sense with <> operators.
N = 0
F = 1
D = 2
C = 3
B = 4
A = 5
S = 6 # S
SH = 7 # HD S
X = 8 # SS
XH = 9 # HD SS
@classmethod
@functools.cache
def from_str(cls, s: str) -> Grade:
return {
"xh": Grade.XH,
"x": Grade.X,
"sh": Grade.SH,
"s": Grade.S,
"a": Grade.A,
"b": Grade.B,
"c": Grade.C,
"d": Grade.D,
"f": Grade.F,
"n": Grade.N,
}[s.lower()]
def __format__(self, format_spec: str) -> str:
if format_spec == "stats_column":
return f"{self.name.lower()}_count"
else:
raise ValueError(f"Invalid format specifier {format_spec}")
@unique
@pymysql_encode(escape_enum)
class SubmissionStatus(IntEnum):
# TODO: make a system more like bancho's?
FAILED = 0
SUBMITTED = 1
BEST = 2
def __repr__(self) -> str:
return {
self.FAILED: "Failed",
self.SUBMITTED: "Submitted",
self.BEST: "Best",
}[self]
class Score:
"""\
Server side representation of an osu! score; any gamemode.
Possibly confusing attributes
-----------
bmap: `Beatmap | None`
A beatmap obj representing the osu map.
player: `Player | None`
A player obj of the player who submitted the score.
grade: `Grade`
The letter grade in the score.
rank: `int`
The leaderboard placement of the score.
perfect: `bool`
Whether the score is a full-combo.
time_elapsed: `int`
The total elapsed time of the play (in milliseconds).
client_flags: `int`
osu!'s old anticheat flags.
prev_best: `Score | None`
The previous best score before this play was submitted.
NOTE: just because a score has a `prev_best` attribute does
mean the score is our best score on the map! the `status`
value will always be accurate for any score.
"""
def __init__(self) -> None:
# TODO: check whether the reamining Optional's should be
self.id: int | None = None
self.bmap: Beatmap | None = None
self.player: Player | None = None
self.mode: GameMode
self.mods: Mods
self.pp: float
self.sr: float
self.score: int
self.max_combo: int
self.acc: float
# TODO: perhaps abstract these differently
# since they're mode dependant? feels weird..
self.n300: int
self.n100: int # n150 for taiko
self.n50: int
self.nmiss: int
self.ngeki: int
self.nkatu: int
self.grade: Grade
self.passed: bool
self.perfect: bool
self.status: SubmissionStatus
self.client_time: datetime
self.server_time: datetime
self.time_elapsed: int
self.client_flags: ClientFlags
self.client_checksum: str
self.rank: int | None = None
self.prev_best: Score | None = None
def __repr__(self) -> str:
# TODO: i really need to clean up my reprs
try:
assert self.bmap is not None
return (
f"<{self.acc:.2f}% {self.max_combo}x {self.nmiss}M "
f"#{self.rank} on {self.bmap.full_name} for {self.pp:,.2f}pp>"
)
except:
return super().__repr__()
"""Classmethods to fetch a score object from various data types."""
@classmethod
async def from_sql(cls, score_id: int) -> Score | None:
"""Create a score object from sql using its scoreid."""
rec = await scores_repo.fetch_one(score_id)
if rec is None:
return None
s = cls()
s.id = rec["id"]
s.bmap = await Beatmap.from_md5(rec["map_md5"])
s.player = await app.state.sessions.players.from_cache_or_sql(id=rec["userid"])
s.sr = 0.0 # TODO
s.pp = rec["pp"]
s.score = rec["score"]
s.max_combo = rec["max_combo"]
s.mods = Mods(rec["mods"])
s.acc = rec["acc"]
s.n300 = rec["n300"]
s.n100 = rec["n100"]
s.n50 = rec["n50"]
s.nmiss = rec["nmiss"]
s.ngeki = rec["ngeki"]
s.nkatu = rec["nkatu"]
s.grade = Grade.from_str(rec["grade"])
s.perfect = rec["perfect"] == 1
s.status = SubmissionStatus(rec["status"])
s.passed = s.status != SubmissionStatus.FAILED
s.mode = GameMode(rec["mode"])
s.server_time = rec["play_time"]
s.time_elapsed = rec["time_elapsed"]
s.client_flags = ClientFlags(rec["client_flags"])
s.client_checksum = rec["online_checksum"]
if s.bmap:
s.rank = await s.calculate_placement()
return s
@classmethod
def from_submission(cls, data: list[str]) -> Score:
"""Create a score object from an osu! submission string."""
s = cls()
""" parse the following format
# 0 online_checksum
# 1 n300
# 2 n100
# 3 n50
# 4 ngeki
# 5 nkatu
# 6 nmiss
# 7 score
# 8 max_combo
# 9 perfect
# 10 grade
# 11 mods
# 12 passed
# 13 gamemode
# 14 play_time # yyMMddHHmmss
# 15 osu_version + (" " * client_flags)
"""
s.client_checksum = data[0]
s.n300 = int(data[1])
s.n100 = int(data[2])
s.n50 = int(data[3])
s.ngeki = int(data[4])
s.nkatu = int(data[5])
s.nmiss = int(data[6])
s.score = int(data[7])
s.max_combo = int(data[8])
s.perfect = data[9] == "True"
s.grade = Grade.from_str(data[10])
s.mods = Mods(int(data[11]))
s.passed = data[12] == "True"
s.mode = GameMode.from_params(int(data[13]), s.mods)
s.client_time = datetime.strptime(data[14], "%y%m%d%H%M%S")
s.client_flags = ClientFlags(data[15].count(" ") & ~4)
s.server_time = datetime.now()
return s
def compute_online_checksum(
self,
osu_version: str,
osu_client_hash: str,
storyboard_checksum: str,
) -> str:
"""Validate the online checksum of the score."""
assert self.player is not None
assert self.bmap is not None
return hashlib.md5(
"chickenmcnuggets{0}o15{1}{2}smustard{3}{4}uu{5}{6}{7}{8}{9}{10}{11}Q{12}{13}{15}{14:%y%m%d%H%M%S}{16}{17}".format(
self.n100 + self.n300,
self.n50,
self.ngeki,
self.nkatu,
self.nmiss,
self.bmap.md5,
self.max_combo,
self.perfect,
self.player.name,
self.score,
self.grade.name,
int(self.mods),
self.passed,
self.mode.as_vanilla,
self.client_time,
osu_version, # 20210520
osu_client_hash,
storyboard_checksum,
# yyMMddHHmmss
).encode(),
).hexdigest()
"""Methods to calculate internal data for a score."""
async def calculate_placement(self) -> int:
assert self.bmap is not None
if self.mode >= GameMode.RELAX_OSU:
scoring_metric = "pp"
score = self.pp
else:
scoring_metric = "score"
score = self.score
num_better_scores: int | None = await app.state.services.database.fetch_val(
"SELECT COUNT(*) AS c FROM scores s "
"INNER JOIN users u ON u.id = s.userid "
"WHERE s.map_md5 = :map_md5 AND s.mode = :mode "
"AND s.status = 2 AND u.priv & 1 "
f"AND s.{scoring_metric} > :score",
{
"map_md5": self.bmap.md5,
"mode": self.mode,
"score": score,
},
column=0, # COUNT(*)
)
assert num_better_scores is not None
return num_better_scores + 1
def calculate_performance(self, beatmap_id: int) -> tuple[float, float]:
"""Calculate PP and star rating for our score."""
mode_vn = self.mode.as_vanilla
score_args = ScoreParams(
mode=mode_vn,
mods=int(self.mods),
combo=self.max_combo,
ngeki=self.ngeki,
n300=self.n300,
nkatu=self.nkatu,
n100=self.n100,
n50=self.n50,
nmiss=self.nmiss,
)
result = app.usecases.performance.calculate_performances(
osu_file_path=str(BEATMAPS_PATH / f"{beatmap_id}.osu"),
scores=[score_args],
)
return result[0]["performance"]["pp"], result[0]["difficulty"]["stars"]
async def calculate_status(self) -> None:
"""Calculate the submission status of a submitted score."""
assert self.player is not None
assert self.bmap is not None
recs = await scores_repo.fetch_many(
user_id=self.player.id,
map_md5=self.bmap.md5,
mode=self.mode,
status=SubmissionStatus.BEST,
)
if recs:
rec = recs[0]
# we have a score on the map.
# save it as our previous best score.
self.prev_best = await Score.from_sql(rec["id"])
assert self.prev_best is not None
# if our new score is better, update
# both of our score's submission statuses.
# NOTE: this will be updated in sql later on in submission
if self.pp > rec["pp"]:
self.status = SubmissionStatus.BEST
self.prev_best.status = SubmissionStatus.SUBMITTED
else:
self.status = SubmissionStatus.SUBMITTED
else:
# this is our first score on the map.
self.status = SubmissionStatus.BEST
def calculate_accuracy(self) -> float:
"""Calculate the accuracy of our score."""
mode_vn = self.mode.as_vanilla
if mode_vn == 0: # osu!
total = self.n300 + self.n100 + self.n50 + self.nmiss
if total == 0:
return 0.0
return (
100.0
* ((self.n300 * 300.0) + (self.n100 * 100.0) + (self.n50 * 50.0))
/ (total * 300.0)
)
elif mode_vn == 1: # osu!taiko
total = self.n300 + self.n100 + self.nmiss
if total == 0:
return 0.0
return 100.0 * ((self.n100 * 0.5) + self.n300) / total
elif mode_vn == 2: # osu!catch
total = self.n300 + self.n100 + self.n50 + self.nkatu + self.nmiss
if total == 0:
return 0.0
return 100.0 * (self.n300 + self.n100 + self.n50) / total
elif mode_vn == 3: # osu!mania
total = (
self.n300 + self.n100 + self.n50 + self.ngeki + self.nkatu + self.nmiss
)
if total == 0:
return 0.0
if self.mods & Mods.SCOREV2:
return (
100.0
* (
(self.n50 * 50.0)
+ (self.n100 * 100.0)
+ (self.nkatu * 200.0)
+ (self.n300 * 300.0)
+ (self.ngeki * 305.0)
)
/ (total * 305.0)
)
return (
100.0
* (
(self.n50 * 50.0)
+ (self.n100 * 100.0)
+ (self.nkatu * 200.0)
+ ((self.n300 + self.ngeki) * 300.0)
)
/ (total * 300.0)
)
else:
raise Exception(f"Invalid vanilla mode {mode_vn}")
""" Methods for updating a score. """
async def increment_replay_views(self) -> None:
# TODO: move replay views to be per-score rather than per-user
assert self.player is not None
# TODO: apparently cached stats don't store replay views?
# need to refactor that to be able to use stats_repo here
await app.state.services.database.execute(
f"UPDATE stats "
"SET replay_views = replay_views + 1 "
"WHERE id = :user_id AND mode = :mode",
{"user_id": self.player.id, "mode": self.mode},
)

1289
app/packets.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from sqlalchemy.orm import DeclarativeMeta
from sqlalchemy.orm import registry
mapper_registry = registry()
class Base(metaclass=DeclarativeMeta):
__abstract__ = True
registry = mapper_registry
metadata = mapper_registry.metadata
__init__ = mapper_registry.constructor

View File

@@ -0,0 +1,173 @@
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
from typing import TypedDict
from typing import cast
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
if TYPE_CHECKING:
from app.objects.score import Score
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
class AchievementsTable(Base):
__tablename__ = "achievements"
id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True)
file = Column("file", String(128), nullable=False)
name = Column("name", String(128, collation="utf8"), nullable=False)
desc = Column("desc", String(256, collation="utf8"), nullable=False)
cond = Column("cond", String(64), nullable=False)
__table_args__ = (
Index("achievements_desc_uindex", desc, unique=True),
Index("achievements_file_uindex", file, unique=True),
Index("achievements_name_uindex", name, unique=True),
)
READ_PARAMS = (
AchievementsTable.id,
AchievementsTable.file,
AchievementsTable.name,
AchievementsTable.desc,
AchievementsTable.cond,
)
class Achievement(TypedDict):
id: int
file: str
name: str
desc: str
cond: Callable[[Score, int], bool]
async def create(
file: str,
name: str,
desc: str,
cond: str,
) -> Achievement:
"""Create a new achievement."""
insert_stmt = insert(AchievementsTable).values(
file=file,
name=name,
desc=desc,
cond=cond,
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == rec_id)
achievement = await app.state.services.database.fetch_one(select_stmt)
assert achievement is not None
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
return cast(Achievement, achievement)
async def fetch_one(
id: int | None = None,
name: str | None = None,
) -> Achievement | None:
"""Fetch a single achievement."""
if id is None and name is None:
raise ValueError("Must provide at least one parameter.")
select_stmt = select(*READ_PARAMS)
if id is not None:
select_stmt = select_stmt.where(AchievementsTable.id == id)
if name is not None:
select_stmt = select_stmt.where(AchievementsTable.name == name)
achievement = await app.state.services.database.fetch_one(select_stmt)
if achievement is None:
return None
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
return cast(Achievement, achievement)
async def fetch_count() -> int:
"""Fetch the number of achievements."""
select_stmt = select(func.count().label("count")).select_from(AchievementsTable)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
page: int | None = None,
page_size: int | None = None,
) -> list[Achievement]:
"""Fetch a list of achievements."""
select_stmt = select(*READ_PARAMS)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
achievements = await app.state.services.database.fetch_all(select_stmt)
for achievement in achievements:
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
return cast(list[Achievement], achievements)
async def partial_update(
id: int,
file: str | _UnsetSentinel = UNSET,
name: str | _UnsetSentinel = UNSET,
desc: str | _UnsetSentinel = UNSET,
cond: str | _UnsetSentinel = UNSET,
) -> Achievement | None:
"""Update an existing achievement."""
update_stmt = update(AchievementsTable).where(AchievementsTable.id == id)
if not isinstance(file, _UnsetSentinel):
update_stmt = update_stmt.values(file=file)
if not isinstance(name, _UnsetSentinel):
update_stmt = update_stmt.values(name=name)
if not isinstance(desc, _UnsetSentinel):
update_stmt = update_stmt.values(desc=desc)
if not isinstance(cond, _UnsetSentinel):
update_stmt = update_stmt.values(cond=cond)
await app.state.services.database.execute(update_stmt)
select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == id)
achievement = await app.state.services.database.fetch_one(select_stmt)
if achievement is None:
return None
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
return cast(Achievement, achievement)
async def delete_one(
id: int,
) -> Achievement | None:
"""Delete an existing achievement."""
select_stmt = select(*READ_PARAMS).where(AchievementsTable.id == id)
achievement = await app.state.services.database.fetch_one(select_stmt)
if achievement is None:
return None
delete_stmt = delete(AchievementsTable).where(AchievementsTable.id == id)
await app.state.services.database.execute(delete_stmt)
achievement["cond"] = eval(f'lambda score, mode_vn: {achievement["cond"]}')
return cast(Achievement, achievement)

View File

@@ -0,0 +1,184 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
class ChannelsTable(Base):
__tablename__ = "channels"
id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True)
name = Column("name", String(32), nullable=False)
topic = Column("topic", String(256), nullable=False)
read_priv = Column("read_priv", Integer, nullable=False, server_default="1")
write_priv = Column("write_priv", Integer, nullable=False, server_default="2")
auto_join = Column("auto_join", TINYINT(1), nullable=False, server_default="0")
__table_args__ = (
Index("channels_name_uindex", name, unique=True),
Index("channels_auto_join_index", auto_join),
)
READ_PARAMS = (
ChannelsTable.id,
ChannelsTable.name,
ChannelsTable.topic,
ChannelsTable.read_priv,
ChannelsTable.write_priv,
ChannelsTable.auto_join,
)
class Channel(TypedDict):
id: int
name: str
topic: str
read_priv: int
write_priv: int
auto_join: bool
async def create(
name: str,
topic: str,
read_priv: int,
write_priv: int,
auto_join: bool,
) -> Channel:
"""Create a new channel."""
insert_stmt = insert(ChannelsTable).values(
name=name,
topic=topic,
read_priv=read_priv,
write_priv=write_priv,
auto_join=auto_join,
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(ChannelsTable.id == rec_id)
channel = await app.state.services.database.fetch_one(select_stmt)
assert channel is not None
return cast(Channel, channel)
async def fetch_one(
id: int | None = None,
name: str | None = None,
) -> Channel | None:
"""Fetch a single channel."""
if id is None and name is None:
raise ValueError("Must provide at least one parameter.")
select_stmt = select(*READ_PARAMS)
if id is not None:
select_stmt = select_stmt.where(ChannelsTable.id == id)
if name is not None:
select_stmt = select_stmt.where(ChannelsTable.name == name)
channel = await app.state.services.database.fetch_one(select_stmt)
return cast(Channel | None, channel)
async def fetch_count(
read_priv: int | None = None,
write_priv: int | None = None,
auto_join: bool | None = None,
) -> int:
if read_priv is None and write_priv is None and auto_join is None:
raise ValueError("Must provide at least one parameter.")
select_stmt = select(func.count().label("count")).select_from(ChannelsTable)
if read_priv is not None:
select_stmt = select_stmt.where(ChannelsTable.read_priv == read_priv)
if write_priv is not None:
select_stmt = select_stmt.where(ChannelsTable.write_priv == write_priv)
if auto_join is not None:
select_stmt = select_stmt.where(ChannelsTable.auto_join == auto_join)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
read_priv: int | None = None,
write_priv: int | None = None,
auto_join: bool | None = None,
page: int | None = None,
page_size: int | None = None,
) -> list[Channel]:
"""Fetch multiple channels from the database."""
select_stmt = select(*READ_PARAMS)
if read_priv is not None:
select_stmt = select_stmt.where(ChannelsTable.read_priv == read_priv)
if write_priv is not None:
select_stmt = select_stmt.where(ChannelsTable.write_priv == write_priv)
if auto_join is not None:
select_stmt = select_stmt.where(ChannelsTable.auto_join == auto_join)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
channels = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Channel], channels)
async def partial_update(
name: str,
topic: str | _UnsetSentinel = UNSET,
read_priv: int | _UnsetSentinel = UNSET,
write_priv: int | _UnsetSentinel = UNSET,
auto_join: bool | _UnsetSentinel = UNSET,
) -> Channel | None:
"""Update a channel in the database."""
update_stmt = update(ChannelsTable).where(ChannelsTable.name == name)
if not isinstance(topic, _UnsetSentinel):
update_stmt = update_stmt.values(topic=topic)
if not isinstance(read_priv, _UnsetSentinel):
update_stmt = update_stmt.values(read_priv=read_priv)
if not isinstance(write_priv, _UnsetSentinel):
update_stmt = update_stmt.values(write_priv=write_priv)
if not isinstance(auto_join, _UnsetSentinel):
update_stmt = update_stmt.values(auto_join=auto_join)
await app.state.services.database.execute(update_stmt)
select_stmt = select(*READ_PARAMS).where(ChannelsTable.name == name)
channel = await app.state.services.database.fetch_one(select_stmt)
return cast(Channel | None, channel)
async def delete_one(
name: str,
) -> Channel | None:
"""Delete a channel from the database."""
select_stmt = select(*READ_PARAMS).where(ChannelsTable.name == name)
channel = await app.state.services.database.fetch_one(select_stmt)
if channel is None:
return None
delete_stmt = delete(ChannelsTable).where(ChannelsTable.name == name)
await app.state.services.database.execute(delete_stmt)
return cast(Channel | None, channel)

156
app/repositories/clans.py Normal file
View File

@@ -0,0 +1,156 @@
from __future__ import annotations
from datetime import datetime
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
class ClansTable(Base):
__tablename__ = "clans"
id = Column("id", Integer, primary_key=True, nullable=False, autoincrement=True)
name = Column("name", String(16, collation="utf8"), nullable=False)
tag = Column("tag", String(6, collation="utf8"), nullable=False)
owner = Column("owner", Integer, nullable=False)
created_at = Column("created_at", DateTime, nullable=False)
__table_args__ = (
Index("clans_name_uindex", name, unique=False),
Index("clans_owner_uindex", owner, unique=True),
Index("clans_tag_uindex", tag, unique=True),
)
READ_PARAMS = (
ClansTable.id,
ClansTable.name,
ClansTable.tag,
ClansTable.owner,
ClansTable.created_at,
)
class Clan(TypedDict):
id: int
name: str
tag: str
owner: int
created_at: datetime
async def create(
name: str,
tag: str,
owner: int,
) -> Clan:
"""Create a new clan in the database."""
insert_stmt = insert(ClansTable).values(
name=name,
tag=tag,
owner=owner,
created_at=func.now(),
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(ClansTable.id == rec_id)
clan = await app.state.services.database.fetch_one(select_stmt)
assert clan is not None
return cast(Clan, clan)
async def fetch_one(
id: int | None = None,
name: str | None = None,
tag: str | None = None,
owner: int | None = None,
) -> Clan | None:
"""Fetch a single clan from the database."""
if id is None and name is None and tag is None and owner is None:
raise ValueError("Must provide at least one parameter.")
select_stmt = select(*READ_PARAMS)
if id is not None:
select_stmt = select_stmt.where(ClansTable.id == id)
if name is not None:
select_stmt = select_stmt.where(ClansTable.name == name)
if tag is not None:
select_stmt = select_stmt.where(ClansTable.tag == tag)
if owner is not None:
select_stmt = select_stmt.where(ClansTable.owner == owner)
clan = await app.state.services.database.fetch_one(select_stmt)
return cast(Clan | None, clan)
async def fetch_count() -> int:
"""Fetch the number of clans in the database."""
select_stmt = select(func.count().label("count")).select_from(ClansTable)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
page: int | None = None,
page_size: int | None = None,
) -> list[Clan]:
"""Fetch many clans from the database."""
select_stmt = select(*READ_PARAMS)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
clans = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Clan], clans)
async def partial_update(
id: int,
name: str | _UnsetSentinel = UNSET,
tag: str | _UnsetSentinel = UNSET,
owner: int | _UnsetSentinel = UNSET,
) -> Clan | None:
"""Update a clan in the database."""
update_stmt = update(ClansTable).where(ClansTable.id == id)
if not isinstance(name, _UnsetSentinel):
update_stmt = update_stmt.values(name=name)
if not isinstance(tag, _UnsetSentinel):
update_stmt = update_stmt.values(tag=tag)
if not isinstance(owner, _UnsetSentinel):
update_stmt = update_stmt.values(owner=owner)
await app.state.services.database.execute(update_stmt)
select_stmt = select(*READ_PARAMS).where(ClansTable.id == id)
clan = await app.state.services.database.fetch_one(select_stmt)
return cast(Clan | None, clan)
async def delete_one(id: int) -> Clan | None:
"""Delete a clan from the database."""
select_stmt = select(*READ_PARAMS).where(ClansTable.id == id)
clan = await app.state.services.database.fetch_one(select_stmt)
if clan is None:
return None
delete_stmt = delete(ClansTable).where(ClansTable.id == id)
await app.state.services.database.execute(delete_stmt)
return cast(Clan, clan)

View File

@@ -0,0 +1,133 @@
from __future__ import annotations
from datetime import datetime
from typing import TypedDict
from typing import cast
from sqlalchemy import CHAR
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Integer
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.dialects.mysql import Insert as MysqlInsert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.sql import ColumnElement
from sqlalchemy.types import Boolean
import app.state.services
from app.repositories import Base
from app.repositories.users import UsersTable
class ClientHashesTable(Base):
__tablename__ = "client_hashes"
userid = Column("userid", Integer, nullable=False, primary_key=True)
osupath = Column("osupath", CHAR(32), nullable=False, primary_key=True)
adapters = Column("adapters", CHAR(32), nullable=False, primary_key=True)
uninstall_id = Column("uninstall_id", CHAR(32), nullable=False, primary_key=True)
disk_serial = Column("disk_serial", CHAR(32), nullable=False, primary_key=True)
latest_time = Column("latest_time", DateTime, nullable=False)
occurrences = Column("occurrences", Integer, nullable=False, server_default="0")
READ_PARAMS = (
ClientHashesTable.userid,
ClientHashesTable.osupath,
ClientHashesTable.adapters,
ClientHashesTable.uninstall_id,
ClientHashesTable.disk_serial,
ClientHashesTable.latest_time,
ClientHashesTable.occurrences,
)
class ClientHash(TypedDict):
userid: int
osupath: str
adapters: str
uninstall_id: str
disk_serial: str
latest_time: datetime
occurrences: int
class ClientHashWithPlayer(ClientHash):
name: str
priv: int
async def create(
userid: int,
osupath: str,
adapters: str,
uninstall_id: str,
disk_serial: str,
) -> ClientHash:
"""Create a new client hash entry in the database."""
insert_stmt: MysqlInsert = (
mysql_insert(ClientHashesTable)
.values(
userid=userid,
osupath=osupath,
adapters=adapters,
uninstall_id=uninstall_id,
disk_serial=disk_serial,
latest_time=func.now(),
occurrences=1,
)
.on_duplicate_key_update(
latest_time=func.now(),
occurrences=ClientHashesTable.occurrences + 1,
)
)
await app.state.services.database.execute(insert_stmt)
select_stmt = (
select(*READ_PARAMS)
.where(ClientHashesTable.userid == userid)
.where(ClientHashesTable.osupath == osupath)
.where(ClientHashesTable.adapters == adapters)
.where(ClientHashesTable.uninstall_id == uninstall_id)
.where(ClientHashesTable.disk_serial == disk_serial)
)
client_hash = await app.state.services.database.fetch_one(select_stmt)
assert client_hash is not None
return cast(ClientHash, client_hash)
async def fetch_any_hardware_matches_for_user(
userid: int,
running_under_wine: bool,
adapters: str,
uninstall_id: str,
disk_serial: str | None = None,
) -> list[ClientHashWithPlayer]:
"""\
Fetch a list of matching hardware addresses where any of
`adapters`, `uninstall_id` or `disk_serial` match other users
from the database.
"""
select_stmt = (
select(*READ_PARAMS, UsersTable.name, UsersTable.priv)
.join(UsersTable, ClientHashesTable.userid == UsersTable.id)
.where(ClientHashesTable.userid != userid)
)
if running_under_wine:
select_stmt = select_stmt.where(ClientHashesTable.uninstall_id == uninstall_id)
else:
# make disk serial optional in the OR
oneof_filters: list[ColumnElement[Boolean]] = []
oneof_filters.append(ClientHashesTable.adapters == adapters)
oneof_filters.append(ClientHashesTable.uninstall_id == uninstall_id)
if disk_serial is not None:
oneof_filters.append(ClientHashesTable.disk_serial == disk_serial)
select_stmt = select_stmt.where(or_(*oneof_filters))
client_hashes = await app.state.services.database.fetch_all(select_stmt)
return cast(list[ClientHashWithPlayer], client_hashes)

View File

@@ -0,0 +1,125 @@
from __future__ import annotations
from enum import StrEnum
from typing import TypedDict
from typing import cast
from sqlalchemy import CHAR
from sqlalchemy import Column
from sqlalchemy import Enum
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import and_
from sqlalchemy import insert
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.dialects.mysql import FLOAT
import app.state.services
from app.repositories import Base
from app.repositories.users import UsersTable
class TargetType(StrEnum):
REPLAY = "replay"
BEATMAP = "map"
SONG = "song"
class CommentsTable(Base):
__tablename__ = "comments"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
target_id = Column("target_id", nullable=False)
target_type = Column(Enum(TargetType, name="target_type"), nullable=False)
userid = Column("userid", Integer, nullable=False)
time = Column("time", FLOAT(precision=6, scale=3), nullable=False)
comment = Column("comment", String(80, collation="utf8"), nullable=False)
colour = Column("colour", CHAR(6), nullable=True)
READ_PARAMS = (
CommentsTable.id,
CommentsTable.target_id,
CommentsTable.target_type,
CommentsTable.userid,
CommentsTable.time,
CommentsTable.comment,
CommentsTable.colour,
)
class Comment(TypedDict):
id: int
target_id: int
target_type: TargetType
userid: int
time: float
comment: str
colour: str | None
async def create(
target_id: int,
target_type: TargetType,
userid: int,
time: float,
comment: str,
colour: str | None,
) -> Comment:
"""Create a new comment entry in the database."""
insert_stmt = insert(CommentsTable).values(
target_id=target_id,
target_type=target_type,
userid=userid,
time=time,
comment=comment,
colour=colour,
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(CommentsTable.id == rec_id)
_comment = await app.state.services.database.fetch_one(select_stmt)
assert _comment is not None
return cast(Comment, _comment)
class CommentWithUserPrivileges(Comment):
priv: int
async def fetch_all_relevant_to_replay(
score_id: int | None = None,
map_set_id: int | None = None,
map_id: int | None = None,
) -> list[CommentWithUserPrivileges]:
"""\
Fetch all comments from the database where any of the following match:
- `score_id`
- `map_set_id`
- `map_id`
"""
select_stmt = (
select(READ_PARAMS, UsersTable.priv)
.join(UsersTable, CommentsTable.userid == UsersTable.id)
.where(
or_(
and_(
CommentsTable.target_type == TargetType.REPLAY,
CommentsTable.target_id == score_id,
),
and_(
CommentsTable.target_type == TargetType.SONG,
CommentsTable.target_id == map_set_id,
),
and_(
CommentsTable.target_type == TargetType.BEATMAP,
CommentsTable.target_id == map_id,
),
),
)
)
comments = await app.state.services.database.fetch_all(select_stmt)
return cast(list[CommentWithUserPrivileges], comments)

View File

@@ -0,0 +1,75 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
import app.state.services
from app.repositories import Base
class FavouritesTable(Base):
__tablename__ = "favourites"
userid = Column("userid", Integer, nullable=False, primary_key=True)
setid = Column("setid", Integer, nullable=False, primary_key=True)
created_at = Column("created_at", Integer, nullable=False, server_default="0")
READ_PARAMS = (
FavouritesTable.userid,
FavouritesTable.setid,
FavouritesTable.created_at,
)
class Favourite(TypedDict):
userid: int
setid: int
created_at: int
async def create(
userid: int,
setid: int,
) -> Favourite:
"""Create a new favourite mapset entry in the database."""
insert_stmt = insert(FavouritesTable).values(
userid=userid,
setid=setid,
created_at=func.unix_timestamp(),
)
await app.state.services.database.execute(insert_stmt)
select_stmt = (
select(*READ_PARAMS)
.where(FavouritesTable.userid == userid)
.where(FavouritesTable.setid == setid)
)
favourite = await app.state.services.database.fetch_one(select_stmt)
assert favourite is not None
return cast(Favourite, favourite)
async def fetch_all(userid: int) -> list[Favourite]:
"""Fetch all favourites from a player."""
select_stmt = select(*READ_PARAMS).where(FavouritesTable.userid == userid)
favourites = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Favourite], favourites)
async def fetch_one(userid: int, setid: int) -> Favourite | None:
"""Check if a mapset is already a favourite."""
select_stmt = (
select(*READ_PARAMS)
.where(FavouritesTable.userid == userid)
.where(FavouritesTable.setid == setid)
)
favourite = await app.state.services.database.fetch_one(select_stmt)
return cast(Favourite | None, favourite)

View File

@@ -0,0 +1,128 @@
from __future__ import annotations
from datetime import date
from datetime import datetime
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Date
from sqlalchemy import DateTime
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
import app.state.services
from app.repositories import Base
class IngameLoginsTable(Base):
__tablename__ = "ingame_logins"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
userid = Column("userid", Integer, nullable=False)
ip = Column("ip", String(45), nullable=False)
osu_ver = Column("osu_ver", Date, nullable=False)
osu_stream = Column("osu_stream", String(11), nullable=False)
datetime = Column("datetime", DateTime, nullable=False)
READ_PARAMS = (
IngameLoginsTable.id,
IngameLoginsTable.userid,
IngameLoginsTable.ip,
IngameLoginsTable.osu_ver,
IngameLoginsTable.osu_stream,
IngameLoginsTable.datetime,
)
class IngameLogin(TypedDict):
id: int
userid: str
ip: str
osu_ver: date
osu_stream: str
datetime: datetime
class InGameLoginUpdateFields(TypedDict, total=False):
userid: str
ip: str
osu_ver: date
osu_stream: str
async def create(
user_id: int,
ip: str,
osu_ver: date,
osu_stream: str,
) -> IngameLogin:
"""Create a new login entry in the database."""
insert_stmt = insert(IngameLoginsTable).values(
userid=user_id,
ip=ip,
osu_ver=osu_ver,
osu_stream=osu_stream,
datetime=func.now(),
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(IngameLoginsTable.id == rec_id)
ingame_login = await app.state.services.database.fetch_one(select_stmt)
assert ingame_login is not None
return cast(IngameLogin, ingame_login)
async def fetch_one(id: int) -> IngameLogin | None:
"""Fetch a login entry from the database."""
select_stmt = select(*READ_PARAMS).where(IngameLoginsTable.id == id)
ingame_login = await app.state.services.database.fetch_one(select_stmt)
return cast(IngameLogin | None, ingame_login)
async def fetch_count(
user_id: int | None = None,
ip: str | None = None,
) -> int:
"""Fetch the number of logins in the database."""
select_stmt = select(func.count().label("count")).select_from(IngameLoginsTable)
if user_id is not None:
select_stmt = select_stmt.where(IngameLoginsTable.userid == user_id)
if ip is not None:
select_stmt = select_stmt.where(IngameLoginsTable.ip == ip)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
user_id: int | None = None,
ip: str | None = None,
osu_ver: date | None = None,
osu_stream: str | None = None,
page: int | None = None,
page_size: int | None = None,
) -> list[IngameLogin]:
"""Fetch a list of logins from the database."""
select_stmt = select(*READ_PARAMS)
if user_id is not None:
select_stmt = select_stmt.where(IngameLoginsTable.userid == user_id)
if ip is not None:
select_stmt = select_stmt.where(IngameLoginsTable.ip == ip)
if osu_ver is not None:
select_stmt = select_stmt.where(IngameLoginsTable.osu_ver == osu_ver)
if osu_stream is not None:
select_stmt = select_stmt.where(IngameLoginsTable.osu_stream == osu_stream)
if page is not None and page_size is not None:
select_stmt.limit(page_size).offset((page - 1) * page_size)
ingame_logins = await app.state.services.database.fetch_all(select_stmt)
return cast(list[IngameLogin], ingame_logins)

70
app/repositories/logs.py Normal file
View File

@@ -0,0 +1,70 @@
from __future__ import annotations
from datetime import datetime
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
import app.state.services
from app.repositories import Base
class LogTable(Base):
__tablename__ = "logs"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
_from = Column("from", Integer, nullable=False)
to = Column("to", Integer, nullable=False)
action = Column("action", String(32), nullable=False)
msg = Column("msg", String(2048, collation="utf8"), nullable=True)
time = Column("time", DateTime, nullable=False, onupdate=func.now())
READ_PARAMS = (
LogTable.id,
LogTable._from.label("from"),
LogTable.to,
LogTable.action,
LogTable.msg,
LogTable.time,
)
class Log(TypedDict):
id: int
_from: int
to: int
action: str
msg: str | None
time: datetime
async def create(
_from: int,
to: int,
action: str,
msg: str,
) -> Log:
"""Create a new log entry in the database."""
insert_stmt = insert(LogTable).values(
{
"from": _from,
"to": to,
"action": action,
"msg": msg,
"time": func.now(),
},
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(LogTable.id == rec_id)
log = await app.state.services.database.fetch_one(select_stmt)
assert log is not None
return cast(Log, log)

113
app/repositories/mail.py Normal file
View File

@@ -0,0 +1,113 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app.repositories import Base
class MailTable(Base):
__tablename__ = "mail"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
from_id = Column("from_id", Integer, nullable=False)
to_id = Column("to_id", Integer, nullable=False)
msg = Column("msg", String(2048, collation="utf8"), nullable=False)
time = Column("time", Integer, nullable=True)
read = Column("read", TINYINT(1), nullable=False, server_default="0")
READ_PARAMS = (
MailTable.id,
MailTable.from_id,
MailTable.to_id,
MailTable.msg,
MailTable.time,
MailTable.read,
)
class Mail(TypedDict):
id: int
from_id: int
to_id: int
msg: str
time: int
read: bool
class MailWithUsernames(Mail):
from_name: str
to_name: str
async def create(from_id: int, to_id: int, msg: str) -> Mail:
"""Create a new mail entry in the database."""
insert_stmt = insert(MailTable).values(
from_id=from_id,
to_id=to_id,
msg=msg,
time=func.unix_timestamp(),
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(MailTable.id == rec_id)
mail = await app.state.services.database.fetch_one(select_stmt)
assert mail is not None
return cast(Mail, mail)
from app.repositories.users import UsersTable
async def fetch_all_mail_to_user(
user_id: int,
read: bool | None = None,
) -> list[MailWithUsernames]:
"""Fetch all of mail to a given target from the database."""
from_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.from_id)
to_subquery = select(UsersTable.name).where(UsersTable.id == MailTable.to_id)
select_stmt = select(
*READ_PARAMS,
from_subquery.label("from_name"),
to_subquery.label("to_name"),
).where(MailTable.to_id == user_id)
if read is not None:
select_stmt = select_stmt.where(MailTable.read == read)
mail = await app.state.services.database.fetch_all(select_stmt)
return cast(list[MailWithUsernames], mail)
async def mark_conversation_as_read(to_id: int, from_id: int) -> list[Mail]:
"""Mark any mail in a user's conversation with another user as read."""
select_stmt = select(*READ_PARAMS).where(
MailTable.to_id == to_id,
MailTable.from_id == from_id,
MailTable.read == False,
)
mail = await app.state.services.database.fetch_all(select_stmt)
if not mail:
return []
update_stmt = (
update(MailTable)
.where(MailTable.to_id == to_id)
.where(MailTable.from_id == from_id)
.where(MailTable.read == False)
.values(read=True)
)
await app.state.services.database.execute(update_stmt)
return cast(list[Mail], mail)

View File

@@ -0,0 +1,97 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Integer
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app.repositories import Base
class MapRequestsTable(Base):
__tablename__ = "map_requests"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
map_id = Column("map_id", Integer, nullable=False)
player_id = Column("player_id", Integer, nullable=False)
datetime = Column("datetime", DateTime, nullable=False)
active = Column("active", TINYINT(1), nullable=False)
READ_PARAMS = (
MapRequestsTable.id,
MapRequestsTable.map_id,
MapRequestsTable.player_id,
MapRequestsTable.datetime,
)
class MapRequest(TypedDict):
id: int
map_id: int
player_id: int
datetime: datetime
active: bool
async def create(
map_id: int,
player_id: int,
active: bool,
) -> MapRequest:
"""Create a new map request entry in the database."""
insert_stmt = insert(MapRequestsTable).values(
map_id=map_id,
player_id=player_id,
datetime=func.now(),
active=active,
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(MapRequestsTable.id == rec_id)
map_request = await app.state.services.database.fetch_one(select_stmt)
assert map_request is not None
return cast(MapRequest, map_request)
async def fetch_all(
map_id: int | None = None,
player_id: int | None = None,
active: bool | None = None,
) -> list[MapRequest]:
"""Fetch a list of map requests from the database."""
select_stmt = select(*READ_PARAMS)
if map_id is not None:
select_stmt = select_stmt.where(MapRequestsTable.map_id == map_id)
if player_id is not None:
select_stmt = select_stmt.where(MapRequestsTable.player_id == player_id)
if active is not None:
select_stmt = select_stmt.where(MapRequestsTable.active == active)
map_requests = await app.state.services.database.fetch_all(select_stmt)
return cast(list[MapRequest], map_requests)
async def mark_batch_as_inactive(map_ids: list[Any]) -> list[MapRequest]:
"""Mark a map request as inactive."""
update_stmt = (
update(MapRequestsTable)
.where(MapRequestsTable.map_id.in_(map_ids))
.values(active=False)
)
await app.state.services.database.execute(update_stmt)
select_stmt = select(*READ_PARAMS).where(MapRequestsTable.map_id.in_(map_ids))
map_requests = await app.state.services.database.fetch_all(select_stmt)
return cast(list[MapRequest], map_requests)

370
app/repositories/maps.py Normal file
View File

@@ -0,0 +1,370 @@
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Enum
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.mysql import FLOAT
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
class MapServer(StrEnum):
OSU = "osu!"
PRIVATE = "private"
class MapsTable(Base):
__tablename__ = "maps"
server = Column(
Enum(MapServer, name="server"),
nullable=False,
server_default="osu!",
primary_key=True,
)
id = Column(Integer, nullable=False, primary_key=True)
set_id = Column(Integer, nullable=False)
status = Column(Integer, nullable=False)
md5 = Column(String(32), nullable=False)
artist = Column(String(128, collation="utf8"), nullable=False)
title = Column(String(128, collation="utf8"), nullable=False)
version = Column(String(128, collation="utf8"), nullable=False)
creator = Column(String(19, collation="utf8"), nullable=False)
filename = Column(String(256, collation="utf8"), nullable=False)
last_update = Column(DateTime, nullable=False)
total_length = Column(Integer, nullable=False)
max_combo = Column(Integer, nullable=False)
frozen = Column(TINYINT(1), nullable=False, server_default="0")
plays = Column(Integer, nullable=False, server_default="0")
passes = Column(Integer, nullable=False, server_default="0")
mode = Column(TINYINT(1), nullable=False, server_default="0")
bpm = Column(FLOAT(12, 2), nullable=False, server_default="0.00")
cs = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
ar = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
od = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
hp = Column(FLOAT(4, 2), nullable=False, server_default="0.00")
diff = Column(FLOAT(6, 3), nullable=False, server_default="0.000")
__table_args__ = (
Index("maps_set_id_index", "set_id"),
Index("maps_status_index", "status"),
Index("maps_filename_index", "filename"),
Index("maps_plays_index", "plays"),
Index("maps_mode_index", "mode"),
Index("maps_frozen_index", "frozen"),
Index("maps_md5_uindex", "md5", unique=True),
Index("maps_id_uindex", "id", unique=True),
)
READ_PARAMS = (
MapsTable.id,
MapsTable.server,
MapsTable.set_id,
MapsTable.status,
MapsTable.md5,
MapsTable.artist,
MapsTable.title,
MapsTable.version,
MapsTable.creator,
MapsTable.filename,
MapsTable.last_update,
MapsTable.total_length,
MapsTable.max_combo,
MapsTable.frozen,
MapsTable.plays,
MapsTable.passes,
MapsTable.mode,
MapsTable.bpm,
MapsTable.cs,
MapsTable.ar,
MapsTable.od,
MapsTable.hp,
MapsTable.diff,
)
class Map(TypedDict):
id: int
server: str
set_id: int
status: int
md5: str
artist: str
title: str
version: str
creator: str
filename: str
last_update: datetime
total_length: int
max_combo: int
frozen: bool
plays: int
passes: int
mode: int
bpm: float
cs: float
ar: float
od: float
hp: float
diff: float
async def create(
id: int,
server: str,
set_id: int,
status: int,
md5: str,
artist: str,
title: str,
version: str,
creator: str,
filename: str,
last_update: datetime,
total_length: int,
max_combo: int,
frozen: bool,
plays: int,
passes: int,
mode: int,
bpm: float,
cs: float,
ar: float,
od: float,
hp: float,
diff: float,
) -> Map:
"""Create a new beatmap entry in the database."""
insert_stmt = insert(MapsTable).values(
id=id,
server=server,
set_id=set_id,
status=status,
md5=md5,
artist=artist,
title=title,
version=version,
creator=creator,
filename=filename,
last_update=last_update,
total_length=total_length,
max_combo=max_combo,
frozen=frozen,
plays=plays,
passes=passes,
mode=mode,
bpm=bpm,
cs=cs,
ar=ar,
od=od,
hp=hp,
diff=diff,
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(MapsTable.id == rec_id)
map = await app.state.services.database.fetch_one(select_stmt)
assert map is not None
return cast(Map, map)
async def fetch_one(
id: int | None = None,
md5: str | None = None,
filename: str | None = None,
) -> Map | None:
"""Fetch a beatmap entry from the database."""
if id is None and md5 is None and filename is None:
raise ValueError("Must provide at least one parameter.")
select_stmt = select(*READ_PARAMS)
if id is not None:
select_stmt = select_stmt.where(MapsTable.id == id)
if md5 is not None:
select_stmt = select_stmt.where(MapsTable.md5 == md5)
if filename is not None:
select_stmt = select_stmt.where(MapsTable.filename == filename)
map = await app.state.services.database.fetch_one(select_stmt)
return cast(Map | None, map)
async def fetch_count(
server: str | None = None,
set_id: int | None = None,
status: int | None = None,
artist: str | None = None,
creator: str | None = None,
filename: str | None = None,
mode: int | None = None,
frozen: bool | None = None,
) -> int:
"""Fetch the number of maps in the database."""
select_stmt = select(func.count().label("count")).select_from(MapsTable)
if server is not None:
select_stmt = select_stmt.where(MapsTable.server == server)
if set_id is not None:
select_stmt = select_stmt.where(MapsTable.set_id == set_id)
if status is not None:
select_stmt = select_stmt.where(MapsTable.status == status)
if artist is not None:
select_stmt = select_stmt.where(MapsTable.artist == artist)
if creator is not None:
select_stmt = select_stmt.where(MapsTable.creator == creator)
if filename is not None:
select_stmt = select_stmt.where(MapsTable.filename == filename)
if mode is not None:
select_stmt = select_stmt.where(MapsTable.mode == mode)
if frozen is not None:
select_stmt = select_stmt.where(MapsTable.frozen == frozen)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
server: str | None = None,
set_id: int | None = None,
status: int | None = None,
artist: str | None = None,
creator: str | None = None,
filename: str | None = None,
mode: int | None = None,
frozen: bool | None = None,
page: int | None = None,
page_size: int | None = None,
) -> list[Map]:
"""Fetch a list of maps from the database."""
select_stmt = select(*READ_PARAMS)
if server is not None:
select_stmt = select_stmt.where(MapsTable.server == server)
if set_id is not None:
select_stmt = select_stmt.where(MapsTable.set_id == set_id)
if status is not None:
select_stmt = select_stmt.where(MapsTable.status == status)
if artist is not None:
select_stmt = select_stmt.where(MapsTable.artist == artist)
if creator is not None:
select_stmt = select_stmt.where(MapsTable.creator == creator)
if filename is not None:
select_stmt = select_stmt.where(MapsTable.filename == filename)
if mode is not None:
select_stmt = select_stmt.where(MapsTable.mode == mode)
if frozen is not None:
select_stmt = select_stmt.where(MapsTable.frozen == frozen)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
maps = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Map], maps)
async def partial_update(
id: int,
server: str | _UnsetSentinel = UNSET,
set_id: int | _UnsetSentinel = UNSET,
status: int | _UnsetSentinel = UNSET,
md5: str | _UnsetSentinel = UNSET,
artist: str | _UnsetSentinel = UNSET,
title: str | _UnsetSentinel = UNSET,
version: str | _UnsetSentinel = UNSET,
creator: str | _UnsetSentinel = UNSET,
filename: str | _UnsetSentinel = UNSET,
last_update: datetime | _UnsetSentinel = UNSET,
total_length: int | _UnsetSentinel = UNSET,
max_combo: int | _UnsetSentinel = UNSET,
frozen: bool | _UnsetSentinel = UNSET,
plays: int | _UnsetSentinel = UNSET,
passes: int | _UnsetSentinel = UNSET,
mode: int | _UnsetSentinel = UNSET,
bpm: float | _UnsetSentinel = UNSET,
cs: float | _UnsetSentinel = UNSET,
ar: float | _UnsetSentinel = UNSET,
od: float | _UnsetSentinel = UNSET,
hp: float | _UnsetSentinel = UNSET,
diff: float | _UnsetSentinel = UNSET,
) -> Map | None:
"""Update a beatmap entry in the database."""
update_stmt = update(MapsTable).where(MapsTable.id == id)
if not isinstance(server, _UnsetSentinel):
update_stmt = update_stmt.values(server=server)
if not isinstance(set_id, _UnsetSentinel):
update_stmt = update_stmt.values(set_id=set_id)
if not isinstance(status, _UnsetSentinel):
update_stmt = update_stmt.values(status=status)
if not isinstance(md5, _UnsetSentinel):
update_stmt = update_stmt.values(md5=md5)
if not isinstance(artist, _UnsetSentinel):
update_stmt = update_stmt.values(artist=artist)
if not isinstance(title, _UnsetSentinel):
update_stmt = update_stmt.values(title=title)
if not isinstance(version, _UnsetSentinel):
update_stmt = update_stmt.values(version=version)
if not isinstance(creator, _UnsetSentinel):
update_stmt = update_stmt.values(creator=creator)
if not isinstance(filename, _UnsetSentinel):
update_stmt = update_stmt.values(filename=filename)
if not isinstance(last_update, _UnsetSentinel):
update_stmt = update_stmt.values(last_update=last_update)
if not isinstance(total_length, _UnsetSentinel):
update_stmt = update_stmt.values(total_length=total_length)
if not isinstance(max_combo, _UnsetSentinel):
update_stmt = update_stmt.values(max_combo=max_combo)
if not isinstance(frozen, _UnsetSentinel):
update_stmt = update_stmt.values(frozen=frozen)
if not isinstance(plays, _UnsetSentinel):
update_stmt = update_stmt.values(plays=plays)
if not isinstance(passes, _UnsetSentinel):
update_stmt = update_stmt.values(passes=passes)
if not isinstance(mode, _UnsetSentinel):
update_stmt = update_stmt.values(mode=mode)
if not isinstance(bpm, _UnsetSentinel):
update_stmt = update_stmt.values(bpm=bpm)
if not isinstance(cs, _UnsetSentinel):
update_stmt = update_stmt.values(cs=cs)
if not isinstance(ar, _UnsetSentinel):
update_stmt = update_stmt.values(ar=ar)
if not isinstance(od, _UnsetSentinel):
update_stmt = update_stmt.values(od=od)
if not isinstance(hp, _UnsetSentinel):
update_stmt = update_stmt.values(hp=hp)
if not isinstance(diff, _UnsetSentinel):
update_stmt = update_stmt.values(diff=diff)
await app.state.services.database.execute(update_stmt)
select_stmt = select(*READ_PARAMS).where(MapsTable.id == id)
map = await app.state.services.database.fetch_one(select_stmt)
return cast(Map | None, map)
async def delete_one(id: int) -> Map | None:
"""Delete a beatmap entry from the database."""
select_stmt = select(*READ_PARAMS).where(MapsTable.id == id)
map = await app.state.services.database.fetch_one(select_stmt)
if map is None:
return None
delete_stmt = delete(MapsTable).where(MapsTable.id == id)
await app.state.services.database.execute(delete_stmt)
return cast(Map, map)

View File

@@ -0,0 +1,85 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app.repositories import Base
class RatingsTable(Base):
__tablename__ = "ratings"
userid = Column("userid", Integer, nullable=False, primary_key=True)
map_md5 = Column("map_md5", String(32), nullable=False, primary_key=True)
rating = Column("rating", TINYINT(2), nullable=False)
READ_PARAMS = (
RatingsTable.userid,
RatingsTable.map_md5,
RatingsTable.rating,
)
class Rating(TypedDict):
userid: int
map_md5: str
rating: int
async def create(userid: int, map_md5: str, rating: int) -> Rating:
"""Create a new rating."""
insert_stmt = insert(RatingsTable).values(
userid=userid,
map_md5=map_md5,
rating=rating,
)
await app.state.services.database.execute(insert_stmt)
select_stmt = (
select(*READ_PARAMS)
.where(RatingsTable.userid == userid)
.where(RatingsTable.map_md5 == map_md5)
)
_rating = await app.state.services.database.fetch_one(select_stmt)
assert _rating is not None
return cast(Rating, _rating)
async def fetch_many(
userid: int | None = None,
map_md5: str | None = None,
page: int | None = 1,
page_size: int | None = 50,
) -> list[Rating]:
"""Fetch multiple ratings, optionally with filter params and pagination."""
select_stmt = select(*READ_PARAMS)
if userid is not None:
select_stmt = select_stmt.where(RatingsTable.userid == userid)
if map_md5 is not None:
select_stmt = select_stmt.where(RatingsTable.map_md5 == map_md5)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
ratings = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Rating], ratings)
async def fetch_one(userid: int, map_md5: str) -> Rating | None:
"""Fetch a single rating for a given user and map."""
select_stmt = (
select(*READ_PARAMS)
.where(RatingsTable.userid == userid)
.where(RatingsTable.map_md5 == map_md5)
)
rating = await app.state.services.database.fetch_one(select_stmt)
return cast(Rating | None, rating)

246
app/repositories/scores.py Normal file
View File

@@ -0,0 +1,246 @@
from __future__ import annotations
from datetime import datetime
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.mysql import FLOAT
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
class ScoresTable(Base):
__tablename__ = "scores"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
map_md5 = Column("map_md5", String(32), nullable=False)
score = Column("score", Integer, nullable=False)
pp = Column("pp", FLOAT(precision=6, scale=3), nullable=False)
acc = Column("acc", FLOAT(precision=6, scale=3), nullable=False)
max_combo = Column("max_combo", Integer, nullable=False)
mods = Column("mods", Integer, nullable=False)
n300 = Column("n300", Integer, nullable=False)
n100 = Column("n100", Integer, nullable=False)
n50 = Column("n50", Integer, nullable=False)
nmiss = Column("nmiss", Integer, nullable=False)
ngeki = Column("ngeki", Integer, nullable=False)
nkatu = Column("nkatu", Integer, nullable=False)
grade = Column("grade", String(2), nullable=False, server_default="N")
status = Column("status", Integer, nullable=False)
mode = Column("mode", Integer, nullable=False)
play_time = Column("play_time", DateTime, nullable=False)
time_elapsed = Column("time_elapsed", Integer, nullable=False)
client_flags = Column("client_flags", Integer, nullable=False)
userid = Column("userid", Integer, nullable=False)
perfect = Column("perfect", TINYINT(1), nullable=False)
online_checksum = Column("online_checksum", String(32), nullable=False)
__table_args__ = (
Index("scores_map_md5_index", map_md5),
Index("scores_score_index", score),
Index("scores_pp_index", pp),
Index("scores_mods_index", mods),
Index("scores_status_index", status),
Index("scores_mode_index", mode),
Index("scores_play_time_index", play_time),
Index("scores_userid_index", userid),
Index("scores_online_checksum_index", online_checksum),
)
READ_PARAMS = (
ScoresTable.id,
ScoresTable.map_md5,
ScoresTable.score,
ScoresTable.pp,
ScoresTable.acc,
ScoresTable.max_combo,
ScoresTable.mods,
ScoresTable.n300,
ScoresTable.n100,
ScoresTable.n50,
ScoresTable.nmiss,
ScoresTable.ngeki,
ScoresTable.nkatu,
ScoresTable.grade,
ScoresTable.status,
ScoresTable.mode,
ScoresTable.play_time,
ScoresTable.time_elapsed,
ScoresTable.client_flags,
ScoresTable.userid,
ScoresTable.perfect,
ScoresTable.online_checksum,
)
class Score(TypedDict):
id: int
map_md5: str
score: int
pp: float
acc: float
max_combo: int
mods: int
n300: int
n100: int
n50: int
nmiss: int
ngeki: int
nkatu: int
grade: str
status: int
mode: int
play_time: datetime
time_elapsed: int
client_flags: int
userid: int
perfect: int
online_checksum: str
async def create(
map_md5: str,
score: int,
pp: float,
acc: float,
max_combo: int,
mods: int,
n300: int,
n100: int,
n50: int,
nmiss: int,
ngeki: int,
nkatu: int,
grade: str,
status: int,
mode: int,
play_time: datetime,
time_elapsed: int,
client_flags: int,
user_id: int,
perfect: int,
online_checksum: str,
) -> Score:
insert_stmt = insert(ScoresTable).values(
map_md5=map_md5,
score=score,
pp=pp,
acc=acc,
max_combo=max_combo,
mods=mods,
n300=n300,
n100=n100,
n50=n50,
nmiss=nmiss,
ngeki=ngeki,
nkatu=nkatu,
grade=grade,
status=status,
mode=mode,
play_time=play_time,
time_elapsed=time_elapsed,
client_flags=client_flags,
userid=user_id,
perfect=perfect,
online_checksum=online_checksum,
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(ScoresTable.id == rec_id)
_score = await app.state.services.database.fetch_one(select_stmt)
assert _score is not None
return cast(Score, _score)
async def fetch_one(id: int) -> Score | None:
select_stmt = select(*READ_PARAMS).where(ScoresTable.id == id)
_score = await app.state.services.database.fetch_one(select_stmt)
return cast(Score | None, _score)
async def fetch_count(
map_md5: str | None = None,
mods: int | None = None,
status: int | None = None,
mode: int | None = None,
user_id: int | None = None,
) -> int:
select_stmt = select(func.count().label("count")).select_from(ScoresTable)
if map_md5 is not None:
select_stmt = select_stmt.where(ScoresTable.map_md5 == map_md5)
if mods is not None:
select_stmt = select_stmt.where(ScoresTable.mods == mods)
if status is not None:
select_stmt = select_stmt.where(ScoresTable.status == status)
if mode is not None:
select_stmt = select_stmt.where(ScoresTable.mode == mode)
if user_id is not None:
select_stmt = select_stmt.where(ScoresTable.userid == user_id)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
map_md5: str | None = None,
mods: int | None = None,
status: int | None = None,
mode: int | None = None,
user_id: int | None = None,
page: int | None = None,
page_size: int | None = None,
) -> list[Score]:
select_stmt = select(*READ_PARAMS)
if map_md5 is not None:
select_stmt = select_stmt.where(ScoresTable.map_md5 == map_md5)
if mods is not None:
select_stmt = select_stmt.where(ScoresTable.mods == mods)
if status is not None:
select_stmt = select_stmt.where(ScoresTable.status == status)
if mode is not None:
select_stmt = select_stmt.where(ScoresTable.mode == mode)
if user_id is not None:
select_stmt = select_stmt.where(ScoresTable.userid == user_id)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
scores = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Score], scores)
async def partial_update(
id: int,
pp: float | _UnsetSentinel = UNSET,
status: int | _UnsetSentinel = UNSET,
) -> Score | None:
"""Update an existing score."""
update_stmt = update(ScoresTable).where(ScoresTable.id == id)
if not isinstance(pp, _UnsetSentinel):
update_stmt = update_stmt.values(pp=pp)
if not isinstance(status, _UnsetSentinel):
update_stmt = update_stmt.values(status=status)
await app.state.services.database.execute(update_stmt)
select_stmt = select(*READ_PARAMS).where(ScoresTable.id == id)
_score = await app.state.services.database.fetch_one(select_stmt)
return cast(Score | None, _score)
# TODO: delete

237
app/repositories/stats.py Normal file
View File

@@ -0,0 +1,237 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.mysql import FLOAT
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
class StatsTable(Base):
__tablename__ = "stats"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
mode = Column("mode", TINYINT(1), primary_key=True)
tscore = Column("tscore", Integer, nullable=False, server_default="0")
rscore = Column("rscore", Integer, nullable=False, server_default="0")
pp = Column("pp", Integer, nullable=False, server_default="0")
plays = Column("plays", Integer, nullable=False, server_default="0")
playtime = Column("playtime", Integer, nullable=False, server_default="0")
acc = Column(
"acc",
FLOAT(precision=6, scale=3),
nullable=False,
server_default="0.000",
)
max_combo = Column("max_combo", Integer, nullable=False, server_default="0")
total_hits = Column("total_hits", Integer, nullable=False, server_default="0")
replay_views = Column("replay_views", Integer, nullable=False, server_default="0")
xh_count = Column("xh_count", Integer, nullable=False, server_default="0")
x_count = Column("x_count", Integer, nullable=False, server_default="0")
sh_count = Column("sh_count", Integer, nullable=False, server_default="0")
s_count = Column("s_count", Integer, nullable=False, server_default="0")
a_count = Column("a_count", Integer, nullable=False, server_default="0")
__table_args__ = (
Index("stats_mode_index", mode),
Index("stats_pp_index", pp),
Index("stats_tscore_index", tscore),
Index("stats_rscore_index", rscore),
)
READ_PARAMS = (
StatsTable.id,
StatsTable.mode,
StatsTable.tscore,
StatsTable.rscore,
StatsTable.pp,
StatsTable.plays,
StatsTable.playtime,
StatsTable.acc,
StatsTable.max_combo,
StatsTable.total_hits,
StatsTable.replay_views,
StatsTable.xh_count,
StatsTable.x_count,
StatsTable.sh_count,
StatsTable.s_count,
StatsTable.a_count,
)
class Stat(TypedDict):
id: int
mode: int
tscore: int
rscore: int
pp: int
plays: int
playtime: int
acc: float
max_combo: int
total_hits: int
replay_views: int
xh_count: int
x_count: int
sh_count: int
s_count: int
a_count: int
async def create(player_id: int, mode: int) -> Stat:
"""Create a new player stats entry in the database."""
insert_stmt = insert(StatsTable).values(id=player_id, mode=mode)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(StatsTable.id == rec_id)
stat = await app.state.services.database.fetch_one(select_stmt)
assert stat is not None
return cast(Stat, stat)
async def create_all_modes(player_id: int) -> list[Stat]:
"""Create new player stats entries for each game mode in the database."""
insert_stmt = insert(StatsTable).values(
[
{"id": player_id, "mode": mode}
for mode in (
0, # vn!std
1, # vn!taiko
2, # vn!catch
3, # vn!mania
4, # rx!std
5, # rx!taiko
6, # rx!catch
8, # ap!std
)
],
)
await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(StatsTable.id == player_id)
stats = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Stat], stats)
async def fetch_one(player_id: int, mode: int) -> Stat | None:
"""Fetch a player stats entry from the database."""
select_stmt = (
select(*READ_PARAMS)
.where(StatsTable.id == player_id)
.where(StatsTable.mode == mode)
)
stat = await app.state.services.database.fetch_one(select_stmt)
return cast(Stat | None, stat)
async def fetch_count(
player_id: int | None = None,
mode: int | None = None,
) -> int:
select_stmt = select(func.count().label("count")).select_from(StatsTable)
if player_id is not None:
select_stmt = select_stmt.where(StatsTable.id == player_id)
if mode is not None:
select_stmt = select_stmt.where(StatsTable.mode == mode)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
player_id: int | None = None,
mode: int | None = None,
page: int | None = None,
page_size: int | None = None,
) -> list[Stat]:
select_stmt = select(*READ_PARAMS)
if player_id is not None:
select_stmt = select_stmt.where(StatsTable.id == player_id)
if mode is not None:
select_stmt = select_stmt.where(StatsTable.mode == mode)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
stats = await app.state.services.database.fetch_all(select_stmt)
return cast(list[Stat], stats)
async def partial_update(
player_id: int,
mode: int,
tscore: int | _UnsetSentinel = UNSET,
rscore: int | _UnsetSentinel = UNSET,
pp: int | _UnsetSentinel = UNSET,
plays: int | _UnsetSentinel = UNSET,
playtime: int | _UnsetSentinel = UNSET,
acc: float | _UnsetSentinel = UNSET,
max_combo: int | _UnsetSentinel = UNSET,
total_hits: int | _UnsetSentinel = UNSET,
replay_views: int | _UnsetSentinel = UNSET,
xh_count: int | _UnsetSentinel = UNSET,
x_count: int | _UnsetSentinel = UNSET,
sh_count: int | _UnsetSentinel = UNSET,
s_count: int | _UnsetSentinel = UNSET,
a_count: int | _UnsetSentinel = UNSET,
) -> Stat | None:
"""Update a player stats entry in the database."""
update_stmt = (
update(StatsTable)
.where(StatsTable.id == player_id)
.where(StatsTable.mode == mode)
)
if not isinstance(tscore, _UnsetSentinel):
update_stmt = update_stmt.values(tscore=tscore)
if not isinstance(rscore, _UnsetSentinel):
update_stmt = update_stmt.values(rscore=rscore)
if not isinstance(pp, _UnsetSentinel):
update_stmt = update_stmt.values(pp=pp)
if not isinstance(plays, _UnsetSentinel):
update_stmt = update_stmt.values(plays=plays)
if not isinstance(playtime, _UnsetSentinel):
update_stmt = update_stmt.values(playtime=playtime)
if not isinstance(acc, _UnsetSentinel):
update_stmt = update_stmt.values(acc=acc)
if not isinstance(max_combo, _UnsetSentinel):
update_stmt = update_stmt.values(max_combo=max_combo)
if not isinstance(total_hits, _UnsetSentinel):
update_stmt = update_stmt.values(total_hits=total_hits)
if not isinstance(replay_views, _UnsetSentinel):
update_stmt = update_stmt.values(replay_views=replay_views)
if not isinstance(xh_count, _UnsetSentinel):
update_stmt = update_stmt.values(xh_count=xh_count)
if not isinstance(x_count, _UnsetSentinel):
update_stmt = update_stmt.values(x_count=x_count)
if not isinstance(sh_count, _UnsetSentinel):
update_stmt = update_stmt.values(sh_count=sh_count)
if not isinstance(s_count, _UnsetSentinel):
update_stmt = update_stmt.values(s_count=s_count)
if not isinstance(a_count, _UnsetSentinel):
update_stmt = update_stmt.values(a_count=a_count)
await app.state.services.database.execute(update_stmt)
select_stmt = (
select(*READ_PARAMS)
.where(StatsTable.id == player_id)
.where(StatsTable.mode == mode)
)
stat = await app.state.services.database.fetch_one(select_stmt)
return cast(Stat | None, stat)
# TODO: delete?

View File

@@ -0,0 +1,137 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import delete
from sqlalchemy import insert
from sqlalchemy import select
import app.state.services
from app.repositories import Base
class TourneyPoolMapsTable(Base):
__tablename__ = "tourney_pool_maps"
map_id = Column("map_id", Integer, nullable=False, primary_key=True)
pool_id = Column("pool_id", Integer, nullable=False, primary_key=True)
mods = Column("mods", Integer, nullable=False)
slot = Column("slot", Integer, nullable=False)
__table_args__ = (
Index("tourney_pool_maps_mods_slot_index", mods, slot),
Index("tourney_pool_maps_tourney_pools_id_fk", pool_id),
)
READ_PARAMS = (
TourneyPoolMapsTable.map_id,
TourneyPoolMapsTable.pool_id,
TourneyPoolMapsTable.mods,
TourneyPoolMapsTable.slot,
)
class TourneyPoolMap(TypedDict):
map_id: int
pool_id: int
mods: int
slot: int
async def create(map_id: int, pool_id: int, mods: int, slot: int) -> TourneyPoolMap:
"""Create a new map pool entry in the database."""
insert_stmt = insert(TourneyPoolMapsTable).values(
map_id=map_id,
pool_id=pool_id,
mods=mods,
slot=slot,
)
await app.state.services.database.execute(insert_stmt)
select_stmt = (
select(*READ_PARAMS)
.where(TourneyPoolMapsTable.map_id == map_id)
.where(TourneyPoolMapsTable.pool_id == pool_id)
)
tourney_pool_map = await app.state.services.database.fetch_one(select_stmt)
assert tourney_pool_map is not None
return cast(TourneyPoolMap, tourney_pool_map)
async def fetch_many(
pool_id: int | None = None,
mods: int | None = None,
slot: int | None = None,
page: int | None = 1,
page_size: int | None = 50,
) -> list[TourneyPoolMap]:
"""Fetch a list of map pool entries from the database."""
select_stmt = select(*READ_PARAMS)
if pool_id is not None:
select_stmt = select_stmt.where(TourneyPoolMapsTable.pool_id == pool_id)
if mods is not None:
select_stmt = select_stmt.where(TourneyPoolMapsTable.mods == mods)
if slot is not None:
select_stmt = select_stmt.where(TourneyPoolMapsTable.slot == slot)
if page and page_size:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
tourney_pool_maps = await app.state.services.database.fetch_all(select_stmt)
return cast(list[TourneyPoolMap], tourney_pool_maps)
async def fetch_by_pool_and_pick(
pool_id: int,
mods: int,
slot: int,
) -> TourneyPoolMap | None:
"""Fetch a map pool entry by pool and pick from the database."""
select_stmt = (
select(*READ_PARAMS)
.where(TourneyPoolMapsTable.pool_id == pool_id)
.where(TourneyPoolMapsTable.mods == mods)
.where(TourneyPoolMapsTable.slot == slot)
)
tourney_pool_map = await app.state.services.database.fetch_one(select_stmt)
return cast(TourneyPoolMap | None, tourney_pool_map)
async def delete_map_from_pool(pool_id: int, map_id: int) -> TourneyPoolMap | None:
"""Delete a map pool entry from a given tourney pool from the database."""
select_stmt = (
select(*READ_PARAMS)
.where(TourneyPoolMapsTable.pool_id == pool_id)
.where(TourneyPoolMapsTable.map_id == map_id)
)
tourney_pool_map = await app.state.services.database.fetch_one(select_stmt)
if tourney_pool_map is None:
return None
delete_stmt = (
delete(TourneyPoolMapsTable)
.where(TourneyPoolMapsTable.pool_id == pool_id)
.where(TourneyPoolMapsTable.map_id == map_id)
)
await app.state.services.database.execute(delete_stmt)
return cast(TourneyPoolMap, tourney_pool_map)
async def delete_all_in_pool(pool_id: int) -> list[TourneyPoolMap]:
"""Delete all map pool entries from a given tourney pool from the database."""
select_stmt = select(*READ_PARAMS).where(TourneyPoolMapsTable.pool_id == pool_id)
tourney_pool_maps = await app.state.services.database.fetch_all(select_stmt)
if not tourney_pool_maps:
return []
delete_stmt = delete(TourneyPoolMapsTable).where(
TourneyPoolMapsTable.pool_id == pool_id,
)
await app.state.services.database.execute(delete_stmt)
return cast(list[TourneyPoolMap], tourney_pool_maps)

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
from datetime import datetime
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import DateTime
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
import app.state.services
from app.repositories import Base
class TourneyPoolsTable(Base):
__tablename__ = "tourney_pools"
id = Column("id", Integer, nullable=False, primary_key=True, autoincrement=True)
name = Column("name", String(16), nullable=False)
created_at = Column("created_at", DateTime, nullable=False)
created_by = Column("created_by", Integer, nullable=False)
__table_args__ = (Index("tourney_pools_users_id_fk", created_by),)
class TourneyPool(TypedDict):
id: int
name: str
created_at: datetime
created_by: int
READ_PARAMS = (
TourneyPoolsTable.id,
TourneyPoolsTable.name,
TourneyPoolsTable.created_at,
TourneyPoolsTable.created_by,
)
async def create(name: str, created_by: int) -> TourneyPool:
"""Create a new tourney pool entry in the database."""
insert_stmt = insert(TourneyPoolsTable).values(
name=name,
created_at=func.now(),
created_by=created_by,
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == rec_id)
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
assert tourney_pool is not None
return cast(TourneyPool, tourney_pool)
async def fetch_many(
id: int | None = None,
created_by: int | None = None,
page: int | None = 1,
page_size: int | None = 50,
) -> list[TourneyPool]:
"""Fetch many tourney pools from the database."""
select_stmt = select(*READ_PARAMS)
if id is not None:
select_stmt = select_stmt.where(TourneyPoolsTable.id == id)
if created_by is not None:
select_stmt = select_stmt.where(TourneyPoolsTable.created_by == created_by)
if page and page_size:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
tourney_pools = await app.state.services.database.fetch_all(select_stmt)
return cast(list[TourneyPool], tourney_pools)
async def fetch_by_name(name: str) -> TourneyPool | None:
"""Fetch a tourney pool by name from the database."""
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.name == name)
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
return cast(TourneyPool | None, tourney_pool)
async def fetch_by_id(id: int) -> TourneyPool | None:
"""Fetch a tourney pool by id from the database."""
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == id)
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
return cast(TourneyPool | None, tourney_pool)
async def delete_by_id(id: int) -> TourneyPool | None:
"""Delete a tourney pool by id from the database."""
select_stmt = select(*READ_PARAMS).where(TourneyPoolsTable.id == id)
tourney_pool = await app.state.services.database.fetch_one(select_stmt)
if tourney_pool is None:
return None
delete_stmt = delete(TourneyPoolsTable).where(TourneyPoolsTable.id == id)
await app.state.services.database.execute(delete_stmt)
return cast(TourneyPool, tourney_pool)

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import insert
from sqlalchemy import select
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
class UserAchievementsTable(Base):
__tablename__ = "user_achievements"
userid = Column("userid", Integer, nullable=False, primary_key=True)
achid = Column("achid", Integer, nullable=False, primary_key=True)
__table_args__ = (
Index("user_achievements_achid_index", achid),
Index("user_achievements_userid_index", userid),
)
READ_PARAMS = (
UserAchievementsTable.userid,
UserAchievementsTable.achid,
)
class UserAchievement(TypedDict):
userid: int
achid: int
async def create(user_id: int, achievement_id: int) -> UserAchievement:
"""Creates a new user achievement entry."""
insert_stmt = insert(UserAchievementsTable).values(
userid=user_id,
achid=achievement_id,
)
await app.state.services.database.execute(insert_stmt)
select_stmt = (
select(*READ_PARAMS)
.where(UserAchievementsTable.userid == user_id)
.where(UserAchievementsTable.achid == achievement_id)
)
user_achievement = await app.state.services.database.fetch_one(select_stmt)
assert user_achievement is not None
return cast(UserAchievement, user_achievement)
async def fetch_many(
user_id: int | _UnsetSentinel = UNSET,
achievement_id: int | _UnsetSentinel = UNSET,
page: int | None = None,
page_size: int | None = None,
) -> list[UserAchievement]:
"""Fetch a list of user achievements."""
select_stmt = select(*READ_PARAMS)
if not isinstance(user_id, _UnsetSentinel):
select_stmt = select_stmt.where(UserAchievementsTable.userid == user_id)
if not isinstance(achievement_id, _UnsetSentinel):
select_stmt = select_stmt.where(UserAchievementsTable.achid == achievement_id)
if page and page_size:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
user_achievements = await app.state.services.database.fetch_all(select_stmt)
return cast(list[UserAchievement], user_achievements)
# TODO: delete?

270
app/repositories/users.py Normal file
View File

@@ -0,0 +1,270 @@
from __future__ import annotations
from typing import TypedDict
from typing import cast
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import func
from sqlalchemy import insert
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.mysql import TINYINT
import app.state.services
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories import Base
from app.utils import make_safe_name
class UsersTable(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
name = Column(String(32, collation="utf8"), nullable=False)
safe_name = Column(String(32, collation="utf8"), nullable=False)
email = Column(String(254), nullable=False)
priv = Column(Integer, nullable=False, server_default="1")
pw_bcrypt = Column(String(60), nullable=False)
country = Column(String(2), nullable=False, server_default="xx")
silence_end = Column(Integer, nullable=False, server_default="0")
donor_end = Column(Integer, nullable=False, server_default="0")
creation_time = Column(Integer, nullable=False, server_default="0")
latest_activity = Column(Integer, nullable=False, server_default="0")
clan_id = Column(Integer, nullable=False, server_default="0")
clan_priv = Column(TINYINT, nullable=False, server_default="0")
preferred_mode = Column(Integer, nullable=False, server_default="0")
play_style = Column(Integer, nullable=False, server_default="0")
custom_badge_name = Column(String(16, collation="utf8"))
custom_badge_icon = Column(String(64))
userpage_content = Column(String(2048, collation="utf8"))
api_key = Column(String(36))
__table_args__ = (
Index("users_priv_index", priv),
Index("users_clan_id_index", clan_id),
Index("users_clan_priv_index", clan_priv),
Index("users_country_index", country),
Index("users_api_key_uindex", api_key, unique=True),
Index("users_email_uindex", email, unique=True),
Index("users_name_uindex", name, unique=True),
Index("users_safe_name_uindex", safe_name, unique=True),
)
READ_PARAMS = (
UsersTable.id,
UsersTable.name,
UsersTable.safe_name,
UsersTable.priv,
UsersTable.country,
UsersTable.silence_end,
UsersTable.donor_end,
UsersTable.creation_time,
UsersTable.latest_activity,
UsersTable.clan_id,
UsersTable.clan_priv,
UsersTable.preferred_mode,
UsersTable.play_style,
UsersTable.custom_badge_name,
UsersTable.custom_badge_icon,
UsersTable.userpage_content,
)
class User(TypedDict):
id: int
name: str
safe_name: str
priv: int
pw_bcrypt: str
country: str
silence_end: int
donor_end: int
creation_time: int
latest_activity: int
clan_id: int
clan_priv: int
preferred_mode: int
play_style: int
custom_badge_name: str | None
custom_badge_icon: str | None
userpage_content: str | None
api_key: str | None
async def create(
name: str,
email: str,
pw_bcrypt: bytes,
country: str,
) -> User:
"""Create a new user in the database."""
insert_stmt = insert(UsersTable).values(
name=name,
safe_name=make_safe_name(name),
email=email,
pw_bcrypt=pw_bcrypt,
country=country,
creation_time=func.unix_timestamp(),
latest_activity=func.unix_timestamp(),
)
rec_id = await app.state.services.database.execute(insert_stmt)
select_stmt = select(*READ_PARAMS).where(UsersTable.id == rec_id)
user = await app.state.services.database.fetch_one(select_stmt)
assert user is not None
return cast(User, user)
async def fetch_one(
id: int | None = None,
name: str | None = None,
email: str | None = None,
fetch_all_fields: bool = False, # TODO: probably remove this if possible
) -> User | None:
"""Fetch a single user from the database."""
if id is None and name is None and email is None:
raise ValueError("Must provide at least one parameter.")
if fetch_all_fields:
select_stmt = select(UsersTable)
else:
select_stmt = select(*READ_PARAMS)
if id is not None:
select_stmt = select_stmt.where(UsersTable.id == id)
if name is not None:
select_stmt = select_stmt.where(UsersTable.safe_name == make_safe_name(name))
if email is not None:
select_stmt = select_stmt.where(UsersTable.email == email)
user = await app.state.services.database.fetch_one(select_stmt)
return cast(User | None, user)
async def fetch_count(
priv: int | None = None,
country: str | None = None,
clan_id: int | None = None,
clan_priv: int | None = None,
preferred_mode: int | None = None,
play_style: int | None = None,
) -> int:
"""Fetch the number of users in the database."""
select_stmt = select(func.count().label("count")).select_from(UsersTable)
if priv is not None:
select_stmt = select_stmt.where(UsersTable.priv == priv)
if country is not None:
select_stmt = select_stmt.where(UsersTable.country == country)
if clan_id is not None:
select_stmt = select_stmt.where(UsersTable.clan_id == clan_id)
if clan_priv is not None:
select_stmt = select_stmt.where(UsersTable.clan_priv == clan_priv)
if preferred_mode is not None:
select_stmt = select_stmt.where(UsersTable.preferred_mode == preferred_mode)
if play_style is not None:
select_stmt = select_stmt.where(UsersTable.play_style == play_style)
rec = await app.state.services.database.fetch_one(select_stmt)
assert rec is not None
return cast(int, rec["count"])
async def fetch_many(
priv: int | None = None,
country: str | None = None,
clan_id: int | None = None,
clan_priv: int | None = None,
preferred_mode: int | None = None,
play_style: int | None = None,
page: int | None = None,
page_size: int | None = None,
) -> list[User]:
"""Fetch multiple users from the database."""
select_stmt = select(*READ_PARAMS)
if priv is not None:
select_stmt = select_stmt.where(UsersTable.priv == priv)
if country is not None:
select_stmt = select_stmt.where(UsersTable.country == country)
if clan_id is not None:
select_stmt = select_stmt.where(UsersTable.clan_id == clan_id)
if clan_priv is not None:
select_stmt = select_stmt.where(UsersTable.clan_priv == clan_priv)
if preferred_mode is not None:
select_stmt = select_stmt.where(UsersTable.preferred_mode == preferred_mode)
if play_style is not None:
select_stmt = select_stmt.where(UsersTable.play_style == play_style)
if page is not None and page_size is not None:
select_stmt = select_stmt.limit(page_size).offset((page - 1) * page_size)
users = await app.state.services.database.fetch_all(select_stmt)
return cast(list[User], users)
async def partial_update(
id: int,
name: str | _UnsetSentinel = UNSET,
email: str | _UnsetSentinel = UNSET,
priv: int | _UnsetSentinel = UNSET,
country: str | _UnsetSentinel = UNSET,
silence_end: int | _UnsetSentinel = UNSET,
donor_end: int | _UnsetSentinel = UNSET,
creation_time: _UnsetSentinel | _UnsetSentinel = UNSET,
latest_activity: int | _UnsetSentinel = UNSET,
clan_id: int | _UnsetSentinel = UNSET,
clan_priv: int | _UnsetSentinel = UNSET,
preferred_mode: int | _UnsetSentinel = UNSET,
play_style: int | _UnsetSentinel = UNSET,
custom_badge_name: str | None | _UnsetSentinel = UNSET,
custom_badge_icon: str | None | _UnsetSentinel = UNSET,
userpage_content: str | None | _UnsetSentinel = UNSET,
api_key: str | None | _UnsetSentinel = UNSET,
) -> User | None:
"""Update a user in the database."""
update_stmt = update(UsersTable).where(UsersTable.id == id)
if not isinstance(name, _UnsetSentinel):
update_stmt = update_stmt.values(name=name, safe_name=make_safe_name(name))
if not isinstance(email, _UnsetSentinel):
update_stmt = update_stmt.values(email=email)
if not isinstance(priv, _UnsetSentinel):
update_stmt = update_stmt.values(priv=priv)
if not isinstance(country, _UnsetSentinel):
update_stmt = update_stmt.values(country=country)
if not isinstance(silence_end, _UnsetSentinel):
update_stmt = update_stmt.values(silence_end=silence_end)
if not isinstance(donor_end, _UnsetSentinel):
update_stmt = update_stmt.values(donor_end=donor_end)
if not isinstance(creation_time, _UnsetSentinel):
update_stmt = update_stmt.values(creation_time=creation_time)
if not isinstance(latest_activity, _UnsetSentinel):
update_stmt = update_stmt.values(latest_activity=latest_activity)
if not isinstance(clan_id, _UnsetSentinel):
update_stmt = update_stmt.values(clan_id=clan_id)
if not isinstance(clan_priv, _UnsetSentinel):
update_stmt = update_stmt.values(clan_priv=clan_priv)
if not isinstance(preferred_mode, _UnsetSentinel):
update_stmt = update_stmt.values(preferred_mode=preferred_mode)
if not isinstance(play_style, _UnsetSentinel):
update_stmt = update_stmt.values(play_style=play_style)
if not isinstance(custom_badge_name, _UnsetSentinel):
update_stmt = update_stmt.values(custom_badge_name=custom_badge_name)
if not isinstance(custom_badge_icon, _UnsetSentinel):
update_stmt = update_stmt.values(custom_badge_icon=custom_badge_icon)
if not isinstance(userpage_content, _UnsetSentinel):
update_stmt = update_stmt.values(userpage_content=userpage_content)
if not isinstance(api_key, _UnsetSentinel):
update_stmt = update_stmt.values(api_key=api_key)
await app.state.services.database.execute(update_stmt)
select_stmt = select(*READ_PARAMS).where(UsersTable.id == id)
user = await app.state.services.database.fetch_one(select_stmt)
return cast(User | None, user)
# TODO: delete?

73
app/settings.py Normal file
View File

@@ -0,0 +1,73 @@
from __future__ import annotations
import os
import tomllib
from urllib.parse import quote
from dotenv import load_dotenv
from app.settings_utils import read_bool
from app.settings_utils import read_list
load_dotenv()
APP_HOST = os.environ["APP_HOST"]
APP_PORT = int(os.environ["APP_PORT"])
DB_HOST = os.environ["DB_HOST"]
DB_PORT = int(os.environ["DB_PORT"])
DB_USER = os.environ["DB_USER"]
DB_PASS = quote(os.environ["DB_PASS"])
DB_NAME = os.environ["DB_NAME"]
DB_DSN = f"mysql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}"
REDIS_HOST = os.environ["REDIS_HOST"]
REDIS_PORT = int(os.environ["REDIS_PORT"])
REDIS_USER = os.environ["REDIS_USER"]
REDIS_PASS = quote(os.environ["REDIS_PASS"])
REDIS_DB = int(os.environ["REDIS_DB"])
REDIS_AUTH_STRING = f"{REDIS_USER}:{REDIS_PASS}@" if REDIS_USER and REDIS_PASS else ""
REDIS_DSN = f"redis://{REDIS_AUTH_STRING}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB}"
OSU_API_KEY = os.environ.get("OSU_API_KEY") or None
DOMAIN = os.environ["DOMAIN"]
MIRROR_SEARCH_ENDPOINT = os.environ["MIRROR_SEARCH_ENDPOINT"]
MIRROR_DOWNLOAD_ENDPOINT = os.environ["MIRROR_DOWNLOAD_ENDPOINT"]
COMMAND_PREFIX = os.environ["COMMAND_PREFIX"]
SEASONAL_BGS = read_list(os.environ["SEASONAL_BGS"])
MENU_ICON_URL = os.environ["MENU_ICON_URL"]
MENU_ONCLICK_URL = os.environ["MENU_ONCLICK_URL"]
DATADOG_API_KEY = os.environ["DATADOG_API_KEY"]
DATADOG_APP_KEY = os.environ["DATADOG_APP_KEY"]
DEBUG = read_bool(os.environ["DEBUG"])
REDIRECT_OSU_URLS = read_bool(os.environ["REDIRECT_OSU_URLS"])
PP_CACHED_ACCURACIES = [int(acc) for acc in read_list(os.environ["PP_CACHED_ACCS"])]
DISALLOWED_NAMES = read_list(os.environ["DISALLOWED_NAMES"])
DISALLOWED_PASSWORDS = read_list(os.environ["DISALLOWED_PASSWORDS"])
DISALLOW_OLD_CLIENTS = read_bool(os.environ["DISALLOW_OLD_CLIENTS"])
DISALLOW_INGAME_REGISTRATION = read_bool(os.environ["DISALLOW_INGAME_REGISTRATION"])
DISCORD_AUDIT_LOG_WEBHOOK = os.environ["DISCORD_AUDIT_LOG_WEBHOOK"]
AUTOMATICALLY_REPORT_PROBLEMS = read_bool(os.environ["AUTOMATICALLY_REPORT_PROBLEMS"])
LOG_WITH_COLORS = read_bool(os.environ["LOG_WITH_COLORS"])
# advanced dev settings
## WARNING touch this once you've
## read through what it enables.
## you could put your server at risk.
DEVELOPER_MODE = read_bool(os.environ["DEVELOPER_MODE"])
with open("pyproject.toml", "rb") as f:
VERSION = tomllib.load(f)["tool"]["poetry"]["version"]

48
app/settings_utils.py Normal file
View File

@@ -0,0 +1,48 @@
from __future__ import annotations
import os
from datetime import date
from app.logging import Ansi
from app.logging import log
def read_bool(value: str) -> bool:
return value.lower() in ("true", "1", "yes")
def read_list(value: str) -> list[str]:
return [v.strip() for v in value.split(",")]
def support_deprecated_vars(
new_name: str,
deprecated_name: str,
*,
until: date,
allow_empty_string: bool = False,
) -> str:
val1 = os.getenv(new_name)
if val1:
return val1
val2 = os.getenv(deprecated_name)
if val2:
if until < date.today():
raise ValueError(
f'The "{deprecated_name}" config option has been deprecated as of {until.isoformat()} and is no longer supported. Use {new_name} instead.',
)
log(
f'The "{deprecated_name}" config option has been deprecated and will be supported until {until.isoformat()}. Use {new_name} instead.',
Ansi.LYELLOW,
)
return val2
if allow_empty_string:
if val1 is not None:
return val1
if val2 is not None:
return val2
raise KeyError(f"{new_name} is not set in the environment")

24
app/state/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import Literal
from . import cache
from . import services
from . import sessions
if TYPE_CHECKING:
from asyncio import AbstractEventLoop
from app.packets import BasePacket
from app.packets import ClientPackets
loop: AbstractEventLoop
score_submission_locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
packets: dict[Literal["all", "restricted"], dict[ClientPackets, type[BasePacket]]] = {
"all": {},
"restricted": {},
}
shutting_down = False

14
app/state/cache.py Normal file
View File

@@ -0,0 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.objects.beatmap import Beatmap
from app.objects.beatmap import BeatmapSet
bcrypt: dict[bytes, bytes] = {} # {bcrypt: md5, ...}
beatmap: dict[str | int, Beatmap] = {} # {md5: map, id: map, ...}
beatmapset: dict[int, BeatmapSet] = {} # {bsid: map_set}
unsubmitted: set[str] = set() # {md5, ...}
needs_update: set[str] = set() # {md5, ...}

492
app/state/services.py Normal file
View File

@@ -0,0 +1,492 @@
from __future__ import annotations
import ipaddress
import logging
import pickle
import re
import secrets
from collections.abc import AsyncGenerator
from collections.abc import Mapping
from collections.abc import MutableMapping
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TypedDict
import datadog as datadog_module
import datadog.threadstats.base as datadog_client
import httpx
import pymysql
from redis import asyncio as aioredis
import app.settings
import app.state
from app._typing import IPAddress
from app.adapters.database import Database
from app.logging import Ansi
from app.logging import log
STRANGE_LOG_DIR = Path.cwd() / ".data/logs"
VERSION_RGX = re.compile(r"^# v(?P<ver>\d+\.\d+\.\d+)$")
SQL_UPDATES_FILE = Path.cwd() / "migrations/migrations.sql"
""" session objects """
http_client = httpx.AsyncClient()
database = Database(app.settings.DB_DSN)
redis: aioredis.Redis = aioredis.from_url(app.settings.REDIS_DSN)
datadog: datadog_client.ThreadStats | None = None
if str(app.settings.DATADOG_API_KEY) and str(app.settings.DATADOG_APP_KEY):
datadog_module.initialize(
api_key=str(app.settings.DATADOG_API_KEY),
app_key=str(app.settings.DATADOG_APP_KEY),
)
datadog = datadog_client.ThreadStats()
ip_resolver: IPResolver
""" session usecases """
class Country(TypedDict):
acronym: str
numeric: int
class Geolocation(TypedDict):
latitude: float
longitude: float
country: Country
# fmt: off
country_codes = {
"oc": 1, "eu": 2, "ad": 3, "ae": 4, "af": 5, "ag": 6, "ai": 7, "al": 8,
"am": 9, "an": 10, "ao": 11, "aq": 12, "ar": 13, "as": 14, "at": 15, "au": 16,
"aw": 17, "az": 18, "ba": 19, "bb": 20, "bd": 21, "be": 22, "bf": 23, "bg": 24,
"bh": 25, "bi": 26, "bj": 27, "bm": 28, "bn": 29, "bo": 30, "br": 31, "bs": 32,
"bt": 33, "bv": 34, "bw": 35, "by": 36, "bz": 37, "ca": 38, "cc": 39, "cd": 40,
"cf": 41, "cg": 42, "ch": 43, "ci": 44, "ck": 45, "cl": 46, "cm": 47, "cn": 48,
"co": 49, "cr": 50, "cu": 51, "cv": 52, "cx": 53, "cy": 54, "cz": 55, "de": 56,
"dj": 57, "dk": 58, "dm": 59, "do": 60, "dz": 61, "ec": 62, "ee": 63, "eg": 64,
"eh": 65, "er": 66, "es": 67, "et": 68, "fi": 69, "fj": 70, "fk": 71, "fm": 72,
"fo": 73, "fr": 74, "fx": 75, "ga": 76, "gb": 77, "gd": 78, "ge": 79, "gf": 80,
"gh": 81, "gi": 82, "gl": 83, "gm": 84, "gn": 85, "gp": 86, "gq": 87, "gr": 88,
"gs": 89, "gt": 90, "gu": 91, "gw": 92, "gy": 93, "hk": 94, "hm": 95, "hn": 96,
"hr": 97, "ht": 98, "hu": 99, "id": 100, "ie": 101, "il": 102, "in": 103, "io": 104,
"iq": 105, "ir": 106, "is": 107, "it": 108, "jm": 109, "jo": 110, "jp": 111, "ke": 112,
"kg": 113, "kh": 114, "ki": 115, "km": 116, "kn": 117, "kp": 118, "kr": 119, "kw": 120,
"ky": 121, "kz": 122, "la": 123, "lb": 124, "lc": 125, "li": 126, "lk": 127, "lr": 128,
"ls": 129, "lt": 130, "lu": 131, "lv": 132, "ly": 133, "ma": 134, "mc": 135, "md": 136,
"mg": 137, "mh": 138, "mk": 139, "ml": 140, "mm": 141, "mn": 142, "mo": 143, "mp": 144,
"mq": 145, "mr": 146, "ms": 147, "mt": 148, "mu": 149, "mv": 150, "mw": 151, "mx": 152,
"my": 153, "mz": 154, "na": 155, "nc": 156, "ne": 157, "nf": 158, "ng": 159, "ni": 160,
"nl": 161, "no": 162, "np": 163, "nr": 164, "nu": 165, "nz": 166, "om": 167, "pa": 168,
"pe": 169, "pf": 170, "pg": 171, "ph": 172, "pk": 173, "pl": 174, "pm": 175, "pn": 176,
"pr": 177, "ps": 178, "pt": 179, "pw": 180, "py": 181, "qa": 182, "re": 183, "ro": 184,
"ru": 185, "rw": 186, "sa": 187, "sb": 188, "sc": 189, "sd": 190, "se": 191, "sg": 192,
"sh": 193, "si": 194, "sj": 195, "sk": 196, "sl": 197, "sm": 198, "sn": 199, "so": 200,
"sr": 201, "st": 202, "sv": 203, "sy": 204, "sz": 205, "tc": 206, "td": 207, "tf": 208,
"tg": 209, "th": 210, "tj": 211, "tk": 212, "tm": 213, "tn": 214, "to": 215, "tl": 216,
"tr": 217, "tt": 218, "tv": 219, "tw": 220, "tz": 221, "ua": 222, "ug": 223, "um": 224,
"us": 225, "uy": 226, "uz": 227, "va": 228, "vc": 229, "ve": 230, "vg": 231, "vi": 232,
"vn": 233, "vu": 234, "wf": 235, "ws": 236, "ye": 237, "yt": 238, "rs": 239, "za": 240,
"zm": 241, "me": 242, "zw": 243, "xx": 244, "a2": 245, "o1": 246, "ax": 247, "gg": 248,
"im": 249, "je": 250, "bl": 251, "mf": 252,
}
# fmt: on
class IPResolver:
def __init__(self) -> None:
self.cache: MutableMapping[str, IPAddress] = {}
def get_ip(self, headers: Mapping[str, str]) -> IPAddress:
"""Resolve the IP address from the headers."""
ip_str = headers.get("CF-Connecting-IP")
if ip_str is None:
forwards = headers["X-Forwarded-For"].split(",")
if len(forwards) != 1:
ip_str = forwards[0]
else:
ip_str = headers["X-Real-IP"]
ip = self.cache.get(ip_str)
if ip is None:
ip = ipaddress.ip_address(ip_str)
self.cache[ip_str] = ip
return ip
async def fetch_geoloc(
ip: IPAddress,
headers: Mapping[str, str] | None = None,
) -> Geolocation | None:
"""Attempt to fetch geolocation data by any means necessary."""
geoloc = None
if headers is not None:
geoloc = _fetch_geoloc_from_headers(headers)
if geoloc is None:
geoloc = await _fetch_geoloc_from_ip(ip)
return geoloc
def _fetch_geoloc_from_headers(headers: Mapping[str, str]) -> Geolocation | None:
"""Attempt to fetch geolocation data from http headers."""
geoloc = __fetch_geoloc_cloudflare(headers)
if geoloc is None:
geoloc = __fetch_geoloc_nginx(headers)
return geoloc
def __fetch_geoloc_cloudflare(headers: Mapping[str, str]) -> Geolocation | None:
"""Attempt to fetch geolocation data from cloudflare headers."""
if not all(
key in headers for key in ("CF-IPCountry", "CF-IPLatitude", "CF-IPLongitude")
):
return None
country_code = headers["CF-IPCountry"].lower()
latitude = float(headers["CF-IPLatitude"])
longitude = float(headers["CF-IPLongitude"])
return {
"latitude": latitude,
"longitude": longitude,
"country": {
"acronym": country_code,
"numeric": country_codes[country_code],
},
}
def __fetch_geoloc_nginx(headers: Mapping[str, str]) -> Geolocation | None:
"""Attempt to fetch geolocation data from nginx headers."""
if not all(
key in headers for key in ("X-Country-Code", "X-Latitude", "X-Longitude")
):
return None
country_code = headers["X-Country-Code"].lower()
latitude = float(headers["X-Latitude"])
longitude = float(headers["X-Longitude"])
return {
"latitude": latitude,
"longitude": longitude,
"country": {
"acronym": country_code,
"numeric": country_codes[country_code],
},
}
async def _fetch_geoloc_from_ip(ip: IPAddress) -> Geolocation | None:
"""Fetch geolocation data based on ip (using ip-api)."""
if not ip.is_private:
url = f"http://ip-api.com/line/{ip}"
else:
url = "http://ip-api.com/line/"
response = await http_client.get(
url,
params={
"fields": ",".join(("status", "message", "countryCode", "lat", "lon")),
},
)
if response.status_code != 200:
log("Failed to get geoloc data: request failed.", Ansi.LRED)
return None
status, *lines = response.read().decode().split("\n")
if status != "success":
err_msg = lines[0]
if err_msg == "invalid query":
err_msg += f" ({url})"
log(f"Failed to get geoloc data: {err_msg} for ip {ip}.", Ansi.LRED)
return None
country_acronym = lines[0].lower()
return {
"latitude": float(lines[1]),
"longitude": float(lines[2]),
"country": {
"acronym": country_acronym,
"numeric": country_codes[country_acronym],
},
}
async def log_strange_occurrence(obj: object) -> None:
pickled_obj: bytes = pickle.dumps(obj)
uploaded = False
if app.settings.AUTOMATICALLY_REPORT_PROBLEMS:
# automatically reporting problems to cmyui's server
response = await http_client.post(
url="https://log.cmyui.xyz/",
headers={
"Bancho-Version": app.settings.VERSION,
"Bancho-Domain": app.settings.DOMAIN,
},
content=pickled_obj,
)
if response.status_code == 200 and response.read() == b"ok":
uploaded = True
log(
"Logged strange occurrence to cmyui's server. "
"Thank you for your participation! <3",
Ansi.LBLUE,
)
else:
log(
f"Autoupload to cmyui's server failed (HTTP {response.status_code})",
Ansi.LRED,
)
if not uploaded:
# log to a file locally, and prompt the user
while True:
log_file = STRANGE_LOG_DIR / f"strange_{secrets.token_hex(4)}.db"
if not log_file.exists():
break
log_file.touch(exist_ok=False)
log_file.write_bytes(pickled_obj)
log(
"Logged strange occurrence to" + "/".join(log_file.parts[-4:]),
Ansi.LYELLOW,
)
log(
"It would be greatly appreciated if you could forward this to the "
"bancho.py development team. To do so, please email josh@akatsuki.gg",
Ansi.LYELLOW,
)
# dependency management
class Version:
def __init__(self, major: int, minor: int, micro: int) -> None:
self.major = major
self.minor = minor
self.micro = micro
def __repr__(self) -> str:
return f"{self.major}.{self.minor}.{self.micro}"
def __hash__(self) -> int:
return self.as_tuple.__hash__()
def __eq__(self, other: object) -> bool:
if not isinstance(other, Version):
return NotImplemented
return self.as_tuple == other.as_tuple
def __lt__(self, other: Version) -> bool:
return self.as_tuple < other.as_tuple
def __le__(self, other: Version) -> bool:
return self.as_tuple <= other.as_tuple
def __gt__(self, other: Version) -> bool:
return self.as_tuple > other.as_tuple
def __ge__(self, other: Version) -> bool:
return self.as_tuple >= other.as_tuple
@property
def as_tuple(self) -> tuple[int, int, int]:
return (self.major, self.minor, self.micro)
@classmethod
def from_str(cls, s: str) -> Version | None:
split = s.split(".")
if len(split) == 3:
return cls(
major=int(split[0]),
minor=int(split[1]),
micro=int(split[2]),
)
return None
async def _get_latest_dependency_versions() -> AsyncGenerator[
tuple[str, Version, Version],
None,
]:
"""Return the current installed & latest version for each dependency."""
with open("requirements.txt") as f:
dependencies = f.read().splitlines(keepends=False)
# TODO: use asyncio.gather() to do all requests at once? or chunk them
for dependency in dependencies:
dependency_name, _, dependency_ver = dependency.partition("==")
current_ver = Version.from_str(dependency_ver)
if not current_ver:
# the module uses some more advanced (and often hard to parse)
# versioning system, so we won't be able to report updates.
continue
# TODO: split up and do the requests asynchronously
url = f"https://pypi.org/pypi/{dependency_name}/json"
response = await http_client.get(url)
json = response.json()
if response.status_code == 200 and json:
latest_ver = Version.from_str(json["info"]["version"])
if not latest_ver:
# they've started using a more advanced versioning system.
continue
yield (dependency_name, latest_ver, current_ver)
else:
yield (dependency_name, current_ver, current_ver)
async def check_for_dependency_updates() -> None:
"""Notify the developer of any dependency updates available."""
updates_available = False
async for module, current_ver, latest_ver in _get_latest_dependency_versions():
if latest_ver > current_ver:
updates_available = True
log(
f"{module} has an update available "
f"[{current_ver!r} -> {latest_ver!r}]",
Ansi.LMAGENTA,
)
if updates_available:
log(
"Python modules can be updated with "
"`python3.11 -m pip install -U <modules>`.",
Ansi.LMAGENTA,
)
# sql migrations
async def _get_current_sql_structure_version() -> Version | None:
"""Get the last launched version of the server."""
res = await app.state.services.database.fetch_one(
"SELECT ver_major, ver_minor, ver_micro "
"FROM startups ORDER BY datetime DESC LIMIT 1",
)
if res:
return Version(res["ver_major"], res["ver_minor"], res["ver_micro"])
return None
async def run_sql_migrations() -> None:
"""Update the sql structure, if it has changed."""
software_version = Version.from_str(app.settings.VERSION)
if software_version is None:
raise RuntimeError(f"Invalid bancho.py version '{app.settings.VERSION}'")
last_run_migration_version = await _get_current_sql_structure_version()
if not last_run_migration_version:
# Migrations have never run before - this is the first time starting the server.
# We'll insert the current version into the database, so future versions know to migrate.
await app.state.services.database.execute(
"INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) "
"VALUES (:major, :minor, :micro, NOW())",
{
"major": software_version.major,
"minor": software_version.minor,
"micro": software_version.micro,
},
)
return # already up to date (server has never run before)
if software_version == last_run_migration_version:
return # already up to date
# version changed; there may be sql changes.
content = SQL_UPDATES_FILE.read_text()
queries: list[str] = []
q_lines: list[str] = []
update_ver = None
for line in content.splitlines():
if not line:
continue
if line.startswith("#"):
# may be normal comment or new version
r_match = VERSION_RGX.fullmatch(line)
if r_match:
update_ver = Version.from_str(r_match["ver"])
continue
elif not update_ver:
continue
# we only need the updates between the
# previous and new version of the server.
if last_run_migration_version < update_ver <= software_version:
if line.endswith(";"):
if q_lines:
q_lines.append(line)
queries.append(" ".join(q_lines))
q_lines = []
else:
queries.append(line)
else:
q_lines.append(line)
if queries:
log(
f"Updating mysql structure (v{last_run_migration_version!r} -> v{software_version!r}).",
Ansi.LMAGENTA,
)
# XXX: we can't use a transaction here with mysql as structural changes to
# tables implicitly commit: https://dev.mysql.com/doc/refman/5.7/en/implicit-commit.html
for query in queries:
try:
await app.state.services.database.execute(query)
except pymysql.err.MySQLError as exc:
log(f"Failed: {query}", Ansi.GRAY)
log(repr(exc))
log(
"SQL failed to update - unless you've been "
"modifying sql and know what caused this, "
"please contact @cmyui on Discord.",
Ansi.LRED,
)
raise KeyboardInterrupt from exc
else:
# all queries executed successfully
await app.state.services.database.execute(
"INSERT INTO startups (ver_major, ver_minor, ver_micro, datetime) "
"VALUES (:major, :minor, :micro, NOW())",
{
"major": software_version.major,
"minor": software_version.minor,
"micro": software_version.micro,
},
)

54
app/state/sessions.py Normal file
View File

@@ -0,0 +1,54 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from typing import Any
from app.logging import Ansi
from app.logging import log
from app.objects.collections import Channels
from app.objects.collections import Matches
from app.objects.collections import Players
if TYPE_CHECKING:
from app.objects.player import Player
players = Players()
channels = Channels()
matches = Matches()
api_keys: dict[str, int] = {}
housekeeping_tasks: set[asyncio.Task[Any]] = set()
bot: Player
# use cases
async def cancel_housekeeping_tasks() -> None:
log(
f"-> Cancelling {len(housekeeping_tasks)} housekeeping tasks.",
Ansi.LMAGENTA,
)
# cancel housekeeping tasks
for task in housekeeping_tasks:
task.cancel()
await asyncio.gather(*housekeeping_tasks, return_exceptions=True)
loop = asyncio.get_running_loop()
for task in housekeeping_tasks:
if not task.cancelled():
exception = task.exception()
if exception:
loop.call_exception_handler(
{
"message": "unhandled exception during loop shutdown",
"exception": exception,
"task": task,
},
)

27
app/timer.py Normal file
View File

@@ -0,0 +1,27 @@
from __future__ import annotations
import time
from types import TracebackType
class Timer:
def __init__(self) -> None:
self.start_time: float | None = None
self.end_time: float | None = None
def __enter__(self) -> Timer:
self.start_time = time.time()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.end_time = time.time()
def elapsed(self) -> float:
if self.start_time is None or self.end_time is None:
raise ValueError("Timer has not been started or stopped.")
return self.end_time - self.start_time

0
app/usecases/__init__.py Normal file
View File

View File

@@ -0,0 +1,30 @@
from __future__ import annotations
import app.repositories.achievements
from app.repositories.achievements import Achievement
async def create(
file: str,
name: str,
desc: str,
cond: str,
) -> Achievement:
achievement = await app.repositories.achievements.create(
file,
name,
desc,
cond,
)
return achievement
async def fetch_many(
page: int | None = None,
page_size: int | None = None,
) -> list[Achievement]:
achievements = await app.repositories.achievements.fetch_many(
page,
page_size,
)
return achievements

138
app/usecases/performance.py Normal file
View File

@@ -0,0 +1,138 @@
from __future__ import annotations
import math
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TypedDict
from akatsuki_pp_py import Beatmap
from akatsuki_pp_py import Calculator
from app.constants.mods import Mods
@dataclass
class ScoreParams:
mode: int
mods: int | None = None
combo: int | None = None
# caller may pass either acc OR 300/100/50/geki/katu/miss
# passing both will result in a value error being raised
acc: float | None = None
n300: int | None = None
n100: int | None = None
n50: int | None = None
ngeki: int | None = None
nkatu: int | None = None
nmiss: int | None = None
class PerformanceRating(TypedDict):
pp: float
pp_acc: float | None
pp_aim: float | None
pp_speed: float | None
pp_flashlight: float | None
effective_miss_count: float | None
pp_difficulty: float | None
class DifficultyRating(TypedDict):
stars: float
aim: float | None
speed: float | None
flashlight: float | None
slider_factor: float | None
speed_note_count: float | None
stamina: float | None
color: float | None
rhythm: float | None
peak: float | None
class PerformanceResult(TypedDict):
performance: PerformanceRating
difficulty: DifficultyRating
def calculate_performances(
osu_file_path: str,
scores: Iterable[ScoreParams],
) -> list[PerformanceResult]:
"""\
Calculate performance for multiple scores on a single beatmap.
Typically most useful for mass-recalculation situations.
TODO: Some level of error handling & returning to caller should be
implemented here to handle cases where e.g. the beatmap file is invalid
or there an issue during calculation.
"""
calc_bmap = Beatmap(path=osu_file_path)
results: list[PerformanceResult] = []
for score in scores:
if score.acc and (
score.n300 or score.n100 or score.n50 or score.ngeki or score.nkatu
):
raise ValueError(
"Must not specify accuracy AND 300/100/50/geki/katu. Only one or the other.",
)
# rosupp ignores NC and requires DT
if score.mods is not None:
if score.mods & Mods.NIGHTCORE:
score.mods |= Mods.DOUBLETIME
calculator = Calculator(
mode=score.mode,
mods=score.mods or 0,
combo=score.combo,
acc=score.acc,
n300=score.n300,
n100=score.n100,
n50=score.n50,
n_geki=score.ngeki,
n_katu=score.nkatu,
n_misses=score.nmiss,
)
result = calculator.performance(calc_bmap)
pp = result.pp
if math.isnan(pp) or math.isinf(pp):
# TODO: report to logserver
pp = 0.0
else:
pp = round(pp, 3)
results.append(
{
"performance": {
"pp": pp,
"pp_acc": result.pp_acc,
"pp_aim": result.pp_aim,
"pp_speed": result.pp_speed,
"pp_flashlight": result.pp_flashlight,
"effective_miss_count": result.effective_miss_count,
"pp_difficulty": result.pp_difficulty,
},
"difficulty": {
"stars": result.difficulty.stars,
"aim": result.difficulty.aim,
"speed": result.difficulty.speed,
"flashlight": result.difficulty.flashlight,
"slider_factor": result.difficulty.slider_factor,
"speed_note_count": result.difficulty.speed_note_count,
"stamina": result.difficulty.stamina,
"color": result.difficulty.color,
"rhythm": result.difficulty.rhythm,
"peak": result.difficulty.peak,
},
},
)
return results

View File

@@ -0,0 +1,27 @@
from __future__ import annotations
import app.repositories.user_achievements
from app._typing import UNSET
from app._typing import _UnsetSentinel
from app.repositories.user_achievements import UserAchievement
async def create(user_id: int, achievement_id: int) -> UserAchievement:
user_achievement = await app.repositories.user_achievements.create(
user_id,
achievement_id,
)
return user_achievement
async def fetch_many(
user_id: int | _UnsetSentinel = UNSET,
page: int | None = None,
page_size: int | None = None,
) -> list[UserAchievement]:
user_achievements = await app.repositories.user_achievements.fetch_many(
user_id=user_id,
page=page,
page_size=page_size,
)
return user_achievements

254
app/utils.py Normal file
View File

@@ -0,0 +1,254 @@
from __future__ import annotations
import ctypes
import inspect
import os
import socket
import sys
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import TypedDict
from typing import TypeVar
import httpx
import pymysql
import app.settings
from app.logging import Ansi
from app.logging import log
if TYPE_CHECKING:
from app.repositories.users import User
T = TypeVar("T")
DATA_PATH = Path.cwd() / ".data"
ACHIEVEMENTS_ASSETS_PATH = DATA_PATH / "assets/medals/client"
DEFAULT_AVATAR_PATH = DATA_PATH / "avatars/default.jpg"
def make_safe_name(name: str) -> str:
"""Return a name safe for usage in sql."""
return name.lower().replace(" ", "_")
def determine_highest_ranking_clan_member(members: list[User]) -> User:
return next(iter(sorted(members, key=lambda m: m["clan_priv"], reverse=True)))
def _download_achievement_images_osu(achievements_path: Path) -> bool:
"""Download all used achievement images (one by one, from osu!)."""
achs: list[str] = []
for resolution in ("", "@2x"):
for mode in ("osu", "taiko", "fruits", "mania"):
# only osu!std has 9 & 10 star pass/fc medals.
for star_rating in range(1, 1 + (10 if mode == "osu" else 8)):
achs.append(f"{mode}-skill-pass-{star_rating}{resolution}.png")
achs.append(f"{mode}-skill-fc-{star_rating}{resolution}.png")
for combo in (500, 750, 1000, 2000):
achs.append(f"osu-combo-{combo}{resolution}.png")
for mod in (
"suddendeath",
"hidden",
"perfect",
"hardrock",
"doubletime",
"flashlight",
"easy",
"nofail",
"nightcore",
"halftime",
"spunout",
):
achs.append(f"all-intro-{mod}{resolution}.png")
log("Downloading achievement images from osu!.", Ansi.LCYAN)
for ach in achs:
resp = httpx.get(f"https://assets.ppy.sh/medals/client/{ach}")
if resp.status_code != 200:
return False
log(f"Saving achievement: {ach}", Ansi.LCYAN)
(achievements_path / ach).write_bytes(resp.content)
return True
def download_achievement_images(achievements_path: Path) -> None:
"""Download all used achievement images (using the best available source)."""
# download individual files from the official osu! servers
downloaded = _download_achievement_images_osu(achievements_path)
if downloaded:
log("Downloaded all achievement images.", Ansi.LGREEN)
else:
# TODO: make the code safe in this state
log("Failed to download achievement images.", Ansi.LRED)
achievements_path.rmdir()
# allow passthrough (don't hard crash).
# the server will *mostly* work in this state.
pass
def download_default_avatar(default_avatar_path: Path) -> None:
"""Download an avatar to use as the server's default."""
resp = httpx.get("https://i.cmyui.xyz/U24XBZw-4wjVME-JaEz3.png")
if resp.status_code != 200:
log("Failed to fetch default avatar.", Ansi.LRED)
return
log("Downloaded default avatar.", Ansi.LGREEN)
default_avatar_path.write_bytes(resp.content)
def has_internet_connectivity(timeout: float = 1.0) -> bool:
"""Check for an active internet connection."""
COMMON_DNS_SERVERS = (
# Cloudflare
"1.1.1.1",
"1.0.0.1",
# Google
"8.8.8.8",
"8.8.4.4",
)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client:
client.settimeout(timeout)
for host in COMMON_DNS_SERVERS:
try:
client.connect((host, 53))
except OSError:
continue
else:
return True
# all connections failed
return False
class FrameInfo(TypedDict):
function: str
filename: str
lineno: int
charno: int
locals: dict[str, str]
def get_appropriate_stacktrace() -> list[FrameInfo]:
"""Return information of all frames related to cmyui_pkg and below."""
stack = inspect.stack()[1:]
for idx, frame in enumerate(stack):
if frame.function == "run":
break
else:
raise Exception
return [
{
"function": frame.function,
"filename": Path(frame.filename).name,
"lineno": frame.lineno,
"charno": frame.index or 0,
"locals": {k: repr(v) for k, v in frame.frame.f_locals.items()},
}
# reverse for python-like stacktrace
# ordering; puts the most recent
# call closest to the command line
for frame in reversed(stack[:idx])
]
def pymysql_encode(
conv: Callable[[Any, dict[object, object] | None], str],
) -> Callable[[type[T]], type[T]]:
"""Decorator to allow for adding to pymysql's encoders."""
def wrapper(cls: type[T]) -> type[T]:
pymysql.converters.encoders[cls] = conv
return cls
return wrapper
def escape_enum(
val: Any,
_: dict[object, object] | None = None,
) -> str: # used for ^
return str(int(val))
def ensure_persistent_volumes_are_available() -> None:
# create /.data directory
DATA_PATH.mkdir(exist_ok=True)
# create /.data/... subdirectories
for sub_dir in ("avatars", "logs", "osu", "osr", "ss"):
subdir = DATA_PATH / sub_dir
subdir.mkdir(exist_ok=True)
# download achievement images from osu!
if not ACHIEVEMENTS_ASSETS_PATH.exists():
ACHIEVEMENTS_ASSETS_PATH.mkdir(parents=True)
download_achievement_images(ACHIEVEMENTS_ASSETS_PATH)
# download a default avatar image for new users
if not DEFAULT_AVATAR_PATH.exists():
download_default_avatar(DEFAULT_AVATAR_PATH)
def is_running_as_admin() -> bool:
try:
return os.geteuid() == 0 # type: ignore[attr-defined, no-any-return, unused-ignore]
except AttributeError:
pass
try:
return ctypes.windll.shell32.IsUserAnAdmin() == 1 # type: ignore[attr-defined, no-any-return, unused-ignore]
except AttributeError:
raise Exception(
f"{sys.platform} is not currently supported on bancho.py, please create a github issue!",
)
def display_startup_dialog() -> None:
"""Print any general information or warnings to the console."""
if app.settings.DEVELOPER_MODE:
log("running in advanced mode", Ansi.LYELLOW)
if app.settings.DEBUG:
log("running in debug mode", Ansi.LMAGENTA)
# running on root/admin grants the software potentally dangerous and
# unnecessary power over the operating system and is not advised.
if is_running_as_admin():
log(
"It is not recommended to run bancho.py as root/admin, especially in production."
+ (
" You are at increased risk as developer mode is enabled."
if app.settings.DEVELOPER_MODE
else ""
),
Ansi.LYELLOW,
)
if not has_internet_connectivity():
log("No internet connectivity detected", Ansi.LYELLOW)
def has_jpeg_headers_and_trailers(data_view: memoryview) -> bool:
return data_view[:4] == b"\xff\xd8\xff\xe0" and data_view[6:11] == b"JFIF\x00"
def has_png_headers_and_trailers(data_view: memoryview) -> bool:
return (
data_view[:8] == b"\x89PNG\r\n\x1a\n"
and data_view[-8:] == b"\x49END\xae\x42\x60\x82"
)

99
docker-compose.test.yml Normal file
View File

@@ -0,0 +1,99 @@
services:
## shared services
mysql-test:
image: mysql:latest
# ports:
# - ${DB_PORT}:${DB_PORT}
environment:
MYSQL_USER: ${DB_USER}
MYSQL_PASSWORD: ${DB_PASS}
MYSQL_DATABASE: ${DB_NAME}
MYSQL_HOST: ${DB_HOST}
MYSQL_PORT: ${DB_PORT}
MYSQL_ROOT_PASSWORD: ${DB_PASS}
volumes:
- ./migrations/base.sql:/docker-entrypoint-initdb.d/init.sql:ro
- test-db-data:/var/lib/mysql
networks:
- test-network
healthcheck:
test: "/usr/bin/mysql --user=$$MYSQL_USER --password=$$MYSQL_PASSWORD --execute \"SHOW DATABASES;\""
interval: 2s
timeout: 20s
retries: 10
redis-test:
image: bitnami/redis:latest
# ports:
# - ${REDIS_PORT}:${REDIS_PORT}
user: root
volumes:
- test-redis-data:/bitnami/redis/data
networks:
- test-network
environment:
- ALLOW_EMPTY_PASSWORD=yes
- REDIS_PASSWORD=${REDIS_PASS}
## application services
bancho-test:
# we also have a public image: osuakatsuki/bancho.py:latest
image: bancho:latest
depends_on:
mysql-test:
condition: service_healthy
redis-test:
condition: service_started
tty: true
init: true
volumes:
- .:/srv/root
- test-data:/srv/root/.data
networks:
- test-network
environment:
- APP_HOST=${APP_HOST}
- APP_PORT=${APP_PORT}
- DB_USER=${DB_USER}
- DB_PASS=${DB_PASS}
- DB_NAME=${DB_NAME}
- DB_HOST=${DB_HOST}
- DB_PORT=${DB_PORT}
- REDIS_USER=${REDIS_USER}
- REDIS_PASS=${REDIS_PASS}
- REDIS_HOST=${REDIS_HOST}
- REDIS_PORT=${REDIS_PORT}
- REDIS_DB=${REDIS_DB}
- OSU_API_KEY=${OSU_API_KEY}
- MIRROR_SEARCH_ENDPOINT=${MIRROR_SEARCH_ENDPOINT}
- MIRROR_DOWNLOAD_ENDPOINT=${MIRROR_DOWNLOAD_ENDPOINT}
- DOMAIN=${DOMAIN}
- COMMAND_PREFIX=${COMMAND_PREFIX}
- SEASONAL_BGS=${SEASONAL_BGS}
- MENU_ICON_URL=${MENU_ICON_URL}
- MENU_ONCLICK_URL=${MENU_ONCLICK_URL}
- DATADOG_API_KEY=${DATADOG_API_KEY}
- DATADOG_APP_KEY=${DATADOG_APP_KEY}
- DEBUG=${DEBUG}
- REDIRECT_OSU_URLS=${REDIRECT_OSU_URLS}
- PP_CACHED_ACCS=${PP_CACHED_ACCS}
- DISALLOWED_NAMES=${DISALLOWED_NAMES}
- DISALLOWED_PASSWORDS=${DISALLOWED_PASSWORDS}
- DISALLOW_OLD_CLIENTS=${DISALLOW_OLD_CLIENTS}
- DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION}
- DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK}
- AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS}
- LOG_WITH_COLORS=${LOG_WITH_COLORS}
- SSL_CERT_PATH=${SSL_CERT_PATH}
- SSL_KEY_PATH=${SSL_KEY_PATH}
- DEVELOPER_MODE=${DEVELOPER_MODE}
volumes:
test-data:
test-db-data:
test-redis-data:
networks:
test-network:

92
docker-compose.yml Normal file
View File

@@ -0,0 +1,92 @@
services:
## shared services
mysql:
image: mysql:latest
# ports:
# - ${DB_PORT}:${DB_PORT}
environment:
MYSQL_USER: ${DB_USER}
MYSQL_PASSWORD: ${DB_PASS}
MYSQL_DATABASE: ${DB_NAME}
MYSQL_HOST: ${DB_HOST}
MYSQL_PORT: ${DB_PORT}
MYSQL_RANDOM_ROOT_PASSWORD: "true"
volumes:
- ./migrations/base.sql:/docker-entrypoint-initdb.d/init.sql:ro
- db-data:/var/lib/mysql
healthcheck:
test: "/usr/bin/mysql --user=$$MYSQL_USER --password=$$MYSQL_PASSWORD --execute \"SHOW DATABASES;\""
interval: 2s
timeout: 20s
retries: 10
redis:
image: bitnami/redis:latest
# ports:
# - ${REDIS_PORT}:${REDIS_PORT}
user: root
volumes:
- redis-data:/bitnami/redis/data
environment:
- ALLOW_EMPTY_PASSWORD=yes
- REDIS_PASSWORD=${REDIS_PASS}
## application services
bancho:
# we also have a public image: osuakatsuki/bancho.py:latest
image: bancho:latest
ports:
- ${APP_PORT}:${APP_PORT}
depends_on:
mysql:
condition: service_healthy
redis:
condition: service_started
tty: true
init: true
volumes:
- .:/srv/root
- data:/srv/root/.data
environment:
- APP_HOST=${APP_HOST}
- APP_PORT=${APP_PORT}
- DB_USER=${DB_USER}
- DB_PASS=${DB_PASS}
- DB_NAME=${DB_NAME}
- DB_HOST=${DB_HOST}
- DB_PORT=${DB_PORT}
- REDIS_USER=${REDIS_USER}
- REDIS_PASS=${REDIS_PASS}
- REDIS_HOST=${REDIS_HOST}
- REDIS_PORT=${REDIS_PORT}
- REDIS_DB=${REDIS_DB}
- OSU_API_KEY=${OSU_API_KEY}
- MIRROR_SEARCH_ENDPOINT=${MIRROR_SEARCH_ENDPOINT}
- MIRROR_DOWNLOAD_ENDPOINT=${MIRROR_DOWNLOAD_ENDPOINT}
- DOMAIN=${DOMAIN}
- COMMAND_PREFIX=${COMMAND_PREFIX}
- SEASONAL_BGS=${SEASONAL_BGS}
- MENU_ICON_URL=${MENU_ICON_URL}
- MENU_ONCLICK_URL=${MENU_ONCLICK_URL}
- DATADOG_API_KEY=${DATADOG_API_KEY}
- DATADOG_APP_KEY=${DATADOG_APP_KEY}
- DEBUG=${DEBUG}
- REDIRECT_OSU_URLS=${REDIRECT_OSU_URLS}
- PP_CACHED_ACCS=${PP_CACHED_ACCS}
- DISALLOWED_NAMES=${DISALLOWED_NAMES}
- DISALLOWED_PASSWORDS=${DISALLOWED_PASSWORDS}
- DISALLOW_OLD_CLIENTS=${DISALLOW_OLD_CLIENTS}
- DISALLOW_INGAME_REGISTRATION=${DISALLOW_INGAME_REGISTRATION}
- DISCORD_AUDIT_LOG_WEBHOOK=${DISCORD_AUDIT_LOG_WEBHOOK}
- AUTOMATICALLY_REPORT_PROBLEMS=${AUTOMATICALLY_REPORT_PROBLEMS}
- LOG_WITH_COLORS=${LOG_WITH_COLORS}
- SSL_CERT_PATH=${SSL_CERT_PATH}
- SSL_KEY_PATH=${SSL_KEY_PATH}
- DEVELOPER_MODE=${DEVELOPER_MODE}
volumes:
data:
db-data:
redis-data:

30
ext/Caddyfile Normal file
View File

@@ -0,0 +1,30 @@
# Comment this out if you need to explicitly
# use self-signed certs.
# NOTE: Not necessary if using a '.local' domain
#
# {
# local_certs
# }
c.{$DOMAIN}, ce.{$DOMAIN}, c4.{$DOMAIN}, osu.{$DOMAIN}, b.{$DOMAIN}, api.{$DOMAIN} {
encode gzip
reverse_proxy * 127.0.0.1:{$APP_PORT} {
header_up X-Real-IP {remote_host}
}
request_body {
max_size 20MB
}
}
assets.{$DOMAIN} {
encode gzip
root * {$DATA_DIRECTORY}/assets
file_server
}
a.{$DOMAIN} {
encode gzip
root * {$DATA_DIRECTORY}/avatars
try_files {path} {file.base}.png {file.base}.jpg {file.base}.gif {file.base}.jpeg {file.base}.jfif default.jpg =404
}

54
ext/nginx.conf.example Normal file
View File

@@ -0,0 +1,54 @@
# c[e4]?.ppy.sh is used for bancho
# osu.ppy.sh is used for /web, /api, etc.
# a.ppy.sh is used for osu! avatars
upstream bancho {
server 127.0.0.1:${APP_PORT};
}
server {
listen 443 ssl;
server_name c.${DOMAIN} ce.${DOMAIN} c4.${DOMAIN} osu.${DOMAIN} b.${DOMAIN} api.${DOMAIN};
client_max_body_size 20M;
ssl_certificate ${SSL_CERT_PATH};
ssl_certificate_key ${SSL_KEY_PATH};
ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:@SECLEVEL=1";
location / {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $http_host;
add_header Access-Control-Allow-Origin *;
proxy_redirect off;
proxy_pass http://bancho;
}
}
server {
listen 443 ssl;
server_name assets.${DOMAIN};
ssl_certificate ${SSL_CERT_PATH};
ssl_certificate_key ${SSL_KEY_PATH};
ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:@SECLEVEL=1";
location / {
default_type image/png;
root ${DATA_DIRECTORY}/assets;
}
}
server {
listen 443 ssl;
server_name a.${DOMAIN};
ssl_certificate ${SSL_CERT_PATH};
ssl_certificate_key ${SSL_KEY_PATH};
ssl_ciphers "EECDH+AESGCM:EDH+AESGCM:AES256+EECDH:AES256+EDH:@SECLEVEL=1";
location / {
root ${DATA_DIRECTORY}/avatars;
try_files $uri $uri.png $uri.jpg $uri.gif $uri.jpeg $uri.jfif /default.jpg = 404;
}
}

36
logging.yaml.example Normal file
View File

@@ -0,0 +1,36 @@
version: 1
disable_existing_loggers: true
loggers:
httpx:
level: WARNING
handlers: [console]
propagate: no
httpcore:
level: WARNING
handlers: [console]
propagate: no
multipart.multipart:
level: ERROR
handlers: [console]
propagate: no
handlers:
console:
class: logging.StreamHandler
level: INFO
formatter: plaintext
stream: ext://sys.stdout
# file:
# class: logging.FileHandler
# level: INFO
# formatter: json
# filename: logs.log
formatters:
plaintext:
format: '[%(asctime)s] %(levelname)s %(message)s'
datefmt: '%Y-%m-%d %H:%M:%S'
# json:
# class: pythonjsonlogger.jsonlogger.JsonFormatter
# format: '%(asctime)s %(name)s %(levelname)s %(message)s'
root:
level: INFO
handlers: [console] # , file]

31
main.py Normal file
View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python3.11
from __future__ import annotations
import logging
import uvicorn
import app.logging
import app.settings
import app.utils
app.logging.configure_logging()
def main() -> int:
app.utils.display_startup_dialog()
uvicorn.run(
"app.api.init_api:asgi_app",
reload=app.settings.DEBUG,
log_level=logging.WARNING,
server_header=False,
date_header=False,
headers=[("bancho-version", app.settings.VERSION)],
host=app.settings.APP_HOST,
port=app.settings.APP_PORT,
)
return 0
if __name__ == "__main__":
exit(main())

1953
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

92
pyproject.toml Normal file
View File

@@ -0,0 +1,92 @@
[tool.mypy]
plugins = ["pydantic.mypy"]
strict = true
disallow_untyped_calls = false
[[tool.mypy.overrides]]
module = "tests.*"
disable_error_code = ["var-annotated", "has-type"]
disallow_untyped_defs = false
[[tool.mypy.overrides]]
module = [
"aiomysql.*",
"mitmproxy.*",
"py3rijndael.*",
"timeago.*",
"pytimeparse.*",
"cpuinfo.*",
]
ignore_missing_imports = true
[tool.pydantic-mypy]
init_forbid_extra = true
init_typed = true
warn_requird_dynamic_aliases = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
[tool.isort]
add_imports = ["from __future__ import annotations"]
force_single_line = true
profile = "black"
[tool.poetry]
name = "bancho-py"
version = "5.2.2"
description = "An osu! server implementation optimized for maintainability in modern python"
authors = ["Akatsuki Team"]
license = "MIT"
readme = "README.md"
packages = [{ include = "bancho" }]
[tool.poetry.dependencies]
python = "^3.11"
async-timeout = "4.0.3"
bcrypt = "4.1.2"
datadog = "0.48.0"
fastapi = "0.109.2"
orjson = "3.9.13"
psutil = "5.9.8"
python-dotenv = "1.0.1"
python-multipart = "0.0.9"
requests = "2.31.0"
timeago = "1.0.16"
uvicorn = "0.27.1"
uvloop = { markers = "sys_platform != 'win32'", version = "0.19.0" }
winloop = { platform = "win32", version = "0.1.1" }
py3rijndael = "0.3.3"
pytimeparse = "1.1.8"
pydantic = "2.6.1"
redis = { extras = ["hiredis"], version = "5.0.1" }
sqlalchemy = ">=1.4.42,<1.5"
akatsuki-pp-py = "1.0.5"
cryptography = "42.0.2"
tenacity = "8.2.3"
httpx = "0.26.0"
py-cpuinfo = "9.0.0"
pytest = "8.0.0"
pytest-asyncio = "0.23.5"
asgi-lifespan = "2.1.0"
respx = "0.20.2"
tzdata = "2024.1"
coverage = "^7.4.1"
databases = { version = "^0.8.0", extras = ["mysql"] }
python-json-logger = "^2.0.7"
[tool.poetry.group.dev.dependencies]
pre-commit = "3.6.1"
black = "24.1.1"
isort = "5.13.2"
autoflake = "2.2.1"
types-psutil = "5.9.5.20240205"
types-pymysql = "1.1.0.1"
types-requests = "2.31.0.20240125"
mypy = "1.8.0"
types-pyyaml = "^6.0.12.12"
sqlalchemy2-stubs = "^0.0.2a38"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"