mirror of
https://github.com/nihilvux/bancho.py.git
synced 2025-09-17 11:07:54 -07:00
Add files via upload
This commit is contained in:
254
app/utils.py
Normal file
254
app/utils.py
Normal file
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import inspect
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
import pymysql
|
||||
|
||||
import app.settings
|
||||
from app.logging import Ansi
|
||||
from app.logging import log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.repositories.users import User
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
DATA_PATH = Path.cwd() / ".data"
|
||||
ACHIEVEMENTS_ASSETS_PATH = DATA_PATH / "assets/medals/client"
|
||||
DEFAULT_AVATAR_PATH = DATA_PATH / "avatars/default.jpg"
|
||||
|
||||
|
||||
def make_safe_name(name: str) -> str:
|
||||
"""Return a name safe for usage in sql."""
|
||||
return name.lower().replace(" ", "_")
|
||||
|
||||
|
||||
def determine_highest_ranking_clan_member(members: list[User]) -> User:
|
||||
return next(iter(sorted(members, key=lambda m: m["clan_priv"], reverse=True)))
|
||||
|
||||
|
||||
def _download_achievement_images_osu(achievements_path: Path) -> bool:
|
||||
"""Download all used achievement images (one by one, from osu!)."""
|
||||
achs: list[str] = []
|
||||
|
||||
for resolution in ("", "@2x"):
|
||||
for mode in ("osu", "taiko", "fruits", "mania"):
|
||||
# only osu!std has 9 & 10 star pass/fc medals.
|
||||
for star_rating in range(1, 1 + (10 if mode == "osu" else 8)):
|
||||
achs.append(f"{mode}-skill-pass-{star_rating}{resolution}.png")
|
||||
achs.append(f"{mode}-skill-fc-{star_rating}{resolution}.png")
|
||||
|
||||
for combo in (500, 750, 1000, 2000):
|
||||
achs.append(f"osu-combo-{combo}{resolution}.png")
|
||||
|
||||
for mod in (
|
||||
"suddendeath",
|
||||
"hidden",
|
||||
"perfect",
|
||||
"hardrock",
|
||||
"doubletime",
|
||||
"flashlight",
|
||||
"easy",
|
||||
"nofail",
|
||||
"nightcore",
|
||||
"halftime",
|
||||
"spunout",
|
||||
):
|
||||
achs.append(f"all-intro-{mod}{resolution}.png")
|
||||
|
||||
log("Downloading achievement images from osu!.", Ansi.LCYAN)
|
||||
|
||||
for ach in achs:
|
||||
resp = httpx.get(f"https://assets.ppy.sh/medals/client/{ach}")
|
||||
if resp.status_code != 200:
|
||||
return False
|
||||
|
||||
log(f"Saving achievement: {ach}", Ansi.LCYAN)
|
||||
(achievements_path / ach).write_bytes(resp.content)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def download_achievement_images(achievements_path: Path) -> None:
|
||||
"""Download all used achievement images (using the best available source)."""
|
||||
|
||||
# download individual files from the official osu! servers
|
||||
downloaded = _download_achievement_images_osu(achievements_path)
|
||||
|
||||
if downloaded:
|
||||
log("Downloaded all achievement images.", Ansi.LGREEN)
|
||||
else:
|
||||
# TODO: make the code safe in this state
|
||||
log("Failed to download achievement images.", Ansi.LRED)
|
||||
achievements_path.rmdir()
|
||||
|
||||
# allow passthrough (don't hard crash).
|
||||
# the server will *mostly* work in this state.
|
||||
pass
|
||||
|
||||
|
||||
def download_default_avatar(default_avatar_path: Path) -> None:
|
||||
"""Download an avatar to use as the server's default."""
|
||||
resp = httpx.get("https://i.cmyui.xyz/U24XBZw-4wjVME-JaEz3.png")
|
||||
|
||||
if resp.status_code != 200:
|
||||
log("Failed to fetch default avatar.", Ansi.LRED)
|
||||
return
|
||||
|
||||
log("Downloaded default avatar.", Ansi.LGREEN)
|
||||
default_avatar_path.write_bytes(resp.content)
|
||||
|
||||
|
||||
def has_internet_connectivity(timeout: float = 1.0) -> bool:
|
||||
"""Check for an active internet connection."""
|
||||
COMMON_DNS_SERVERS = (
|
||||
# Cloudflare
|
||||
"1.1.1.1",
|
||||
"1.0.0.1",
|
||||
# Google
|
||||
"8.8.8.8",
|
||||
"8.8.4.4",
|
||||
)
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as client:
|
||||
client.settimeout(timeout)
|
||||
for host in COMMON_DNS_SERVERS:
|
||||
try:
|
||||
client.connect((host, 53))
|
||||
except OSError:
|
||||
continue
|
||||
else:
|
||||
return True
|
||||
|
||||
# all connections failed
|
||||
return False
|
||||
|
||||
|
||||
class FrameInfo(TypedDict):
|
||||
function: str
|
||||
filename: str
|
||||
lineno: int
|
||||
charno: int
|
||||
locals: dict[str, str]
|
||||
|
||||
|
||||
def get_appropriate_stacktrace() -> list[FrameInfo]:
|
||||
"""Return information of all frames related to cmyui_pkg and below."""
|
||||
stack = inspect.stack()[1:]
|
||||
for idx, frame in enumerate(stack):
|
||||
if frame.function == "run":
|
||||
break
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
return [
|
||||
{
|
||||
"function": frame.function,
|
||||
"filename": Path(frame.filename).name,
|
||||
"lineno": frame.lineno,
|
||||
"charno": frame.index or 0,
|
||||
"locals": {k: repr(v) for k, v in frame.frame.f_locals.items()},
|
||||
}
|
||||
# reverse for python-like stacktrace
|
||||
# ordering; puts the most recent
|
||||
# call closest to the command line
|
||||
for frame in reversed(stack[:idx])
|
||||
]
|
||||
|
||||
|
||||
def pymysql_encode(
|
||||
conv: Callable[[Any, dict[object, object] | None], str],
|
||||
) -> Callable[[type[T]], type[T]]:
|
||||
"""Decorator to allow for adding to pymysql's encoders."""
|
||||
|
||||
def wrapper(cls: type[T]) -> type[T]:
|
||||
pymysql.converters.encoders[cls] = conv
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def escape_enum(
|
||||
val: Any,
|
||||
_: dict[object, object] | None = None,
|
||||
) -> str: # used for ^
|
||||
return str(int(val))
|
||||
|
||||
|
||||
def ensure_persistent_volumes_are_available() -> None:
|
||||
# create /.data directory
|
||||
DATA_PATH.mkdir(exist_ok=True)
|
||||
|
||||
# create /.data/... subdirectories
|
||||
for sub_dir in ("avatars", "logs", "osu", "osr", "ss"):
|
||||
subdir = DATA_PATH / sub_dir
|
||||
subdir.mkdir(exist_ok=True)
|
||||
|
||||
# download achievement images from osu!
|
||||
if not ACHIEVEMENTS_ASSETS_PATH.exists():
|
||||
ACHIEVEMENTS_ASSETS_PATH.mkdir(parents=True)
|
||||
download_achievement_images(ACHIEVEMENTS_ASSETS_PATH)
|
||||
|
||||
# download a default avatar image for new users
|
||||
if not DEFAULT_AVATAR_PATH.exists():
|
||||
download_default_avatar(DEFAULT_AVATAR_PATH)
|
||||
|
||||
|
||||
def is_running_as_admin() -> bool:
|
||||
try:
|
||||
return os.geteuid() == 0 # type: ignore[attr-defined, no-any-return, unused-ignore]
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return ctypes.windll.shell32.IsUserAnAdmin() == 1 # type: ignore[attr-defined, no-any-return, unused-ignore]
|
||||
except AttributeError:
|
||||
raise Exception(
|
||||
f"{sys.platform} is not currently supported on bancho.py, please create a github issue!",
|
||||
)
|
||||
|
||||
|
||||
def display_startup_dialog() -> None:
|
||||
"""Print any general information or warnings to the console."""
|
||||
if app.settings.DEVELOPER_MODE:
|
||||
log("running in advanced mode", Ansi.LYELLOW)
|
||||
if app.settings.DEBUG:
|
||||
log("running in debug mode", Ansi.LMAGENTA)
|
||||
|
||||
# running on root/admin grants the software potentally dangerous and
|
||||
# unnecessary power over the operating system and is not advised.
|
||||
if is_running_as_admin():
|
||||
log(
|
||||
"It is not recommended to run bancho.py as root/admin, especially in production."
|
||||
+ (
|
||||
" You are at increased risk as developer mode is enabled."
|
||||
if app.settings.DEVELOPER_MODE
|
||||
else ""
|
||||
),
|
||||
Ansi.LYELLOW,
|
||||
)
|
||||
|
||||
if not has_internet_connectivity():
|
||||
log("No internet connectivity detected", Ansi.LYELLOW)
|
||||
|
||||
|
||||
def has_jpeg_headers_and_trailers(data_view: memoryview) -> bool:
|
||||
return data_view[:4] == b"\xff\xd8\xff\xe0" and data_view[6:11] == b"JFIF\x00"
|
||||
|
||||
|
||||
def has_png_headers_and_trailers(data_view: memoryview) -> bool:
|
||||
return (
|
||||
data_view[:8] == b"\x89PNG\r\n\x1a\n"
|
||||
and data_view[-8:] == b"\x49END\xae\x42\x60\x82"
|
||||
)
|
Reference in New Issue
Block a user