mirror of
https://github.com/nihilvux/bancho.py.git
synced 2025-09-16 10:38:39 -07:00
255 lines
7.4 KiB
Python
255 lines
7.4 KiB
Python
from __future__ import annotations
|
|
|
|
import ctypes
|
|
import inspect
|
|
import os
|
|
import socket
|
|
import sys
|
|
from collections.abc import Callable
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import TypedDict
|
|
from typing import TypeVar
|
|
|
|
import httpx
|
|
import pymysql
|
|
|
|
import app.settings
|
|
from app.logging import Ansi
|
|
from app.logging import log
|
|
|
|
if TYPE_CHECKING:
|
|
from app.repositories.users import User
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
DATA_PATH = Path.cwd() / ".data"
|
|
ACHIEVEMENTS_ASSETS_PATH = DATA_PATH / "assets/medals/client"
|
|
DEFAULT_AVATAR_PATH = DATA_PATH / "avatars/default.jpg"
|
|
|
|
|
|
def make_safe_name(name: str) -> str:
|
|
"""Return a name safe for usage in sql."""
|
|
return name.lower().replace(" ", "_")
|
|
|
|
|
|
def determine_highest_ranking_clan_member(members: list[User]) -> User:
|
|
return next(iter(sorted(members, key=lambda m: m["clan_priv"], reverse=True)))
|
|
|
|
|
|
def _download_achievement_images_osu(achievements_path: Path) -> bool:
|
|
"""Download all used achievement images (one by one, from osu!)."""
|
|
achs: list[str] = []
|
|
|
|
for resolution in ("", "@2x"):
|
|
for mode in ("osu", "taiko", "fruits", "mania"):
|
|
# only osu!std has 9 & 10 star pass/fc medals.
|
|
for star_rating in range(1, 1 + (10 if mode == "osu" else 8)):
|
|
achs.append(f"{mode}-skill-pass-{star_rating}{resolution}.png")
|
|
achs.append(f"{mode}-skill-fc-{star_rating}{resolution}.png")
|
|
|
|
for combo in (500, 750, 1000, 2000):
|
|
achs.append(f"osu-combo-{combo}{resolution}.png")
|
|
|
|
for mod in (
|
|
"suddendeath",
|
|
"hidden",
|
|
"perfect",
|
|
"hardrock",
|
|
"doubletime",
|
|
"flashlight",
|
|
"easy",
|
|
"nofail",
|
|
"nightcore",
|
|
"halftime",
|
|
"spunout",
|
|
):
|
|
achs.append(f"all-intro-{mod}{resolution}.png")
|
|
|
|
log("Downloading achievement images from osu!.", Ansi.LCYAN)
|
|
|
|
for ach in achs:
|
|
resp = httpx.get(f"https://assets.ppy.sh/medals/client/{ach}")
|
|
if resp.status_code != 200:
|
|
return False
|
|
|
|
log(f"Saving achievement: {ach}", Ansi.LCYAN)
|
|
(achievements_path / ach).write_bytes(resp.content)
|
|
|
|
return True
|
|
|
|
|
|
def download_achievement_images(achievements_path: Path) -> None:
|
|
"""Download all used achievement images (using the best available source)."""
|
|
|
|
# download individual files from the official osu! servers
|
|
downloaded = _download_achievement_images_osu(achievements_path)
|
|
|
|
if downloaded:
|
|
log("Downloaded all achievement images.", Ansi.LGREEN)
|
|
else:
|
|
# TODO: make the code safe in this state
|
|
log("Failed to download achievement images.", Ansi.LRED)
|
|
achievements_path.rmdir()
|
|
|
|
# allow passthrough (don't hard crash).
|
|
# the server will *mostly* work in this state.
|
|
pass
|
|
|
|
|
|
def download_default_avatar(default_avatar_path: Path) -> None:
|
|
"""Download an avatar to use as the server's default."""
|
|
resp = httpx.get("https://i.cmyui.xyz/U24XBZw-4wjVME-JaEz3.png")
|
|
|
|
if resp.status_code != 200:
|
|
log("Failed to fetch default avatar.", Ansi.LRED)
|
|
return
|
|
|
|
log("Downloaded default avatar.", Ansi.LGREEN)
|
|
default_avatar_path.write_bytes(resp.content)
|
|
|
|
|
|
def has_internet_connectivity(timeout: float = 1.0) -> bool:
|
|
"""Check for an active internet connection."""
|
|
COMMON_DNS_SERVERS = (
|
|
# Cloudflare
|
|
"1.1.1.1",
|
|
"1.0.0.1",
|
|
# Google
|
|
"8.8.8.8",
|
|
"8.8.4.4",
|
|
)
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client:
|
|
client.settimeout(timeout)
|
|
for host in COMMON_DNS_SERVERS:
|
|
try:
|
|
client.connect((host, 53))
|
|
except OSError:
|
|
continue
|
|
else:
|
|
return True
|
|
|
|
# all connections failed
|
|
return False
|
|
|
|
|
|
class FrameInfo(TypedDict):
|
|
function: str
|
|
filename: str
|
|
lineno: int
|
|
charno: int
|
|
locals: dict[str, str]
|
|
|
|
|
|
def get_appropriate_stacktrace() -> list[FrameInfo]:
|
|
"""Return information of all frames related to cmyui_pkg and below."""
|
|
stack = inspect.stack()[1:]
|
|
for idx, frame in enumerate(stack):
|
|
if frame.function == "run":
|
|
break
|
|
else:
|
|
raise Exception
|
|
|
|
return [
|
|
{
|
|
"function": frame.function,
|
|
"filename": Path(frame.filename).name,
|
|
"lineno": frame.lineno,
|
|
"charno": frame.index or 0,
|
|
"locals": {k: repr(v) for k, v in frame.frame.f_locals.items()},
|
|
}
|
|
# reverse for python-like stacktrace
|
|
# ordering; puts the most recent
|
|
# call closest to the command line
|
|
for frame in reversed(stack[:idx])
|
|
]
|
|
|
|
|
|
def pymysql_encode(
|
|
conv: Callable[[Any, dict[object, object] | None], str],
|
|
) -> Callable[[type[T]], type[T]]:
|
|
"""Decorator to allow for adding to pymysql's encoders."""
|
|
|
|
def wrapper(cls: type[T]) -> type[T]:
|
|
pymysql.converters.encoders[cls] = conv
|
|
return cls
|
|
|
|
return wrapper
|
|
|
|
|
|
def escape_enum(
|
|
val: Any,
|
|
_: dict[object, object] | None = None,
|
|
) -> str: # used for ^
|
|
return str(int(val))
|
|
|
|
|
|
def ensure_persistent_volumes_are_available() -> None:
|
|
# create /.data directory
|
|
DATA_PATH.mkdir(exist_ok=True)
|
|
|
|
# create /.data/... subdirectories
|
|
for sub_dir in ("avatars", "logs", "osu", "osr", "ss"):
|
|
subdir = DATA_PATH / sub_dir
|
|
subdir.mkdir(exist_ok=True)
|
|
|
|
# download achievement images from osu!
|
|
if not ACHIEVEMENTS_ASSETS_PATH.exists():
|
|
ACHIEVEMENTS_ASSETS_PATH.mkdir(parents=True)
|
|
download_achievement_images(ACHIEVEMENTS_ASSETS_PATH)
|
|
|
|
# download a default avatar image for new users
|
|
if not DEFAULT_AVATAR_PATH.exists():
|
|
download_default_avatar(DEFAULT_AVATAR_PATH)
|
|
|
|
|
|
def is_running_as_admin() -> bool:
|
|
try:
|
|
return os.geteuid() == 0 # type: ignore[attr-defined, no-any-return, unused-ignore]
|
|
except AttributeError:
|
|
pass
|
|
|
|
try:
|
|
return ctypes.windll.shell32.IsUserAnAdmin() == 1 # type: ignore[attr-defined, no-any-return, unused-ignore]
|
|
except AttributeError:
|
|
raise Exception(
|
|
f"{sys.platform} is not currently supported on bancho.py, please create a github issue!",
|
|
)
|
|
|
|
|
|
def display_startup_dialog() -> None:
|
|
"""Print any general information or warnings to the console."""
|
|
if app.settings.DEVELOPER_MODE:
|
|
log("running in advanced mode", Ansi.LYELLOW)
|
|
if app.settings.DEBUG:
|
|
log("running in debug mode", Ansi.LMAGENTA)
|
|
|
|
# running on root/admin grants the software potentally dangerous and
|
|
# unnecessary power over the operating system and is not advised.
|
|
if is_running_as_admin():
|
|
log(
|
|
"It is not recommended to run bancho.py as root/admin, especially in production."
|
|
+ (
|
|
" You are at increased risk as developer mode is enabled."
|
|
if app.settings.DEVELOPER_MODE
|
|
else ""
|
|
),
|
|
Ansi.LYELLOW,
|
|
)
|
|
|
|
if not has_internet_connectivity():
|
|
log("No internet connectivity detected", Ansi.LYELLOW)
|
|
|
|
|
|
def has_jpeg_headers_and_trailers(data_view: memoryview) -> bool:
|
|
return data_view[:4] == b"\xff\xd8\xff\xe0" and data_view[6:11] == b"JFIF\x00"
|
|
|
|
|
|
def has_png_headers_and_trailers(data_view: memoryview) -> bool:
|
|
return (
|
|
data_view[:8] == b"\x89PNG\r\n\x1a\n"
|
|
and data_view[-8:] == b"\x49END\xae\x42\x60\x82"
|
|
)
|