Add files via upload

This commit is contained in:
purr
2025-04-04 21:30:31 +09:00
committed by GitHub
parent 5763658177
commit 966e7691a3
90 changed files with 20938 additions and 0 deletions

11
app/objects/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
# type: ignore
# isort: dont-add-imports
from . import achievement
from . import beatmap
from . import channel
from . import collections
from . import match
from . import models
from . import player
from . import score

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from app.objects.score import Score
class Achievement:
"""A class to represent a single osu! achievement."""
def __init__(
self,
id: int,
file: str,
name: str,
desc: str,
cond: Callable[[Score, int], bool], # (score, mode) -> unlocked
) -> None:
self.id = id
self.file = file
self.name = name
self.desc = desc
self.cond = cond
def __repr__(self) -> str:
return f"{self.file}+{self.name}+{self.desc}"

996
app/objects/beatmap.py Normal file
View File

@@ -0,0 +1,996 @@
from __future__ import annotations
import functools
import hashlib
from collections import defaultdict
from collections.abc import Mapping
from datetime import datetime
from datetime import timedelta
from enum import IntEnum
from enum import unique
from pathlib import Path
from typing import Any
from typing import TypedDict
import httpx
from tenacity import retry
from tenacity.stop import stop_after_attempt
import app.settings
import app.state
import app.utils
from app.constants.gamemodes import GameMode
from app.logging import Ansi
from app.logging import log
from app.repositories import maps as maps_repo
from app.utils import escape_enum
from app.utils import pymysql_encode
# from dataclasses import dataclass
BEATMAPS_PATH = Path.cwd() / ".data/osu"
DEFAULT_LAST_UPDATE = datetime(1970, 1, 1)
IGNORED_BEATMAP_CHARS = dict.fromkeys(map(ord, r':\/*<>?"|'), None)
class BeatmapApiResponse(TypedDict):
data: list[dict[str, Any]] | None
status_code: int
@retry(reraise=True, stop=stop_after_attempt(3))
async def api_get_beatmaps(**params: Any) -> BeatmapApiResponse:
"""\
Fetch data from the osu!api with a beatmap's md5.
Optionally use osu.direct's API if the user has not provided an osu! api key.
"""
if app.settings.DEBUG:
log(f"Doing api (getbeatmaps) request {params}", Ansi.LMAGENTA)
if app.settings.OSU_API_KEY:
# https://github.com/ppy/osu-api/wiki#apiget_beatmaps
url = "https://old.ppy.sh/api/get_beatmaps"
params["k"] = str(app.settings.OSU_API_KEY)
else:
# https://osu.direct/doc
url = "https://osu.direct/api/get_beatmaps"
response = await app.state.services.http_client.get(url, params=params)
response_data = response.json()
if response.status_code == 200 and response_data: # (data may be [])
return {"data": response_data, "status_code": response.status_code}
return {"data": None, "status_code": response.status_code}
@retry(reraise=True, stop=stop_after_attempt(3))
async def api_get_osu_file(beatmap_id: int) -> bytes:
url = f"https://old.ppy.sh/osu/{beatmap_id}"
response = await app.state.services.http_client.get(url)
response.raise_for_status()
return response.read()
def disk_has_expected_osu_file(
beatmap_id: int,
expected_md5: str | None = None,
) -> bool:
osu_file_path = BEATMAPS_PATH / f"{beatmap_id}.osu"
file_exists = osu_file_path.exists()
if file_exists and expected_md5 is not None:
osu_file_md5 = hashlib.md5(osu_file_path.read_bytes()).hexdigest()
return osu_file_md5 == expected_md5
return file_exists
def write_osu_file_to_disk(beatmap_id: int, data: bytes) -> None:
osu_file_path = BEATMAPS_PATH / f"{beatmap_id}.osu"
osu_file_path.write_bytes(data)
async def ensure_osu_file_is_available(
beatmap_id: int,
expected_md5: str | None = None,
) -> bool:
"""\
Download the .osu file for a beatmap if it's not already present.
If `expected_md5` is provided, the file will be downloaded if it
does not match the expected md5 hash -- this is typically used for
ensuring a file is the latest expected version.
Returns whether the file is available for use.
"""
if disk_has_expected_osu_file(beatmap_id, expected_md5):
return True
try:
latest_osu_file = await api_get_osu_file(beatmap_id)
except httpx.HTTPStatusError:
return False
except Exception:
log(f"Failed to fetch osu file for {beatmap_id}", Ansi.LRED)
return False
write_osu_file_to_disk(beatmap_id, latest_osu_file)
return True
# for some ungodly reason, different values are used to
# represent different ranked statuses all throughout osu!
# This drives me and probably everyone else pretty insane,
# but we have nothing to do but deal with it B).
@unique
@pymysql_encode(escape_enum)
class RankedStatus(IntEnum):
"""Server side osu! beatmap ranked statuses.
Same as used in osu!'s /web/getscores.php.
"""
NotSubmitted = -1
Pending = 0
UpdateAvailable = 1
Ranked = 2
Approved = 3
Qualified = 4
Loved = 5
def __str__(self) -> str:
return {
self.NotSubmitted: "Unsubmitted",
self.Pending: "Unranked",
self.UpdateAvailable: "Outdated",
self.Ranked: "Ranked",
self.Approved: "Approved",
self.Qualified: "Qualified",
self.Loved: "Loved",
}[self]
@functools.cached_property
def osu_api(self) -> int:
"""Convert the value to osu!api status."""
# XXX: only the ones that exist are mapped.
return {
self.Pending: 0,
self.Ranked: 1,
self.Approved: 2,
self.Qualified: 3,
self.Loved: 4,
}[self]
@classmethod
@functools.cache
def from_osuapi(cls, osuapi_status: int) -> RankedStatus:
"""Convert from osu!api status."""
mapping: Mapping[int, RankedStatus] = defaultdict(
lambda: cls.UpdateAvailable,
{
-2: cls.Pending, # graveyard
-1: cls.Pending, # wip
0: cls.Pending,
1: cls.Ranked,
2: cls.Approved,
3: cls.Qualified,
4: cls.Loved,
},
)
return mapping[osuapi_status]
@classmethod
@functools.cache
def from_osudirect(cls, osudirect_status: int) -> RankedStatus:
"""Convert from osu!direct status."""
mapping: Mapping[int, RankedStatus] = defaultdict(
lambda: cls.UpdateAvailable,
{
0: cls.Ranked,
2: cls.Pending,
3: cls.Qualified,
# 4: all ranked statuses lol
5: cls.Pending, # graveyard
7: cls.Ranked, # played before
8: cls.Loved,
},
)
return mapping[osudirect_status]
@classmethod
@functools.cache
def from_str(cls, status_str: str) -> RankedStatus:
"""Convert from string value.""" # could perhaps have `'unranked': cls.Pending`?
mapping: Mapping[str, RankedStatus] = defaultdict(
lambda: cls.UpdateAvailable,
{
"pending": cls.Pending,
"ranked": cls.Ranked,
"approved": cls.Approved,
"qualified": cls.Qualified,
"loved": cls.Loved,
},
)
return mapping[status_str]
# @dataclass
# class BeatmapInfoRequest:
# filenames: Sequence[str]
# ids: Sequence[int]
# @dataclass
# class BeatmapInfo:
# id: int # i16
# map_id: int # i32
# set_id: int # i32
# thread_id: int # i32
# status: int # u8
# osu_rank: int # u8
# fruits_rank: int # u8
# taiko_rank: int # u8
# mania_rank: int # u8
# map_md5: str
class Beatmap:
"""A class representing an osu! beatmap.
This class provides a high level api which should always be the
preferred method of fetching beatmaps due to its housekeeping.
It will perform caching & invalidation, handle map updates while
minimizing osu!api requests, and always use the most efficient
method available to fetch the beatmap's information, while
maintaining a low overhead.
The only methods you should need are:
await Beatmap.from_md5(md5: str, set_id: int = -1) -> Beatmap | None
await Beatmap.from_bid(bid: int) -> Beatmap | None
Properties:
Beatmap.full -> str # Artist - Title [Version]
Beatmap.url -> str # https://osu.cmyui.xyz/b/321
Beatmap.embed -> str # [{url} {full}]
Beatmap.has_leaderboard -> bool
Beatmap.awards_ranked_pp -> bool
Beatmap.as_dict -> dict[str, object]
Lower level API:
Beatmap._from_md5_cache(md5: str, check_updates: bool = True) -> Beatmap | None
Beatmap._from_bid_cache(bid: int, check_updates: bool = True) -> Beatmap | None
Beatmap._from_md5_sql(md5: str) -> Beatmap | None
Beatmap._from_bid_sql(bid: int) -> Beatmap | None
Beatmap._parse_from_osuapi_resp(osuapi_resp: dict[str, object]) -> None
Note that the BeatmapSet class also provides a similar API.
Possibly confusing attributes
-----------
frozen: `bool`
Whether the beatmap's status is to be kept when a newer
version is found in the osu!api.
# XXX: This is set when a map's status is manually changed.
"""
def __init__(
self,
map_set: BeatmapSet,
md5: str = "",
id: int = 0,
set_id: int = 0,
artist: str = "",
title: str = "",
version: str = "",
creator: str = "",
last_update: datetime = DEFAULT_LAST_UPDATE,
total_length: int = 0,
max_combo: int = 0,
status: RankedStatus = RankedStatus.Pending,
frozen: bool = False,
plays: int = 0,
passes: int = 0,
mode: GameMode = GameMode.VANILLA_OSU,
bpm: float = 0.0,
cs: float = 0.0,
od: float = 0.0,
ar: float = 0.0,
hp: float = 0.0,
diff: float = 0.0,
filename: str = "",
) -> None:
self.set = map_set
self.md5 = md5
self.id = id
self.set_id = set_id
self.artist = artist
self.title = title
self.version = version
self.creator = creator
self.last_update = last_update
self.total_length = total_length
self.max_combo = max_combo
self.status = status
self.frozen = frozen
self.plays = plays
self.passes = passes
self.mode = mode
self.bpm = bpm
self.cs = cs
self.od = od
self.ar = ar
self.hp = hp
self.diff = diff
self.filename = filename
def __repr__(self) -> str:
return self.full_name
@property
def full_name(self) -> str:
"""The full osu! formatted name `self`."""
return f"{self.artist} - {self.title} [{self.version}]"
@property
def url(self) -> str:
"""The osu! beatmap url for `self`."""
return f"https://osu.{app.settings.DOMAIN}/b/{self.id}"
@property
def embed(self) -> str:
"""An osu! chat embed to `self`'s osu! beatmap page."""
return f"[{self.url} {self.full_name}]"
@property
def has_leaderboard(self) -> bool:
"""Return whether the map has a ranked leaderboard."""
return self.status in (
RankedStatus.Ranked,
RankedStatus.Approved,
RankedStatus.Loved,
)
@property
def awards_ranked_pp(self) -> bool:
"""Return whether the map's status awards ranked pp for scores."""
return self.status in (RankedStatus.Ranked, RankedStatus.Approved)
@property # perhaps worth caching some of?
def as_dict(self) -> dict[str, object]:
return {
"md5": self.md5,
"id": self.id,
"set_id": self.set_id,
"artist": self.artist,
"title": self.title,
"version": self.version,
"creator": self.creator,
"last_update": self.last_update,
"total_length": self.total_length,
"max_combo": self.max_combo,
"status": self.status,
"plays": self.plays,
"passes": self.passes,
"mode": self.mode,
"bpm": self.bpm,
"cs": self.cs,
"od": self.od,
"ar": self.ar,
"hp": self.hp,
"diff": self.diff,
}
""" High level API """
# There are three levels of storage used for beatmaps,
# the cache (ram), the db (disk), and the osu!api (web).
# Going down this list gets exponentially slower, so
# we always prioritze what's fastest when possible.
# These methods will keep beatmaps reasonably up to
# date and use the fastest storage available, while
# populating the higher levels of the cache with new maps.
@classmethod
async def from_md5(cls, md5: str, set_id: int = -1) -> Beatmap | None:
"""Fetch a map from the cache, database, or osuapi by md5."""
bmap = await cls._from_md5_cache(md5)
if not bmap:
# map not found in cache
# to be efficient, we want to cache the whole set
# at once rather than caching the individual map
if set_id <= 0:
# set id not provided - fetch it from the map md5
rec = await maps_repo.fetch_one(md5=md5)
if rec is not None:
# set found in db
set_id = rec["set_id"]
else:
# set not found in db, try api
api_data = await api_get_beatmaps(h=md5)
if api_data["data"] is None:
return None
api_response = api_data["data"]
set_id = int(api_response[0]["beatmapset_id"])
# fetch (and cache) beatmap set
beatmap_set = await BeatmapSet.from_bsid(set_id)
if beatmap_set is not None:
# the beatmap set has been cached - fetch beatmap from cache
bmap = await cls._from_md5_cache(md5)
# XXX:HACK in this case, BeatmapSet.from_bsid will have
# ensured the map is up to date, so we can just return it
return bmap
if bmap is not None:
if bmap.set._cache_expired():
await bmap.set._update_if_available()
return bmap
@classmethod
async def from_bid(cls, bid: int) -> Beatmap | None:
"""Fetch a map from the cache, database, or osuapi by id."""
bmap = await cls._from_bid_cache(bid)
if not bmap:
# map not found in cache
# to be efficient, we want to cache the whole set
# at once rather than caching the individual map
rec = await maps_repo.fetch_one(id=bid)
if rec is not None:
# set found in db
set_id = rec["set_id"]
else:
# set not found in db, try getting via api
api_data = await api_get_beatmaps(b=bid)
if api_data["data"] is None:
return None
api_response = api_data["data"]
set_id = int(api_response[0]["beatmapset_id"])
# fetch (and cache) beatmap set
beatmap_set = await BeatmapSet.from_bsid(set_id)
if beatmap_set is not None:
# the beatmap set has been cached - fetch beatmap from cache
bmap = await cls._from_bid_cache(bid)
# XXX:HACK in this case, BeatmapSet.from_bsid will have
# ensured the map is up to date, so we can just return it
return bmap
if bmap is not None:
if bmap.set._cache_expired():
await bmap.set._update_if_available()
return bmap
""" Lower level API """
# These functions are meant for internal use under
# all normal circumstances and should only be used
# if you're really modifying bancho.py by adding new
# features, or perhaps optimizing parts of the code.
def _parse_from_osuapi_resp(self, osuapi_resp: dict[str, Any]) -> None:
"""Change internal data with the data in osu!api format."""
# NOTE: `self` is not guaranteed to have any attributes
# initialized when this is called.
self.md5 = osuapi_resp["file_md5"]
# self.id = int(osuapi_resp['beatmap_id'])
self.set_id = int(osuapi_resp["beatmapset_id"])
self.artist, self.title, self.version, self.creator = (
osuapi_resp["artist"],
osuapi_resp["title"],
osuapi_resp["version"],
osuapi_resp["creator"],
)
self.filename = (
("{artist} - {title} ({creator}) [{version}].osu")
.format(**osuapi_resp)
.translate(IGNORED_BEATMAP_CHARS)
)
# quite a bit faster than using dt.strptime.
_last_update = osuapi_resp["last_update"]
self.last_update = datetime(
year=int(_last_update[0:4]),
month=int(_last_update[5:7]),
day=int(_last_update[8:10]),
hour=int(_last_update[11:13]),
minute=int(_last_update[14:16]),
second=int(_last_update[17:19]),
)
self.total_length = int(osuapi_resp["total_length"])
if osuapi_resp["max_combo"] is not None:
self.max_combo = int(osuapi_resp["max_combo"])
else:
self.max_combo = 0
# if a map is 'frozen', we keep its status
# even after an update from the osu!api.
if not getattr(self, "frozen", False):
osuapi_status = int(osuapi_resp["approved"])
self.status = RankedStatus.from_osuapi(osuapi_status)
self.mode = GameMode(int(osuapi_resp["mode"]))
if osuapi_resp["bpm"] is not None:
self.bpm = float(osuapi_resp["bpm"])
else:
self.bpm = 0.0
self.cs = float(osuapi_resp["diff_size"])
self.od = float(osuapi_resp["diff_overall"])
self.ar = float(osuapi_resp["diff_approach"])
self.hp = float(osuapi_resp["diff_drain"])
self.diff = float(osuapi_resp["difficultyrating"])
@staticmethod
async def _from_md5_cache(md5: str) -> Beatmap | None:
"""Fetch a map from the cache by md5."""
return app.state.cache.beatmap.get(md5, None)
@staticmethod
async def _from_bid_cache(bid: int) -> Beatmap | None:
"""Fetch a map from the cache by id."""
return app.state.cache.beatmap.get(bid, None)
class BeatmapSet:
"""A class to represent an osu! beatmap set.
Like the Beatmap class, this class provides a high level api
which should always be the preferred method of fetching beatmaps
due to its housekeeping. It will perform caching & invalidation,
handle map updates while minimizing osu!api requests, and always
use the most efficient method available to fetch the beatmap's
information, while maintaining a low overhead.
The only methods you should need are:
await BeatmapSet.from_bsid(bsid: int) -> BeatmapSet | None
BeatmapSet.all_officially_ranked_or_approved() -> bool
BeatmapSet.all_officially_loved() -> bool
Properties:
BeatmapSet.url -> str # https://osu.cmyui.xyz/s/123
Lower level API:
await BeatmapSet._from_bsid_cache(bsid: int) -> BeatmapSet | None
await BeatmapSet._from_bsid_sql(bsid: int) -> BeatmapSet | None
await BeatmapSet._from_bsid_osuapi(bsid: int) -> BeatmapSet | None
BeatmapSet._cache_expired() -> bool
await BeatmapSet._update_if_available() -> None
await BeatmapSet._save_to_sql() -> None
"""
def __init__(
self,
id: int,
last_osuapi_check: datetime,
maps: list[Beatmap] | None = None,
) -> None:
self.id = id
self.maps = maps or []
self.last_osuapi_check = last_osuapi_check
def __repr__(self) -> str:
map_names = []
for bmap in self.maps:
name = f"{bmap.artist} - {bmap.title}"
if name not in map_names:
map_names.append(name)
return ", ".join(map_names)
@property
def url(self) -> str:
"""The online url for this beatmap set."""
return f"https://osu.{app.settings.DOMAIN}/s/{self.id}"
def any_beatmaps_have_official_leaderboards(self) -> bool:
"""Whether all the maps in the set have leaderboards on official servers."""
leaderboard_having_statuses = (
RankedStatus.Loved,
RankedStatus.Ranked,
RankedStatus.Approved,
)
return any(bmap.status in leaderboard_having_statuses for bmap in self.maps)
def _cache_expired(self) -> bool:
"""Whether the cached version of the set is
expired and needs an update from the osu!api."""
current_datetime = datetime.now()
if not self.maps:
return True
# the delta between cache invalidations will increase depending
# on how long it's been since the map was last updated on osu!
last_map_update = max(bmap.last_update for bmap in self.maps)
update_delta = current_datetime - last_map_update
# with a minimum of 2 hours, add 5 hours per year since its update.
# the formula for this is subject to adjustment in the future.
check_delta = timedelta(hours=2 + ((5 / 365) * update_delta.days))
# it's much less likely that a beatmapset who has beatmaps with
# leaderboards on official servers will be updated.
if self.any_beatmaps_have_official_leaderboards():
check_delta *= 4
# we'll cache for an absolute maximum of 1 day.
check_delta = min(check_delta, timedelta(days=1))
return current_datetime > (self.last_osuapi_check + check_delta)
async def _update_if_available(self) -> None:
"""Fetch the newest data from the api, check for differences
and propogate any update into our cache & database."""
try:
api_data = await api_get_beatmaps(s=self.id)
except (httpx.TransportError, httpx.DecodingError):
# NOTE: TransportError is directly caused by the API being unavailable
# NOTE: DecodingError is caused by the API returning HTML and
# normally happens when CF protection is enabled while
# osu! recovers from a DDOS attack
# we do not want to delete the beatmap in this case, so we simply return
# but do not set the last check, as we would like to retry these ASAP
return
if api_data["data"] is not None:
api_response = api_data["data"]
old_maps = {bmap.id: bmap for bmap in self.maps}
new_maps = {int(api_map["beatmap_id"]): api_map for api_map in api_response}
self.last_osuapi_check = datetime.now()
# delete maps from old_maps where old.id not in new_maps
# update maps from old_maps where old.md5 != new.md5
# add maps to old_maps where new.id not in old_maps
updated_maps: list[Beatmap] = []
map_md5s_to_delete: set[str] = set()
# temp value for building the new beatmap
bmap: Beatmap
# find maps in our current state that've been deleted, or need updates
for old_id, old_map in old_maps.items():
if old_id not in new_maps:
# delete map from old_maps
map_md5s_to_delete.add(old_map.md5)
else:
new_map = new_maps[old_id]
new_ranked_status = RankedStatus.from_osuapi(
int(new_map["approved"]),
)
if (
old_map.md5 != new_map["file_md5"]
or old_map.status != new_ranked_status
):
# update map from old_maps
bmap = old_maps[old_id]
bmap._parse_from_osuapi_resp(new_map)
updated_maps.append(bmap)
else:
# map is the same, make no changes
updated_maps.append(old_map) # (this line is _maybe_ needed?)
# find maps that aren't in our current state, and add them
for new_id, new_map in new_maps.items():
if new_id not in old_maps:
# new map we don't have locally, add it
bmap = Beatmap.__new__(Beatmap)
bmap.id = new_id
bmap._parse_from_osuapi_resp(new_map)
# (some implementation-specific stuff not given by api)
bmap.frozen = False
bmap.passes = 0
bmap.plays = 0
bmap.set = self
updated_maps.append(bmap)
# save changes to cache
self.maps = updated_maps
# save changes to sql
if map_md5s_to_delete:
# delete maps
await app.state.services.database.execute(
"DELETE FROM maps WHERE md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# delete scores on the maps
await app.state.services.database.execute(
"DELETE FROM scores WHERE map_md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# update last_osuapi_check
await app.state.services.database.execute(
"REPLACE INTO mapsets "
"(id, server, last_osuapi_check) "
"VALUES (:id, :server, :last_osuapi_check)",
{
"id": self.id,
"server": "osu!",
"last_osuapi_check": self.last_osuapi_check,
},
)
# update maps in sql
await self._save_to_sql()
elif api_data["status_code"] in (404, 200):
# NOTE: 200 can return an empty array of beatmaps,
# so we still delete in this case if the beatmap data is None
# TODO: a couple of open questions here:
# - should we delete the beatmap from the database if it's not in the osu!api?
# - are 404 and 200 the only cases where we should delete the beatmap?
if self.maps:
map_md5s_to_delete = {bmap.md5 for bmap in self.maps}
# delete maps
await app.state.services.database.execute(
"DELETE FROM maps WHERE md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# delete scores on the maps
await app.state.services.database.execute(
"DELETE FROM scores WHERE map_md5 IN :map_md5s",
{"map_md5s": map_md5s_to_delete},
)
# delete set
await app.state.services.database.execute(
"DELETE FROM mapsets WHERE id = :set_id",
{"set_id": self.id},
)
async def _save_to_sql(self) -> None:
"""Save the object's attributes into the database."""
await app.state.services.database.execute_many(
"REPLACE INTO maps ("
"md5, id, server, set_id, "
"artist, title, version, creator, "
"filename, last_update, total_length, "
"max_combo, status, frozen, "
"plays, passes, mode, bpm, "
"cs, od, ar, hp, diff"
") VALUES ("
":md5, :id, :server, :set_id, "
":artist, :title, :version, :creator, "
":filename, :last_update, :total_length, "
":max_combo, :status, :frozen, "
":plays, :passes, :mode, :bpm, "
":cs, :od, :ar, :hp, :diff"
")",
[
{
"md5": bmap.md5,
"id": bmap.id,
"server": "osu!",
"set_id": bmap.set_id,
"artist": bmap.artist,
"title": bmap.title,
"version": bmap.version,
"creator": bmap.creator,
"filename": bmap.filename,
"last_update": bmap.last_update,
"total_length": bmap.total_length,
"max_combo": bmap.max_combo,
"status": bmap.status,
"frozen": bmap.frozen,
"plays": bmap.plays,
"passes": bmap.passes,
"mode": bmap.mode,
"bpm": bmap.bpm,
"cs": bmap.cs,
"od": bmap.od,
"ar": bmap.ar,
"hp": bmap.hp,
"diff": bmap.diff,
}
for bmap in self.maps
],
)
@staticmethod
async def _from_bsid_cache(bsid: int) -> BeatmapSet | None:
"""Fetch a mapset from the cache by set id."""
return app.state.cache.beatmapset.get(bsid, None)
@classmethod
async def _from_bsid_sql(cls, bsid: int) -> BeatmapSet | None:
"""Fetch a mapset from the database by set id."""
last_osuapi_check = await app.state.services.database.fetch_val(
"SELECT last_osuapi_check FROM mapsets WHERE id = :set_id",
{"set_id": bsid},
column=0, # last_osuapi_check
)
if last_osuapi_check is None:
return None
bmap_set = cls(id=bsid, last_osuapi_check=last_osuapi_check)
for row in await maps_repo.fetch_many(set_id=bsid):
bmap = Beatmap(
md5=row["md5"],
id=row["id"],
set_id=row["set_id"],
artist=row["artist"],
title=row["title"],
version=row["version"],
creator=row["creator"],
last_update=row["last_update"],
total_length=row["total_length"],
max_combo=row["max_combo"],
status=RankedStatus(row["status"]),
frozen=row["frozen"],
plays=row["plays"],
passes=row["passes"],
mode=GameMode(row["mode"]),
bpm=row["bpm"],
cs=row["cs"],
od=row["od"],
ar=row["ar"],
hp=row["hp"],
diff=row["diff"],
filename=row["filename"],
map_set=bmap_set,
)
# XXX: tempfix for bancho.py <v3.4.1,
# where filenames weren't stored.
if not bmap.filename:
bmap.filename = (
("{artist} - {title} ({creator}) [{version}].osu")
.format(
artist=row["artist"],
title=row["title"],
creator=row["creator"],
version=row["version"],
)
.translate(IGNORED_BEATMAP_CHARS)
)
await maps_repo.partial_update(bmap.id, filename=bmap.filename)
bmap_set.maps.append(bmap)
return bmap_set
@classmethod
async def _from_bsid_osuapi(cls, bsid: int) -> BeatmapSet | None:
"""Fetch a mapset from the osu!api by set id."""
api_data = await api_get_beatmaps(s=bsid)
if api_data["data"] is not None:
api_response = api_data["data"]
self = cls(id=bsid, last_osuapi_check=datetime.now())
# XXX: pre-mapset bancho.py support
# select all current beatmaps
# that're frozen in the db
res = await app.state.services.database.fetch_all(
"SELECT id, status FROM maps WHERE set_id = :set_id AND frozen = 1",
{"set_id": bsid},
)
current_maps = {row["id"]: row["status"] for row in res}
for api_bmap in api_response:
# newer version available for this map
bmap: Beatmap = Beatmap.__new__(Beatmap)
bmap.id = int(api_bmap["beatmap_id"])
if bmap.id in current_maps:
# map is currently frozen, keep it's status.
bmap.status = RankedStatus(current_maps[bmap.id])
bmap.frozen = True
else:
bmap.frozen = False
bmap._parse_from_osuapi_resp(api_bmap)
# (some implementation-specific stuff not given by api)
bmap.passes = 0
bmap.plays = 0
bmap.set = self
self.maps.append(bmap)
await app.state.services.database.execute(
"REPLACE INTO mapsets "
"(id, server, last_osuapi_check) "
"VALUES (:id, :server, :last_osuapi_check)",
{
"id": self.id,
"server": "osu!",
"last_osuapi_check": self.last_osuapi_check,
},
)
await self._save_to_sql()
return self
return None
@classmethod
async def from_bsid(cls, bsid: int) -> BeatmapSet | None:
"""Cache all maps in a set from the osuapi, optionally
returning beatmaps by their md5 or id."""
bmap_set = await cls._from_bsid_cache(bsid)
did_api_request = False
if not bmap_set:
bmap_set = await cls._from_bsid_sql(bsid)
if not bmap_set:
bmap_set = await cls._from_bsid_osuapi(bsid)
if not bmap_set:
return None
did_api_request = True
# TODO: this can be done less often for certain types of maps,
# such as ones that're ranked on bancho and won't be updated,
# and perhaps ones that haven't been updated in a long time.
if not did_api_request and bmap_set._cache_expired():
await bmap_set._update_if_available()
# cache the beatmap set, and beatmaps
# to be efficient in future requests
cache_beatmap_set(bmap_set)
return bmap_set
def cache_beatmap(beatmap: Beatmap) -> None:
"""Add the beatmap to the cache."""
app.state.cache.beatmap[beatmap.md5] = beatmap
app.state.cache.beatmap[beatmap.id] = beatmap
def cache_beatmap_set(beatmap_set: BeatmapSet) -> None:
"""Add the beatmap set, and each beatmap to the cache."""
app.state.cache.beatmapset[beatmap_set.id] = beatmap_set
for beatmap in beatmap_set.maps:
cache_beatmap(beatmap)

138
app/objects/channel.py Normal file
View File

@@ -0,0 +1,138 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING
import app.packets
import app.state
from app.constants.privileges import Privileges
if TYPE_CHECKING:
from app.objects.player import Player
class Channel:
"""An osu! chat channel.
Possibly confusing attributes
-----------
_name: `str`
A name string of the channel.
The cls.`name` property wraps handling for '#multiplayer' and
'#spectator' when communicating with the osu! client; only use
this attr when you need the channel's true name; otherwise you
should use the `name` property described below.
instance: `bool`
Instanced channels are deleted when all players have left;
this is useful for things like multiplayer, spectator, etc.
"""
def __init__(
self,
name: str,
topic: str,
read_priv: Privileges = Privileges.UNRESTRICTED,
write_priv: Privileges = Privileges.UNRESTRICTED,
auto_join: bool = True,
instance: bool = False,
) -> None:
# TODO: think of better names than `_name` and `name`
self._name = name # 'real' name ('#{multi/spec}_{id}')
if self._name.startswith("#spec_"):
self.name = "#spectator"
elif self._name.startswith("#multi_"):
self.name = "#multiplayer"
else:
self.name = self._name
self.topic = topic
self.read_priv = read_priv
self.write_priv = write_priv
self.auto_join = auto_join
self.instance = instance
self.players: list[Player] = []
def __repr__(self) -> str:
return f"<{self._name}>"
def __contains__(self, player: Player) -> bool:
return player in self.players
# XXX: should this be cached differently?
def can_read(self, priv: Privileges) -> bool:
if not self.read_priv:
return True
return priv & self.read_priv != 0
def can_write(self, priv: Privileges) -> bool:
if not self.write_priv:
return True
return priv & self.write_priv != 0
def send(self, msg: str, sender: Player, to_self: bool = False) -> None:
"""Enqueue `msg` to all appropriate clients from `sender`."""
data = app.packets.send_message(
sender=sender.name,
msg=msg,
recipient=self.name,
sender_id=sender.id,
)
for player in self.players:
if sender.id not in player.blocks and (to_self or player.id != sender.id):
player.enqueue(data)
def send_bot(self, msg: str) -> None:
"""Enqueue `msg` to all connected clients from bot."""
bot = app.state.sessions.bot
msg_len = len(msg)
if msg_len >= 31979: # TODO ??????????
msg = f"message would have crashed games ({msg_len} chars)"
self.enqueue(
app.packets.send_message(
sender=bot.name,
msg=msg,
recipient=self.name,
sender_id=bot.id,
),
)
def send_selective(
self,
msg: str,
sender: Player,
recipients: set[Player],
) -> None:
"""Enqueue `sender`'s `msg` to `recipients`."""
for player in recipients:
if player in self:
player.send(msg, sender=sender, chan=self)
def append(self, player: Player) -> None:
"""Add `player` to the channel's players."""
self.players.append(player)
def remove(self, player: Player) -> None:
"""Remove `player` from the channel's players."""
self.players.remove(player)
if not self.players and self.instance:
# if it's an instance channel and this
# is the last member leaving, just remove
# the channel from the global list.
app.state.sessions.channels.remove(self)
def enqueue(self, data: bytes, immune: Sequence[int] = []) -> None:
"""Enqueue `data` to all connected clients not in `immune`."""
for player in self.players:
if player.id not in immune:
player.enqueue(data)

314
app/objects/collections.py Normal file
View File

@@ -0,0 +1,314 @@
from __future__ import annotations
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Sequence
from typing import Any
import databases.core
import app.settings
import app.state
import app.utils
from app.constants.privileges import ClanPrivileges
from app.constants.privileges import Privileges
from app.logging import Ansi
from app.logging import log
from app.objects.channel import Channel
from app.objects.match import Match
from app.objects.player import Player
from app.repositories import channels as channels_repo
from app.repositories import clans as clans_repo
from app.repositories import users as users_repo
from app.utils import make_safe_name
class Channels(list[Channel]):
"""The currently active chat channels on the server."""
def __iter__(self) -> Iterator[Channel]:
return super().__iter__()
def __contains__(self, o: object) -> bool:
"""Check whether internal list contains `o`."""
# Allow string to be passed to compare vs. name.
if isinstance(o, str):
return o in (chan.name for chan in self)
else:
return super().__contains__(o)
def __repr__(self) -> str:
# XXX: we use the "real" name, aka
# #multi_1 instead of #multiplayer
# #spect_1 instead of #spectator.
return f'[{", ".join(c._name for c in self)}]'
def get_by_name(self, name: str) -> Channel | None:
"""Get a channel from the list by `name`."""
for channel in self:
if channel._name == name:
return channel
return None
def append(self, channel: Channel) -> None:
"""Append `channel` to the list."""
super().append(channel)
if app.settings.DEBUG:
log(f"{channel} added to channels list.")
def extend(self, channels: Iterable[Channel]) -> None:
"""Extend the list with `channels`."""
super().extend(channels)
if app.settings.DEBUG:
log(f"{channels} added to channels list.")
def remove(self, channel: Channel) -> None:
"""Remove `channel` from the list."""
super().remove(channel)
if app.settings.DEBUG:
log(f"{channel} removed from channels list.")
async def prepare(self) -> None:
"""Fetch data from sql & return; preparing to run the server."""
log("Fetching channels from sql.", Ansi.LCYAN)
for row in await channels_repo.fetch_many():
self.append(
Channel(
name=row["name"],
topic=row["topic"],
read_priv=Privileges(row["read_priv"]),
write_priv=Privileges(row["write_priv"]),
auto_join=row["auto_join"] == 1,
),
)
class Matches(list[Match | None]):
"""The currently active multiplayer matches on the server."""
def __init__(self) -> None:
MAX_MATCHES = 64 # TODO: refactor this out of existence
super().__init__([None] * MAX_MATCHES)
def __iter__(self) -> Iterator[Match | None]:
return super().__iter__()
def __repr__(self) -> str:
return f'[{", ".join(match.name for match in self if match)}]'
def get_free(self) -> int | None:
"""Return the first free match id from `self`."""
for idx, match in enumerate(self):
if match is None:
return idx
return None
def remove(self, match: Match | None) -> None:
"""Remove `match` from the list."""
for i, _m in enumerate(self):
if match is _m:
self[i] = None
break
if app.settings.DEBUG:
log(f"{match} removed from matches list.")
class Players(list[Player]):
"""The currently active players on the server."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __iter__(self) -> Iterator[Player]:
return super().__iter__()
def __contains__(self, player: object) -> bool:
# allow us to either pass in the player
# obj, or the player name as a string.
if isinstance(player, str):
return player in (player.name for player in self)
else:
return super().__contains__(player)
def __repr__(self) -> str:
return f'[{", ".join(map(repr, self))}]'
@property
def ids(self) -> set[int]:
"""Return a set of the current ids in the list."""
return {p.id for p in self}
@property
def staff(self) -> set[Player]:
"""Return a set of the current staff online."""
return {p for p in self if p.priv & Privileges.STAFF}
@property
def restricted(self) -> set[Player]:
"""Return a set of the current restricted players."""
return {p for p in self if not p.priv & Privileges.UNRESTRICTED}
@property
def unrestricted(self) -> set[Player]:
"""Return a set of the current unrestricted players."""
return {p for p in self if p.priv & Privileges.UNRESTRICTED}
def enqueue(self, data: bytes, immune: Sequence[Player] = []) -> None:
"""Enqueue `data` to all players, except for those in `immune`."""
for player in self:
if player not in immune:
player.enqueue(data)
def get(
self,
token: str | None = None,
id: int | None = None,
name: str | None = None,
) -> Player | None:
"""Get a player by token, id, or name from cache."""
for player in self:
if token is not None:
if player.token == token:
return player
elif id is not None:
if player.id == id:
return player
elif name is not None:
if player.safe_name == make_safe_name(name):
return player
return None
async def get_sql(
self,
id: int | None = None,
name: str | None = None,
) -> Player | None:
"""Get a player by token, id, or name from sql."""
# try to get from sql.
player = await users_repo.fetch_one(
id=id,
name=name,
fetch_all_fields=True,
)
if player is None:
return None
clan_id: int | None = None
clan_priv: ClanPrivileges | None = None
if player["clan_id"] != 0:
clan_id = player["clan_id"]
clan_priv = ClanPrivileges(player["clan_priv"])
return Player(
id=player["id"],
name=player["name"],
priv=Privileges(player["priv"]),
pw_bcrypt=player["pw_bcrypt"].encode(),
token=Player.generate_token(),
clan_id=clan_id,
clan_priv=clan_priv,
geoloc={
"latitude": 0.0,
"longitude": 0.0,
"country": {
"acronym": player["country"],
"numeric": app.state.services.country_codes[player["country"]],
},
},
silence_end=player["silence_end"],
donor_end=player["donor_end"],
api_key=player["api_key"],
)
async def from_cache_or_sql(
self,
id: int | None = None,
name: str | None = None,
) -> Player | None:
"""Try to get player from cache, or sql as fallback."""
player = self.get(id=id, name=name)
if player is not None:
return player
player = await self.get_sql(id=id, name=name)
if player is not None:
return player
return None
async def from_login(
self,
name: str,
pw_md5: str,
sql: bool = False,
) -> Player | None:
"""Return a player with a given name & pw_md5, from cache or sql."""
player = self.get(name=name)
if not player:
if not sql:
return None
player = await self.get_sql(name=name)
if not player:
return None
assert player.pw_bcrypt is not None
if app.state.cache.bcrypt[player.pw_bcrypt] == pw_md5.encode():
return player
return None
def append(self, player: Player) -> None:
"""Append `p` to the list."""
if player in self:
if app.settings.DEBUG:
log(f"{player} double-added to global player list?")
return
super().append(player)
def remove(self, player: Player) -> None:
"""Remove `p` from the list."""
if player not in self:
if app.settings.DEBUG:
log(f"{player} removed from player list when not online?")
return
super().remove(player)
async def initialize_ram_caches() -> None:
"""Setup & cache the global collections before listening for connections."""
# fetch channels, clans and pools from db
await app.state.sessions.channels.prepare()
bot = await users_repo.fetch_one(id=1)
if bot is None:
raise RuntimeError("Bot account not found in database.")
# create bot & add it to online players
app.state.sessions.bot = Player(
id=1,
name=bot["name"],
priv=Privileges.UNRESTRICTED,
pw_bcrypt=None,
token=Player.generate_token(),
login_time=float(0x7FFFFFFF), # (never auto-dc)
is_bot_client=True,
)
app.state.sessions.players.append(app.state.sessions.bot)
# static api keys
app.state.sessions.api_keys = {
row["api_key"]: row["id"]
for row in await app.state.services.database.fetch_all(
"SELECT id, api_key FROM users WHERE api_key IS NOT NULL",
)
}

552
app/objects/match.py Normal file
View File

@@ -0,0 +1,552 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Sequence
from datetime import datetime as datetime
from datetime import timedelta as timedelta
from enum import IntEnum
from enum import unique
from typing import TYPE_CHECKING
from typing import TypedDict
import app.packets
import app.settings
import app.state
from app.constants import regexes
from app.constants.gamemodes import GameMode
from app.constants.mods import Mods
from app.objects.beatmap import Beatmap
from app.repositories.tourney_pools import TourneyPool
from app.utils import escape_enum
from app.utils import pymysql_encode
if TYPE_CHECKING:
from asyncio import TimerHandle
from app.objects.channel import Channel
from app.objects.player import Player
MAX_MATCH_NAME_LENGTH = 50
@unique
@pymysql_encode(escape_enum)
class SlotStatus(IntEnum):
open = 1
locked = 2
not_ready = 4
ready = 8
no_map = 16
playing = 32
complete = 64
quit = 128
# has_player = not_ready | ready | no_map | playing | complete
@unique
@pymysql_encode(escape_enum)
class MatchTeams(IntEnum):
neutral = 0
blue = 1
red = 2
"""
# implemented by osu! and send between client/server,
# quite frequently even, but seems useless??
@unique
@pymysql_encode(escape_enum)
class MatchTypes(IntEnum):
standard = 0
powerplay = 1 # literally no idea what this is for
"""
@unique
@pymysql_encode(escape_enum)
class MatchWinConditions(IntEnum):
score = 0
accuracy = 1
combo = 2
scorev2 = 3
@unique
@pymysql_encode(escape_enum)
class MatchTeamTypes(IntEnum):
head_to_head = 0
tag_coop = 1
team_vs = 2
tag_team_vs = 3
class Slot:
"""An individual player slot in an osu! multiplayer match."""
def __init__(self) -> None:
self.player: Player | None = None
self.status = SlotStatus.open
self.team = MatchTeams.neutral
self.mods = Mods.NOMOD
self.loaded = False
self.skipped = False
def empty(self) -> bool:
return self.player is None
def copy_from(self, other: Slot) -> None:
self.player = other.player
self.status = other.status
self.team = other.team
self.mods = other.mods
def reset(self, new_status: SlotStatus = SlotStatus.open) -> None:
self.player = None
self.status = new_status
self.team = MatchTeams.neutral
self.mods = Mods.NOMOD
self.loaded = False
self.skipped = False
class StartingTimers(TypedDict):
start: TimerHandle
alerts: list[TimerHandle]
time: float
class Match:
"""\
An osu! multiplayer match.
Possibly confusing attributes
-----------
_refs: set[`Player`]
A set of players who have access to mp commands in the match.
These can be used with the !mp <addref/rmref/listref> commands.
slots: list[`Slot`]
A list of 16 `Slot` objects representing the match's slots.
starting: dict[str, `TimerHandle`] | None
Used when the match is started with !mp start <seconds>.
It stores both the starting timer, and the chat alert timers.
seed: `int`
The seed used for osu!mania's random mod.
use_pp_scoring: `bool`
Whether pp should be used as a win condition override during scrims.
"""
def __init__(
self,
id: int,
name: str,
password: str,
has_public_history: bool,
map_name: str,
map_id: int,
map_md5: str,
host_id: int,
mode: GameMode,
mods: Mods,
win_condition: MatchWinConditions,
team_type: MatchTeamTypes,
freemods: bool,
seed: int,
chat_channel: Channel,
) -> None:
self.id = id
self.name = name
self.passwd = password
self.has_public_history = has_public_history
self.host_id = host_id
self._refs: set[Player] = set()
self.map_id = map_id
self.map_md5 = map_md5
self.map_name = map_name
self.prev_map_id = 0 # previously chosen map
self.mods = mods
self.mode = mode
self.freemods = freemods
self.chat = chat_channel
self.slots = [Slot() for _ in range(16)]
# self.type = MatchTypes.standard
self.team_type = team_type
self.win_condition = win_condition
self.in_progress = False
self.starting: StartingTimers | None = None
self.seed = seed # used for mania random mod
self.tourney_pool: TourneyPool | None = None
# scrimmage stuff
self.is_scrimming = False
self.match_points: dict[MatchTeams | Player, int] = defaultdict(int)
self.bans: set[tuple[Mods, int]] = set()
self.winners: list[Player | MatchTeams | None] = [] # none for tie
self.winning_pts = 0
self.use_pp_scoring = False # only for scrims
self.tourney_clients: set[int] = set() # player ids
@property
def host(self) -> Player:
player = app.state.sessions.players.get(id=self.host_id)
if player is None:
raise ValueError(
f"Host with id {self.host_id} not found for match {self!r}",
)
return player
@property
def url(self) -> str:
"""The match's invitation url."""
return f"osump://{self.id}/{self.passwd}"
@property
def map_url(self) -> str:
"""The osu! beatmap url for `self`'s map."""
return f"https://osu.{app.settings.DOMAIN}/b/{self.map_id}"
@property
def embed(self) -> str:
"""An osu! chat embed for `self`."""
return f"[{self.url} {self.name}]"
@property
def map_embed(self) -> str:
"""An osu! chat embed for `self`'s map."""
return f"[{self.map_url} {self.map_name}]"
@property
def refs(self) -> set[Player]:
"""Return all players with referee permissions."""
refs = self._refs
if self.host is not None:
refs.add(self.host)
return refs
def __repr__(self) -> str:
return f"<{self.name} ({self.id})>"
def get_slot(self, player: Player) -> Slot | None:
"""Return the slot containing a given player."""
for s in self.slots:
if player is s.player:
return s
return None
def get_slot_id(self, player: Player) -> int | None:
"""Return the slot index containing a given player."""
for idx, s in enumerate(self.slots):
if player is s.player:
return idx
return None
def get_free(self) -> int | None:
"""Return the first unoccupied slot in multi, if any."""
for idx, s in enumerate(self.slots):
if s.status == SlotStatus.open:
return idx
return None
def get_host_slot(self) -> Slot | None:
"""Return the slot containing the host."""
for s in self.slots:
if s.player is not None and s.player is self.host:
return s
return None
def copy(self, m: Match) -> None:
"""Fully copy the data of another match obj."""
self.map_id = m.map_id
self.map_md5 = m.map_md5
self.map_name = m.map_name
self.freemods = m.freemods
self.mode = m.mode
self.team_type = m.team_type
self.win_condition = m.win_condition
self.mods = m.mods
self.name = m.name
def enqueue(
self,
data: bytes,
lobby: bool = True,
immune: Sequence[int] = [],
) -> None:
"""Add data to be sent to all clients in the match."""
self.chat.enqueue(data, immune)
lchan = app.state.sessions.channels.get_by_name("#lobby")
if lobby and lchan and lchan.players:
lchan.enqueue(data)
def enqueue_state(self, lobby: bool = True) -> None:
"""Enqueue `self`'s state to players in the match & lobby."""
# TODO: hmm this is pretty bad, writes twice
# send password only to users currently in the match.
self.chat.enqueue(app.packets.update_match(self, send_pw=True))
lchan = app.state.sessions.channels.get_by_name("#lobby")
if lobby and lchan and lchan.players:
lchan.enqueue(app.packets.update_match(self, send_pw=False))
def unready_players(self, expected: SlotStatus = SlotStatus.ready) -> None:
"""Unready any players in the `expected` state."""
for s in self.slots:
if s.status == expected:
s.status = SlotStatus.not_ready
def reset_players_loaded_status(self) -> None:
"""Reset all players' loaded status."""
for s in self.slots:
s.loaded = False
s.skipped = False
def start(self) -> None:
"""Start the match for all ready players with the map."""
no_map: list[int] = []
for s in self.slots:
# start each player who has the map.
if s.player is not None:
if s.status != SlotStatus.no_map:
s.status = SlotStatus.playing
else:
no_map.append(s.player.id)
self.in_progress = True
self.enqueue(app.packets.match_start(self), immune=no_map, lobby=False)
self.enqueue_state()
def reset_scrim(self) -> None:
"""Reset the current scrim's winning points & bans."""
self.match_points.clear()
self.winners.clear()
self.bans.clear()
async def await_submissions(
self,
was_playing: Sequence[Slot],
) -> tuple[dict[MatchTeams | Player, int], Sequence[Player]]:
"""Await score submissions from all players in completed state."""
scores: dict[MatchTeams | Player, int] = defaultdict(int)
didnt_submit: list[Player] = []
time_waited = 0.0 # allow up to 10s (total, not per player)
ffa = self.team_type in (MatchTeamTypes.head_to_head, MatchTeamTypes.tag_coop)
if self.use_pp_scoring:
win_cond = "pp"
else:
win_cond = ("score", "acc", "max_combo", "score")[self.win_condition]
bmap = await Beatmap.from_md5(self.map_md5)
if not bmap:
# map isn't submitted
return {}, ()
for s in was_playing:
# continue trying to fetch each player's
# scores until they've all been submitted.
while True:
assert s.player is not None
rc_score = s.player.recent_score
max_age = datetime.now() - timedelta(
seconds=bmap.total_length + time_waited + 0.5,
)
if (
rc_score
and rc_score.bmap
and rc_score.bmap.md5 == self.map_md5
and rc_score.server_time > max_age
):
# score found, add to our scores dict if != 0.
score: int = getattr(rc_score, win_cond)
if score:
key: MatchTeams | Player = s.player if ffa else s.team
scores[key] += score
break
# wait 0.5s and try again
await asyncio.sleep(0.5)
time_waited += 0.5
if time_waited > 10:
# inform the match this user didn't
# submit a score in time, and skip them.
didnt_submit.append(s.player)
break
# all scores retrieved, update the match.
return scores, didnt_submit
async def update_matchpoints(self, was_playing: Sequence[Slot]) -> None:
"""\
Determine the winner from `scores`, increment & inform players.
This automatically works with the match settings (such as
win condition, teams, & co-op) to determine the appropriate
winner, and will use any team names included in the match name,
along with the match name (fmt: OWC2020: (Team1) vs. (Team2)).
For the examples, we'll use accuracy as a win condition.
Teams, match title: `OWC2015: (United States) vs. (China)`.
United States takes the point! (293.32% vs 292.12%)
Total Score: United States | 7 - 2 | China
United States takes the match, finishing OWC2015 with a score of 7 - 2!
FFA, the top <=3 players will be listed for the total score.
Justice takes the point! (94.32% [Match avg. 91.22%])
Total Score: Justice - 3 | cmyui - 2 | FrostiDrinks - 2
Justice takes the match, finishing with a score of 4 - 2!
"""
scores, didnt_submit = await self.await_submissions(was_playing)
for player in didnt_submit:
self.chat.send_bot(f"{player} didn't submit a score (timeout: 10s).")
if not scores:
self.chat.send_bot("Scores could not be calculated.")
return None
ffa = self.team_type in (
MatchTeamTypes.head_to_head,
MatchTeamTypes.tag_coop,
)
# all scores are equal, it was a tie.
if len(scores) != 1 and len(set(scores.values())) == 1:
self.winners.append(None)
self.chat.send_bot("The point has ended in a tie!")
return None
# Find the winner & increment their matchpoints.
winner: Player | MatchTeams = max(scores, key=lambda k: scores[k])
self.winners.append(winner)
self.match_points[winner] += 1
msg: list[str] = []
def add_suffix(score: int | float) -> str | int | float:
if self.use_pp_scoring:
return f"{score:.2f}pp"
elif self.win_condition == MatchWinConditions.accuracy:
return f"{score:.2f}%"
elif self.win_condition == MatchWinConditions.combo:
return f"{score}x"
else:
return str(score)
if ffa:
from app.objects.player import Player
assert isinstance(winner, Player)
msg.append(
f"{winner.name} takes the point! ({add_suffix(scores[winner])} "
f"[Match avg. {add_suffix(sum(scores.values()) / len(scores))}])",
)
wmp = self.match_points[winner]
# check if match point #1 has enough points to win.
if self.winning_pts and wmp == self.winning_pts:
# we have a champion, announce & reset our match.
self.is_scrimming = False
self.reset_scrim()
self.bans.clear()
m = f"{winner.name} takes the match! Congratulations!"
else:
# no winner, just announce the match points so far.
# for ffa, we'll only announce the top <=3 players.
m_points = sorted(self.match_points.items(), key=lambda x: x[1])
m = f"Total Score: {' | '.join([f'{k.name} - {v}' for k, v in m_points])}"
msg.append(m)
del m
else: # teams
assert isinstance(winner, MatchTeams)
r_match = regexes.TOURNEY_MATCHNAME.match(self.name)
if r_match:
match_name = r_match["name"]
team_names = {
MatchTeams.blue: r_match["T1"],
MatchTeams.red: r_match["T2"],
}
else:
match_name = self.name
team_names = {MatchTeams.blue: "Blue", MatchTeams.red: "Red"}
# teams are binary, so we have a loser.
if winner is MatchTeams.blue:
loser = MatchTeams.red
else:
loser = MatchTeams.blue
# from match name if available, else blue/red.
wname = team_names[winner]
lname = team_names[loser]
# scores from the recent play
# (according to win condition)
ws = add_suffix(scores[winner])
ls = add_suffix(scores[loser])
# total win/loss score in the match.
wmp = self.match_points[winner]
lmp = self.match_points[loser]
# announce the score for the most recent play.
msg.append(f"{wname} takes the point! ({ws} vs. {ls})")
# check if the winner has enough match points to win the match.
if self.winning_pts and wmp == self.winning_pts:
# we have a champion, announce & reset our match.
self.is_scrimming = False
self.reset_scrim()
msg.append(
f"{wname} takes the match, finishing {match_name} "
f"with a score of {wmp} - {lmp}! Congratulations!",
)
else:
# no winner, just announce the match points so far.
msg.append(f"Total Score: {wname} | {wmp} - {lmp} | {lname}")
if didnt_submit:
self.chat.send_bot(
"If you'd like to perform a rematch, "
"please use the `!mp rematch` command.",
)
for line in msg:
self.chat.send_bot(line)

8
app/objects/models.py Normal file
View File

@@ -0,0 +1,8 @@
from __future__ import annotations
from pydantic import BaseModel
class OsuBeatmapRequestForm(BaseModel):
Filenames: list[str]
Ids: list[int]

1017
app/objects/player.py Normal file

File diff suppressed because it is too large Load Diff

453
app/objects/score.py Normal file
View File

@@ -0,0 +1,453 @@
from __future__ import annotations
import functools
import hashlib
from datetime import datetime
from enum import IntEnum
from enum import unique
from pathlib import Path
from typing import TYPE_CHECKING
import app.state
import app.usecases.performance
import app.utils
from app.constants.clientflags import ClientFlags
from app.constants.gamemodes import GameMode
from app.constants.mods import Mods
from app.objects.beatmap import Beatmap
from app.repositories import scores as scores_repo
from app.usecases.performance import ScoreParams
from app.utils import escape_enum
from app.utils import pymysql_encode
if TYPE_CHECKING:
from app.objects.player import Player
BEATMAPS_PATH = Path.cwd() / ".data/osu"
@unique
class Grade(IntEnum):
# NOTE: these are implemented in the opposite order
# as osu! to make more sense with <> operators.
N = 0
F = 1
D = 2
C = 3
B = 4
A = 5
S = 6 # S
SH = 7 # HD S
X = 8 # SS
XH = 9 # HD SS
@classmethod
@functools.cache
def from_str(cls, s: str) -> Grade:
return {
"xh": Grade.XH,
"x": Grade.X,
"sh": Grade.SH,
"s": Grade.S,
"a": Grade.A,
"b": Grade.B,
"c": Grade.C,
"d": Grade.D,
"f": Grade.F,
"n": Grade.N,
}[s.lower()]
def __format__(self, format_spec: str) -> str:
if format_spec == "stats_column":
return f"{self.name.lower()}_count"
else:
raise ValueError(f"Invalid format specifier {format_spec}")
@unique
@pymysql_encode(escape_enum)
class SubmissionStatus(IntEnum):
# TODO: make a system more like bancho's?
FAILED = 0
SUBMITTED = 1
BEST = 2
def __repr__(self) -> str:
return {
self.FAILED: "Failed",
self.SUBMITTED: "Submitted",
self.BEST: "Best",
}[self]
class Score:
"""\
Server side representation of an osu! score; any gamemode.
Possibly confusing attributes
-----------
bmap: `Beatmap | None`
A beatmap obj representing the osu map.
player: `Player | None`
A player obj of the player who submitted the score.
grade: `Grade`
The letter grade in the score.
rank: `int`
The leaderboard placement of the score.
perfect: `bool`
Whether the score is a full-combo.
time_elapsed: `int`
The total elapsed time of the play (in milliseconds).
client_flags: `int`
osu!'s old anticheat flags.
prev_best: `Score | None`
The previous best score before this play was submitted.
NOTE: just because a score has a `prev_best` attribute does
mean the score is our best score on the map! the `status`
value will always be accurate for any score.
"""
def __init__(self) -> None:
# TODO: check whether the reamining Optional's should be
self.id: int | None = None
self.bmap: Beatmap | None = None
self.player: Player | None = None
self.mode: GameMode
self.mods: Mods
self.pp: float
self.sr: float
self.score: int
self.max_combo: int
self.acc: float
# TODO: perhaps abstract these differently
# since they're mode dependant? feels weird..
self.n300: int
self.n100: int # n150 for taiko
self.n50: int
self.nmiss: int
self.ngeki: int
self.nkatu: int
self.grade: Grade
self.passed: bool
self.perfect: bool
self.status: SubmissionStatus
self.client_time: datetime
self.server_time: datetime
self.time_elapsed: int
self.client_flags: ClientFlags
self.client_checksum: str
self.rank: int | None = None
self.prev_best: Score | None = None
def __repr__(self) -> str:
# TODO: i really need to clean up my reprs
try:
assert self.bmap is not None
return (
f"<{self.acc:.2f}% {self.max_combo}x {self.nmiss}M "
f"#{self.rank} on {self.bmap.full_name} for {self.pp:,.2f}pp>"
)
except:
return super().__repr__()
"""Classmethods to fetch a score object from various data types."""
@classmethod
async def from_sql(cls, score_id: int) -> Score | None:
"""Create a score object from sql using its scoreid."""
rec = await scores_repo.fetch_one(score_id)
if rec is None:
return None
s = cls()
s.id = rec["id"]
s.bmap = await Beatmap.from_md5(rec["map_md5"])
s.player = await app.state.sessions.players.from_cache_or_sql(id=rec["userid"])
s.sr = 0.0 # TODO
s.pp = rec["pp"]
s.score = rec["score"]
s.max_combo = rec["max_combo"]
s.mods = Mods(rec["mods"])
s.acc = rec["acc"]
s.n300 = rec["n300"]
s.n100 = rec["n100"]
s.n50 = rec["n50"]
s.nmiss = rec["nmiss"]
s.ngeki = rec["ngeki"]
s.nkatu = rec["nkatu"]
s.grade = Grade.from_str(rec["grade"])
s.perfect = rec["perfect"] == 1
s.status = SubmissionStatus(rec["status"])
s.passed = s.status != SubmissionStatus.FAILED
s.mode = GameMode(rec["mode"])
s.server_time = rec["play_time"]
s.time_elapsed = rec["time_elapsed"]
s.client_flags = ClientFlags(rec["client_flags"])
s.client_checksum = rec["online_checksum"]
if s.bmap:
s.rank = await s.calculate_placement()
return s
@classmethod
def from_submission(cls, data: list[str]) -> Score:
"""Create a score object from an osu! submission string."""
s = cls()
""" parse the following format
# 0 online_checksum
# 1 n300
# 2 n100
# 3 n50
# 4 ngeki
# 5 nkatu
# 6 nmiss
# 7 score
# 8 max_combo
# 9 perfect
# 10 grade
# 11 mods
# 12 passed
# 13 gamemode
# 14 play_time # yyMMddHHmmss
# 15 osu_version + (" " * client_flags)
"""
s.client_checksum = data[0]
s.n300 = int(data[1])
s.n100 = int(data[2])
s.n50 = int(data[3])
s.ngeki = int(data[4])
s.nkatu = int(data[5])
s.nmiss = int(data[6])
s.score = int(data[7])
s.max_combo = int(data[8])
s.perfect = data[9] == "True"
s.grade = Grade.from_str(data[10])
s.mods = Mods(int(data[11]))
s.passed = data[12] == "True"
s.mode = GameMode.from_params(int(data[13]), s.mods)
s.client_time = datetime.strptime(data[14], "%y%m%d%H%M%S")
s.client_flags = ClientFlags(data[15].count(" ") & ~4)
s.server_time = datetime.now()
return s
def compute_online_checksum(
self,
osu_version: str,
osu_client_hash: str,
storyboard_checksum: str,
) -> str:
"""Validate the online checksum of the score."""
assert self.player is not None
assert self.bmap is not None
return hashlib.md5(
"chickenmcnuggets{0}o15{1}{2}smustard{3}{4}uu{5}{6}{7}{8}{9}{10}{11}Q{12}{13}{15}{14:%y%m%d%H%M%S}{16}{17}".format(
self.n100 + self.n300,
self.n50,
self.ngeki,
self.nkatu,
self.nmiss,
self.bmap.md5,
self.max_combo,
self.perfect,
self.player.name,
self.score,
self.grade.name,
int(self.mods),
self.passed,
self.mode.as_vanilla,
self.client_time,
osu_version, # 20210520
osu_client_hash,
storyboard_checksum,
# yyMMddHHmmss
).encode(),
).hexdigest()
"""Methods to calculate internal data for a score."""
async def calculate_placement(self) -> int:
assert self.bmap is not None
if self.mode >= GameMode.RELAX_OSU:
scoring_metric = "pp"
score = self.pp
else:
scoring_metric = "score"
score = self.score
num_better_scores: int | None = await app.state.services.database.fetch_val(
"SELECT COUNT(*) AS c FROM scores s "
"INNER JOIN users u ON u.id = s.userid "
"WHERE s.map_md5 = :map_md5 AND s.mode = :mode "
"AND s.status = 2 AND u.priv & 1 "
f"AND s.{scoring_metric} > :score",
{
"map_md5": self.bmap.md5,
"mode": self.mode,
"score": score,
},
column=0, # COUNT(*)
)
assert num_better_scores is not None
return num_better_scores + 1
def calculate_performance(self, beatmap_id: int) -> tuple[float, float]:
"""Calculate PP and star rating for our score."""
mode_vn = self.mode.as_vanilla
score_args = ScoreParams(
mode=mode_vn,
mods=int(self.mods),
combo=self.max_combo,
ngeki=self.ngeki,
n300=self.n300,
nkatu=self.nkatu,
n100=self.n100,
n50=self.n50,
nmiss=self.nmiss,
)
result = app.usecases.performance.calculate_performances(
osu_file_path=str(BEATMAPS_PATH / f"{beatmap_id}.osu"),
scores=[score_args],
)
return result[0]["performance"]["pp"], result[0]["difficulty"]["stars"]
async def calculate_status(self) -> None:
"""Calculate the submission status of a submitted score."""
assert self.player is not None
assert self.bmap is not None
recs = await scores_repo.fetch_many(
user_id=self.player.id,
map_md5=self.bmap.md5,
mode=self.mode,
status=SubmissionStatus.BEST,
)
if recs:
rec = recs[0]
# we have a score on the map.
# save it as our previous best score.
self.prev_best = await Score.from_sql(rec["id"])
assert self.prev_best is not None
# if our new score is better, update
# both of our score's submission statuses.
# NOTE: this will be updated in sql later on in submission
if self.pp > rec["pp"]:
self.status = SubmissionStatus.BEST
self.prev_best.status = SubmissionStatus.SUBMITTED
else:
self.status = SubmissionStatus.SUBMITTED
else:
# this is our first score on the map.
self.status = SubmissionStatus.BEST
def calculate_accuracy(self) -> float:
"""Calculate the accuracy of our score."""
mode_vn = self.mode.as_vanilla
if mode_vn == 0: # osu!
total = self.n300 + self.n100 + self.n50 + self.nmiss
if total == 0:
return 0.0
return (
100.0
* ((self.n300 * 300.0) + (self.n100 * 100.0) + (self.n50 * 50.0))
/ (total * 300.0)
)
elif mode_vn == 1: # osu!taiko
total = self.n300 + self.n100 + self.nmiss
if total == 0:
return 0.0
return 100.0 * ((self.n100 * 0.5) + self.n300) / total
elif mode_vn == 2: # osu!catch
total = self.n300 + self.n100 + self.n50 + self.nkatu + self.nmiss
if total == 0:
return 0.0
return 100.0 * (self.n300 + self.n100 + self.n50) / total
elif mode_vn == 3: # osu!mania
total = (
self.n300 + self.n100 + self.n50 + self.ngeki + self.nkatu + self.nmiss
)
if total == 0:
return 0.0
if self.mods & Mods.SCOREV2:
return (
100.0
* (
(self.n50 * 50.0)
+ (self.n100 * 100.0)
+ (self.nkatu * 200.0)
+ (self.n300 * 300.0)
+ (self.ngeki * 305.0)
)
/ (total * 305.0)
)
return (
100.0
* (
(self.n50 * 50.0)
+ (self.n100 * 100.0)
+ (self.nkatu * 200.0)
+ ((self.n300 + self.ngeki) * 300.0)
)
/ (total * 300.0)
)
else:
raise Exception(f"Invalid vanilla mode {mode_vn}")
""" Methods for updating a score. """
async def increment_replay_views(self) -> None:
# TODO: move replay views to be per-score rather than per-user
assert self.player is not None
# TODO: apparently cached stats don't store replay views?
# need to refactor that to be able to use stats_repo here
await app.state.services.database.execute(
f"UPDATE stats "
"SET replay_views = replay_views + 1 "
"WHERE id = :user_id AND mode = :mode",
{"user_id": self.player.id, "mode": self.mode},
)