commit 988afb33e3c5a5585f8fdd547b94dc0416c9e3a4 Author: Avanpost Date: Tue Nov 5 14:57:38 2024 +0000 first commit diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..2cfe03c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,26 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: unconfirmed bug +assignees: '' + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +A small code snippet showing the error + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Library version** +Access this info via `pip show next.py` + +**Additional context** +Add any other context about the problem here. diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 0000000..fb24d64 --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,38 @@ +on: [push, pull_request] +name: pyright +jobs: + pyright-type-checking: + strategy: + matrix: + version: ["3.9", "3.10", "3.11"] + + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.version }} + - run: pip install .[speedups,docs] + - uses: jakebailey/pyright-action@v1 + with: + python-version: ${{ matrix.version }} + working-directory: next + + pyright-type-completeness: + strategy: + matrix: + version: ["3.9", "3.10", "3.11"] + + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.version }} + - run: pip install .[speedups,docs] + - uses: jakebailey/pyright-action@v1 + with: + python-version: ${{ matrix.version }} + working-directory: next + verify-types: next + ignore-external: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5695ae4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.venv +**/__pycache__ +test.py +dist +*.egg-info +docs/_build +.vscode +.env +.mypy_cache +build diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..561fb58 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,16 @@ +version: 2 + +sphinx: + configuration: docs/conf.py + +python: + install: + - method: pip + path: . + extra_requirements: + - docs + +build: + tools: + python: "3.9" + os: "ubuntu-22.04" diff --git a/Justfile b/Justfile new file mode 100644 index 0000000..3925b3d --- /dev/null +++ b/Justfile @@ -0,0 +1,20 @@ +set dotenv-load := true + +test: + python test.py + +build: + rm -rf dist/* + python -m build + +upload: + python -m twine upload dist/* + +lint: + pyright . + +coverage: + pyright --ignoreexternal --verifytypes next + +docs: + cd docs && make html diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..dcc6624 --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2021-present Zomatree + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..d30a012 --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ +# next.py + +An async library to interact with the https://next.avanpost20.ru API. + +You can join the support server [here](https://app.avanpost20.ru/invite/Testers) and find the library's documentation [here](https://nextpy.readthedocs.io/en/latest/). + +## Installing + +You can use `pip` to install next.py. It differs slightly depending on what OS/Distro you use. + +On Windows +``` +py -m pip install -U next-api-py # -U to update +``` + +On macOS and Linux +``` +python3 -m pip install -U next-api-py +``` + +## Example + +More examples can be found in the [examples folder](https://github.com/avanpost200/next.py/blob/master/examples). + +```py +import next +import asyncio + +class Client(next.Client): + async def on_message(self, message: next.Message): + if message.content == "hello": + await message.channel.send("hi how are you") + +async def main(): + async with next.utils.client_session() as session: + client = Client(session, "BOT TOKEN HERE") + await client.start() + +asyncio.run(main()) +``` diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 0000000..3e1cc00 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,289 @@ +.. currentmodule:: next + +API Reference +=============== + + +.. autoclass:: Client + :members: + :inherited-members: + +.. autoclass:: Asset + :members: + :inherited-members: + +.. autoclass:: PartialAsset + :members: + :inherited-members: + +.. autoclass:: Channel + :members: + :inherited-members: + +.. autoclass:: ServerChannel + :members: + :inherited-members: + +.. autoclass:: SavedMessageChannel + :members: + :inherited-members: + +.. autoclass:: DMChannel + :members: + :inherited-members: + +.. autoclass:: GroupDMChannel + :members: + :inherited-members: + +.. autoclass:: TextChannel + :members: + :inherited-members: + +.. autoclass:: VoiceChannel + :members: + :inherited-members: + +.. autoclass:: Embed + :members: + :inherited-members: + +.. autoclass:: WebsiteEmbed + :members: + :inherited-members: + +.. autoclass:: ImageEmbed + :members: + :inherited-members: + +.. autoclass:: TextEmbed + :members: + :inherited-members: + +.. autoclass:: NoneEmbed + :members: + :inherited-members: + +.. autoclass:: SendableEmbed + :members: + :inherited-members: + +.. autoclass:: File + :members: + :inherited-members: + +.. autoclass:: Member + :members: + :inherited-members: + +.. autoclass:: Message + :members: + :inherited-members: + +.. autoclass:: MessageReply + :members: + :inherited-members: + +.. autoclass:: Masquerade + :members: + :inherited-members: + +.. autoclass:: Messageable + :members: + :inherited-members: + +.. autoclass:: Permissions + :members: + :inherited-members: + +.. autoclass:: UserPermissions + :members: + :inherited-members: + +.. autoclass:: PermissionsOverwrite + :members: + :inherited-members: + +.. autoclass:: Role + :members: + :inherited-members: + +.. autoclass:: Server + :members: + :inherited-members: + +.. autoclass:: ServerBan + :members: + :inherited-members: + +.. autoclass:: Category + :members: + :inherited-members: + +.. autoclass:: SystemMessages + :members: + :inherited-members: + +.. autoclass:: User + :members: + :inherited-members: + +.. autonamedtuple:: Relation + +.. autonamedtuple:: Status + +.. autoclass:: UserBadges + :members: + +.. autoclass:: UserProfile + :members: + +.. autoclass:: Invite + :members: + +.. autoclass:: Emoji + :members: + +.. autoclass:: MessageInteractions + :members: + +Enums +====== + +The api uses enums to say what variant of something is, +these represent those enums + +All enums subclass `aenum.Enum`. + +.. class:: ChannelType + + Specifies the type of channel. + + .. attribute:: saved_message + + A private channel only you can access. + .. attribute:: direct_message + + A private direct message channel between you and another user + .. attribute:: group + + A private group channel for messages between a group of users + .. attribute:: text_channel + + A text channel in a server + .. attribute:: voice_channel + + A voice only channel + +.. class:: PresenceType + + Specifies what a users presence is + + .. attribute:: busy + + The user is busy and wont receive notification + .. attribute:: idle + + The user is idle + .. attribute:: invisible + + The user is invisible, you will never receive this, instead they will appear offline + .. attribute:: online + + The user is online + + .. attribute:: offline + + The user is offline or invisible + +.. class:: RelationshipType + + Specifies the relationship between two users + + .. attribute:: blocked + + You have blocked them + .. attribute:: blocked_other + + They have blocked you + .. attribute:: friend + + You are friends with them + .. attribute:: incoming_friend_request + + They are sending you a friend request + .. attribute:: none + + You have no relationship with them + .. attribute:: outgoing_friend_request + + You are sending them a friend request + + .. attribute:: user + + That user is yourself + +.. class:: AssetType + + Specifies the type of asset + + .. attribute:: image + + The asset is an image + .. attribute:: video + + The asset is a video + .. attribute:: text + + The asset is a text file + .. attribute:: audio + + The asset is an audio file + .. attribute:: file + + The asset is a generic file + +.. class:: SortType + + The sort type for a message search + + .. attribute:: latest + + Sort by the latest message + .. attribute:: oldest + + Sort by the oldest message + .. attribute:: relevance + + Sort by the relevance of the message + +.. class:: EmbedType + + The type of embed + + .. attribute:: website + + The embed is a website + .. attribute:: image + + The embed is an image + .. attribute:: text + + The embed is text + .. attribute:: video + + The embed is a video + .. attribute:: unknown + + The embed is unknown + +Utils +====== + +.. currentmodule:: next.utils + +A collection a utility functions and classes to aid in making your bot + +.. autofunction:: get + +.. autofunction:: client_session diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..2224dd6 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,65 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +import os +import sys + +import sphinx_nameko_theme + +sys.path.insert(0, os.path.abspath('..')) + +import next + +# -- Project information ----------------------------------------------------- + +project = 'Next.py' +copyright = '2024-present, Avanpost' +author = 'Avanpost' +version = ".".join(map(str, next.__version__)) + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx_toolbox.installation", + "sphinx_toolbox.more_autodoc.autonamedtuple" +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +add_module_names = False + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. + +html_theme = 'nameko' +html_theme_path = [sphinx_nameko_theme.get_html_theme_path()] + + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +autodoc_typehints = "none" diff --git a/docs/ext/commands/api.rst b/docs/ext/commands/api.rst new file mode 100644 index 0000000..60e26e4 --- /dev/null +++ b/docs/ext/commands/api.rst @@ -0,0 +1,110 @@ +.. currentmodule:: next + +API Reference +=============== + + +CommandsClient +~~~~~~~~~~~~~~~ +.. autoclass:: next.ext.commands.CommandsClient + :members: + +Context +~~~~~~~~ +.. autoclass:: next.ext.commands.Context + :members: + +Command +~~~~~~~~ +.. autoclass:: next.ext.commands.Command + :members: + +Cog +~~~~ +.. autoclass:: next.ext.commands.Cog + :members: + +command +~~~~~~~~ +.. autodecorator:: next.ext.commands.command + +check +~~~~~~ +.. autodecorator:: next.ext.commands.check + +is_bot_owner +~~~~~~~~~~~~~ +.. autodecorator:: next.ext.commands.is_bot_owner + +is_server_owner +~~~~~~~~~~~~~~~~ +.. autodecorator:: next.ext.commands.is_server_owner + + +Exceptions +=========== + +CommandError +~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.CommandError + :members: + +CommandNotFound +~~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.CommandNotFound + :members: + +NoClosingQuote +~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.NoClosingQuote + :members: + +CheckError +~~~~~~~~~~~ +.. autoexception:: next.ext.commands.CheckError + :members: + +NotBotOwner +~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.NotBotOwner + :members: + +NotServerOwner +~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.NotServerOwner + :members: + +ServerOnly +~~~~~~~~~~~ +.. autoexception:: next.ext.commands.ServerOnly + :members: + +ConverterError +~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.ConverterError + :members: + +InvalidLiteralArgument +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.InvalidLiteralArgument + :members: + +BadBoolArgument +~~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.BadBoolArgument + :members: + +CategoryConverterError +~~~~~~~~~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.CategoryConverterError + :members: + +UserConverterError +~~~~~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.UserConverterError + :members: + +MemberConverterError +~~~~~~~~~~~~~~~~~~~~~ +.. autoexception:: next.ext.commands.MemberConverterError + :members: diff --git a/docs/ext/commands/index.rst b/docs/ext/commands/index.rst new file mode 100644 index 0000000..007b11d --- /dev/null +++ b/docs/ext/commands/index.rst @@ -0,0 +1,9 @@ +.. next_ext_commands: + +``next.ext.commands`` - Command Framework +============================================ + +.. toctree:: + :maxdepth: 1 + + api diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..ab2ea37 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,14 @@ +Welcome to Next.py's documentation! +====================================== + +.. toctree:: + :maxdepth: 1 + + api + +Extensions +----------- +.. toctree:: + :maxdepth: 1 + + ext/commands/index.rst diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..2119f51 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/examples/basic.py b/examples/basic.py new file mode 100644 index 0000000..de8d058 --- /dev/null +++ b/examples/basic.py @@ -0,0 +1,16 @@ +import asyncio +import aiohttp +import next + + +class Client(next.Client): + async def on_message(self, message: next.Message): + if message.content == "hello": + await message.channel.send("hi how are you") + +async def main(): + async with aiohttp.ClientSession() as session: + client = Client(session, "BOT TOKEN HERE") + await client.start() + +asyncio.run(main()) diff --git a/examples/commands.py b/examples/commands.py new file mode 100644 index 0000000..bba5f6e --- /dev/null +++ b/examples/commands.py @@ -0,0 +1,22 @@ +import asyncio + +import aiohttp + +import next +from next.ext import commands + + +class Client(commands.CommandsClient): + async def get_prefix(self, message: next.Message): + return "!" + + @commands.command() + async def ping(self, ctx: commands.Context): + await ctx.send("pong") + +async def main(): + async with aiohttp.ClientSession() as session: + client = Client(session, "BOT TOKEN HERE") + await client.start() + +asyncio.run(main()) diff --git a/next/__init__.py b/next/__init__.py new file mode 100644 index 0000000..3a66c24 --- /dev/null +++ b/next/__init__.py @@ -0,0 +1,22 @@ +from . import utils as utils +from . import types as types +from .asset import * +from .category import * +from .channel import * +from .client import * +from .embed import * +from .emoji import * +from .enums import * +from .errors import * +from .file import * +from .flags import * +from .invite import * +from .member import * +from .message import * +from .messageable import * +from .permissions import * +from .role import * +from .server import * +from .user import * + +__version__ = "0.2.0" diff --git a/next/asset.py b/next/asset.py new file mode 100644 index 0000000..4d74c28 --- /dev/null +++ b/next/asset.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import mimetypes +from typing import TYPE_CHECKING + +from .enums import AssetType +from .utils import Ulid + +if TYPE_CHECKING: + from io import IOBase + + from .state import State + from .types import File as FilePayload + + +__all__ = ("Asset", "PartialAsset") + +class Asset(Ulid): + """Represents a file on next + + Attributes + ----------- + id: :class:`str` + The id of the asset + tag: :class:`str` + The tag of the asset, this corresponds to where the asset is used + size: :class:`int` + Amount of bytes in the file + filename: :class:`str` + The name of the file + height: Optional[:class:`int`] + The height of the file if it is an image or video + width: Optional[:class:`int`] + The width of the file if it is an image or video + content_type: :class:`str` + The content type of the file + type: :class:`AssetType` + The type of asset it is + url: :class:`str` + The asset's url + """ + __slots__ = ("state", "id", "tag", "size", "filename", "content_type", "width", "height", "type", "url") + + def __init__(self, data: FilePayload, state: State): + self.state: State = state + + self.id: str = data['_id'] + self.tag: str = data['tag'] + self.size: int = data['size'] + self.filename: str = data['filename'] + + metadata = data['metadata'] + self.height: int | None + self.width: int | None + + if metadata["type"] == "Image" or metadata["type"] == "Video": # cannot use `in` because type narrowing will not happen + self.height = metadata["height"] + self.width = metadata["width"] + else: + self.height = None + self.width = None + + self.content_type: str | None = data["content_type"] + self.type: AssetType = AssetType(metadata["type"]) + + base_url = self.state.api_info["features"]["autumn"]["url"] + self.url: str = f"{base_url}/{self.tag}/{self.id}" + + async def read(self) -> bytes: + """Reads the files content into bytes""" + return await self.state.http.request_file(self.url) + + async def save(self, fp: IOBase) -> None: + """Reads the files content and saves it to a file + + Parameters + ----------- + fp: IOBase + The file to write to + """ + fp.write(await self.read()) + +class PartialAsset(Asset): + """Partial asset for when we get limited data about the asset + + Attributes + ----------- + id: :class:`str` + The id of the asset, this will always be ``"0"`` + size: :class:`int` + Amount of bytes in the file, this will always be ``0`` + filename: :class:`str` + The name of the file, this be always be ``""`` + height: Optional[:class:`int`] + The height of the file if it is an image or video, this will always be ``None`` + width: Optional[:class:`int`] + The width of the file if it is an image or video, this will always be ``None`` + content_type: Optional[:class:`str`] + The content type of the file, this is guessed from the url's file extension if it has one + type: :class:`AssetType` + The type of asset it is, this always be ``AssetType.file`` + """ + + def __init__(self, url: str, state: State): + self.state: State = state + self.id: str = "0" + self.size: int = 0 + self.filename: str = "" + self.height: int | None = None + self.width: int | None = None + self.content_type: str | None = mimetypes.guess_extension(url) + self.type: AssetType = AssetType.file + self.url: str = url diff --git a/next/category.py b/next/category.py new file mode 100644 index 0000000..cdf746c --- /dev/null +++ b/next/category.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .utils import Ulid + +if TYPE_CHECKING: + from .channel import Channel + from .state import State + from .types import Category as CategoryPayload + +__all__ = ("Category",) + +class Category(Ulid): + """Represents a category in a server that stores channels. + + Attributes + ----------- + name: :class:`str` + The name of the category + id: :class:`str` + The id of the category + channel_ids: list[:class:`str`] + The ids of channels that are inside the category + """ + + def __init__(self, data: CategoryPayload, state: State): + self.state: State = state + self.name: str = data["title"] + self.id: str = data["id"] + self.channel_ids: list[str] = data["channels"] + + @property + def channels(self) -> list[Channel]: + """Returns a list of channels that the category contains""" + return [self.state.get_channel(channel_id) for channel_id in self.channel_ids] diff --git a/next/channel.py b/next/channel.py new file mode 100644 index 0000000..06459ca --- /dev/null +++ b/next/channel.py @@ -0,0 +1,422 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Union + +from .asset import Asset +from .enums import ChannelType +from .messageable import Messageable +from .permissions import Permissions, PermissionsOverwrite +from .utils import Missing, Ulid + +if TYPE_CHECKING: + from .message import Message + from .role import Role + from .server import Server + from .state import State + from .types import Channel as ChannelPayload + from .types import DMChannel as DMChannelPayload + from .types import File as FilePayload + from .types import GroupDMChannel as GroupDMChannelPayload + from .types import Overwrite as OverwritePayload + from .types import SavedMessages as SavedMessagesPayload + from .types import ServerChannel as ServerChannelPayload + from .types import TextChannel as TextChannelPayload + from .user import User + +__all__ = ("DMChannel", "GroupDMChannel", "SavedMessageChannel", "TextChannel", "VoiceChannel", "Channel", "ServerChannel") + +class EditableChannel: + __slots__ = () + + state: State + id: str + + async def edit(self, **kwargs: Any) -> None: + """Edits the channel + + Passing ``None`` to the parameters that accept it will remove them. + + Parameters + ----------- + name: str + The new name for the channel + description: Optional[str] + The new description for the channel + owner: User + The new owner for the group dm channel + icon: Optional[File] + The new icon for the channel + nsfw: bool + Sets whether the channel is nsfw or not + """ + remove: list[str] = [] + + if kwargs.get("icon", Missing) == None: + remove.append("Icon") + elif kwargs.get("description", Missing) == None: + remove.append("Description") + + if icon := kwargs.get("icon"): + asset = await self.state.http.upload_file(icon, "icons") + kwargs["icon"] = asset["id"] + + if owner := kwargs.get("owner"): + kwargs["owner"] = owner.id + + await self.state.http.edit_channel(self.id, remove, kwargs) + +class Channel(Ulid): + """Base class for all channels + + Attributes + ----------- + id: :class:`str` + The id of the channel + channel_type: ChannelType + The type of the channel + server_id: Optional[:class:`str`] + The server id of the chanel, if any + """ + __slots__ = ("state", "id", "channel_type", "server_id") + + def __init__(self, data: ChannelPayload, state: State): + self.state: State = state + self.id: str = data["_id"] + self.channel_type: ChannelType = ChannelType(data["channel_type"]) + self.server_id: Optional[str] = None + + async def _get_channel_id(self) -> str: + return self.id + + def _update(self, **_: Any) -> None: + pass + + async def delete(self) -> None: + """Deletes or closes the channel""" + await self.state.http.close_channel(self.id) + + @property + def server(self) -> Server: + """:class:`Server` The server this voice channel belongs too + + Raises + ------- + :class:`LookupError` + Raises if the channel is not part of a server + """ + if not self.server_id: + raise LookupError + + return self.state.get_server(self.server_id) + + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the given channel.""" + return f"<#{self.id}>" + + +class SavedMessageChannel(Channel, Messageable): + """The Saved Message Channel""" + def __init__(self, data: SavedMessagesPayload, state: State): + super().__init__(data, state) + +class DMChannel(Channel, Messageable): + """A DM channel + + Attributes + ----------- + last_message_id: Optional[:class:`str`] + The id of the last message in this channel, if any + """ + + __slots__ = ("last_message_id", "recipient_ids") + + def __init__(self, data: DMChannelPayload, state: State): + super().__init__(data, state) + + self.recipient_ids: list[str] = data["recipients"] + self.last_message_id: str | None = data.get("last_message_id") + + @property + def recipients(self) -> tuple[User, User]: + a, b = self.recipient_ids + + return (self.state.get_user(a), self.state.get_user(b)) + + @property + def recipient(self) -> User: + if self.recipient_ids[0] != self.state.user_id: + user_id = self.recipient_ids[0] + else: + user_id = self.recipient_ids[1] + + return self.state.get_user(user_id) + + @property + def last_message(self) -> Message: + """Gets the last message from the channel, shorthand for `client.get_message(channel.last_message_id)` + + Returns + -------- + :class:`Message` the last message in the channel + """ + + if not self.last_message_id: + raise LookupError + + return self.state.get_message(self.last_message_id) + +class GroupDMChannel(Channel, Messageable, EditableChannel): + """A group DM channel + + Attributes + ----------- + recipients: list[:class:`User`] + The recipients of the group dm channel + name: :class:`str` + The name of the group dm channel + owner: :class:`User` + The user who created the group dm channel + icon: Optional[:class:`Asset`] + The icon of the group dm channel + permissions: :class:`ChannelPermissions` + The permissions of the users inside the group dm channel + description: Optional[:class:`str`] + The description of the channel, if any + last_message_id: Optional[:class:`str`] + The id of the last message in this channel, if any + """ + + __slots__ = ("recipient_ids", "name", "owner_id", "permissions", "icon", "description", "last_message_id") + + def __init__(self, data: GroupDMChannelPayload, state: State): + super().__init__(data, state) + self.recipient_ids: list[str] = data["recipients"] + self.name: str = data["name"] + self.owner_id: str = data["owner"] + self.description: str | None = data.get("description") + self.last_message_id: str | None = data.get("last_message_id") + + self.icon: Asset | None + + if icon := data.get("icon"): + self.icon = Asset(icon, state) + else: + self.icon = None + + self.permissions: Permissions = Permissions(data.get("permissions", 0)) + + def _update(self, *, name: Optional[str] = None, recipients: Optional[list[str]] = None, description: Optional[str] = None) -> None: + if name is not None: + self.name = name + + if recipients is not None: + self.recipient_ids = recipients + + if description is not None: + self.description = description + + @property + def recipients(self) -> list[User]: + return [self.state.get_user(user_id) for user_id in self.recipient_ids] + + @property + def owner(self) -> User: + return self.state.get_user(self.owner_id) + + async def set_default_permissions(self, permissions: Permissions) -> None: + """Sets the default permissions for a group. + Parameters + ----------- + permissions: :class:`ChannelPermissions` + The new default group permissions + """ + await self.state.http.set_group_channel_default_permissions(self.id, permissions.value) + + @property + def last_message(self) -> Message: + """Gets the last message from the channel, shorthand for `client.get_message(channel.last_message_id)` + + Returns + -------- + :class:`Message` the last message in the channel + """ + + if not self.last_message_id: + raise LookupError + + return self.state.get_message(self.last_message_id) + +class ServerChannel(Channel): + """Base class for all guild channels + + Attributes + ----------- + server_id: :class:`str` + The id of the server this text channel belongs to + name: :class:`str` + The name of the text channel + description: Optional[:class:`str`] + The description of the channel, if any + nsfw: bool + Sets whether the channel is nsfw or not + default_permissions: :class:`ChannelPermissions` + The default permissions for all users in the text channel + """ + def __init__(self, data: ServerChannelPayload, state: State): + super().__init__(data, state) + + self.server_id: Optional[str] = data["server"] + self.name: str = data["name"] + self.description: Optional[str] = data.get("description") + self.nsfw: bool = data.get("nsfw", False) + self.active: bool = False + self.default_permissions: PermissionsOverwrite = PermissionsOverwrite._from_overwrite(data.get("default_permissions", {"a": 0, "d": 0})) + + permissions: dict[str, PermissionsOverwrite] = {} + + for role_name, overwrite_data in data.get("role_permissions", {}).items(): + overwrite = PermissionsOverwrite._from_overwrite(overwrite_data) + permissions[role_name] = overwrite + + self.permissions: dict[str, PermissionsOverwrite] = permissions + + self.icon: Asset | None + + if icon := data.get("icon"): + self.icon = Asset(icon, state) + else: + self.icon = None + + async def set_default_permissions(self, permissions: PermissionsOverwrite) -> None: + """Sets the default permissions for the channel. + Parameters + ----------- + permissions: :class:`ChannelPermissions` + The new default channel permissions + """ + allow, deny = permissions.to_pair() + await self.state.http.set_guild_channel_default_permissions(self.id, allow.value, deny.value) + + async def set_role_permissions(self, role: Role, permissions: PermissionsOverwrite) -> None: + """Sets the permissions for a role in the channel. + Parameters + ----------- + permissions: :class:`ChannelPermissions` + The new channel permissions + """ + allow, deny = permissions.to_pair() + + await self.state.http.set_guild_channel_role_permissions(self.id, role.id, allow.value, deny.value) + + def _update(self, *, name: Optional[str] = None, description: Optional[str] = None, icon: Optional[FilePayload] = None, nsfw: Optional[bool] = None, active: Optional[bool] = None, role_permissions: Optional[dict[str, OverwritePayload]] = None, default_permissions: Optional[OverwritePayload] = None): + if name is not None: + self.name = name + + if description is not None: + self.description = description + + if icon is not None: + self.icon = Asset(icon, self.state) + + if nsfw is not None: + self.nsfw = nsfw + + if active is not None: + self.active = active + + if role_permissions is not None: + permissions = {} + + for role_name, overwrite_data in role_permissions.items(): + overwrite = PermissionsOverwrite._from_overwrite(overwrite_data) + permissions[role_name] = overwrite + + self.permissions = permissions + + if default_permissions is not None: + self.default_permissions = PermissionsOverwrite._from_overwrite(default_permissions) + +class TextChannel(ServerChannel, Messageable, EditableChannel): + """A text channel + + Subclasses :class:`ServerChannel` and :class:`Messageable` + + Attributes + ----------- + name: :class:`str` + The name of the text channel + server_id: :class:`str` + The id of the server this text channel belongs to + last_message_id: Optional[:class:`str`] + The id of the last message in this channel, if any + default_permissions: :class:`ChannelPermissions` + The default permissions for all users in the text channel + role_permissions: dict[:class:`str`, :class:`ChannelPermissions`] + A dictionary of role id's to the permissions of that role in the text channel + icon: Optional[:class:`Asset`] + The icon of the text channel, if any + description: Optional[:class:`str`] + The description of the channel, if any + """ + + __slots__ = ("name", "description", "last_message_id", "default_permissions", "icon", "overwrites") + + def __init__(self, data: TextChannelPayload, state: State): + super().__init__(data, state) + + self.last_message_id: str | None = data.get("last_message_id") + + async def _get_channel_id(self) -> str: + return self.id + + @property + def last_message(self) -> Message: + """Gets the last message from the channel, shorthand for `client.get_message(channel.last_message_id)` + + Returns + -------- + :class:`Message` the last message in the channel + """ + + if not self.last_message_id: + raise LookupError + + return self.state.get_message(self.last_message_id) + +class VoiceChannel(ServerChannel, EditableChannel): + """A voice channel + + Subclasses :class:`ServerChannel` + + Attributes + ----------- + name: :class:`str` + The name of the voice channel + server_id: :class:`str` + The id of the server this voice channel belongs to + last_message_id: Optional[:class:`str`] + The id of the last message in this channel, if any + default_permissions: :class:`ChannelPermissions` + The default permissions for all users in the voice channel + role_permissions: dict[:class:`str`, :class:`ChannelPermissions`] + A dictionary of role id's to the permissions of that role in the voice channel + icon: Optional[:class:`Asset`] + The icon of the voice channel, if any + description: Optional[:class:`str`] + The description of the channel, if any + """ + +def channel_factory(data: ChannelPayload, state: State) -> Union[DMChannel, GroupDMChannel, SavedMessageChannel, TextChannel, VoiceChannel]: + if data["channel_type"] == "SavedMessages": + return SavedMessageChannel(data, state) + elif data["channel_type"] == "DirectMessage": + return DMChannel(data, state) + elif data["channel_type"] == "Group": + return GroupDMChannel(data, state) + elif data["channel_type"] == "TextChannel": + return TextChannel(data, state) + elif data["channel_type"] == "VoiceChannel": + return VoiceChannel(data, state) + else: + raise Exception diff --git a/next/client.py b/next/client.py new file mode 100644 index 0000000..8c1c39a --- /dev/null +++ b/next/client.py @@ -0,0 +1,565 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, Optional, TypeVar, Union, cast, overload +from typing_extensions import ParamSpec + +import aiohttp + +from .errors import NextError +from .channel import (DMChannel, GroupDMChannel, SavedMessageChannel, + TextChannel, VoiceChannel, channel_factory) +from .http import HttpClient +from .invite import Invite +from .message import Message +from .state import State +from .utils import Missing, Ulid +from .websocket import WebsocketHandler +from .emoji import Emoji +from .server import Server +from .user import User + +try: + import ujson as json +except ImportError: + import json + +if TYPE_CHECKING: + from .channel import Channel + from .file import File + from .types import ApiInfo + + import next + +__all__ = ("Client",) + +logger: logging.Logger = logging.getLogger("next") + +P = ParamSpec("P") +R = TypeVar("R") + +class Client: + """The client for interacting with next + + Parameters + ----------- + session: :class:`aiohttp.ClientSession` + The aiohttp session to use for http request and the websocket + token: :class:`str` + The bots token + api_url: :class:`str` + The api url for the next instance you are connecting to, by default it uses the offical instance hosted at next.avanpost20.ru + max_messages: :class:`int` + The max amount of messages stored in the cache, by default this is 5k + """ + + def __init__(self, session: aiohttp.ClientSession, token: str, *, api_url: str = "https://api.avanpost20.ru", max_messages: int = 5000, bot: bool = True): + self.session: aiohttp.ClientSession = session + self.token: str = token + self.api_url: str = api_url + self.max_messages: int = max_messages + self.bot: bool = bot + + self.api_info: ApiInfo + self.http: HttpClient + self.state: State + self.websocket: WebsocketHandler + + self.temp_listeners: dict[str, list[tuple[Callable[..., bool], asyncio.Future[Any]]]] = {} + self.listeners: dict[str, list[Callable[..., Coroutine[Any, Any, Any]]]] = {} + + super().__init__() + + def dispatch(self, event: str, *args: Any) -> None: + """Dispatch an event, this is typically used for testing and internals. + + Parameters + ---------- + event: class:`str` + The name of the event to dispatch, not including `on_` + args: :class:`Any` + The arguments passed to the event + """ + + if temp_listeners := self.temp_listeners.get(event, None): + for check, future in temp_listeners: + if check(*args): + if len(args) == 1: + future.set_result(args[0]) + else: + future.set_result(args) + + self.temp_listeners[event] = [(c, f) for c, f in temp_listeners if not f.done()] + + for listener in self.listeners.get(event, []): + asyncio.create_task(listener(*args)) + + if func := getattr(self, f"on_{event}", None): + asyncio.create_task(func(*args)) + + async def get_api_info(self) -> ApiInfo: + async with self.session.get(self.api_url) as resp: + text = await resp.text() + + try: + return json.loads(text) + except: + raise NextError(f"Cant fetch api info:\n{text}") + + async def start(self, *, reconnect: bool = True) -> None: + """Starts the client""" + api_info = await self.get_api_info() + + self.api_info = api_info + self.http = HttpClient(self.session, self.token, self.api_url, self.api_info, self.bot) + self.state = State(self.http, api_info, self.max_messages) + self.websocket = WebsocketHandler(self.session, self.token, api_info["ws"], self.dispatch, self.state) + + await self.websocket.start(reconnect) + + async def stop(self) -> None: + await self.websocket.websocket.close() + + def get_user(self, id: str) -> User: + """Gets a user from the cache + + Parameters + ----------- + id: :class:`str` + The id of the user + + Returns + -------- + :class:`User` + The user + """ + return self.state.get_user(id) + + def get_channel(self, id: str) -> Channel: + """Gets a channel from the cache + + Parameters + ----------- + id: :class:`str` + The id of the channel + + Returns + -------- + :class:`Channel` + The channel + """ + return self.state.get_channel(id) + + def get_server(self, id: str) -> Server: + """Gets a server from the cache + + Parameters + ----------- + id: :class:`str` + The id of the server + + Returns + -------- + :class:`Server` + The server + """ + return self.state.get_server(id) + + async def wait_for(self, event: str, *, check: Optional[Callable[..., bool]] = None, timeout: Optional[float] = None) -> Any: + """Waits for an event + + Parameters + ----------- + event: :class:`str` + The name of the event to wait for, without the `on_` + check: Optional[Callable[..., :class:`bool`]] + A function that says what event to wait_for, this function takes the same parameters as the event you are waiting for and should return a bool saying if that is the event you want + timeout: Optional[:class:`float`] + Time in seconds to wait for the event. By default it waits forever + + Raises + ------- + asyncio.TimeoutError + If timeout is provided and it was reached + + Returns + -------- + Any + The parameters of the event + """ + if not check: + check = lambda *_: True + + future = asyncio.get_running_loop().create_future() + self.temp_listeners.setdefault(event, []).append((check, future)) + + return await asyncio.wait_for(future, timeout) + + def listen(self, name: str | None = None) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: + """Registers a listener for an event, multiple listeners can be registered to the same event without conflict + + Parameters + ----------- + name: Optional[:class:`str`] + The name of the event to register this under, this defaults to the function's name + """ + def inner(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: + nonlocal name + + if not name: + if not func.__name__.startswith("on_"): + raise NextError("listener name must begin with `on_`") + + name = func.__name__[3:] + + self.listeners.setdefault(name, []).append(func) + return func + + return inner + + @overload + def remove_listener(self, func: Callable[P, Coroutine[Any, Any, R]], *, event: str = ...) -> Callable[..., Coroutine[Any, Any, R]] | None: + ... + + @overload + def remove_listener(self, func: Callable[P, Coroutine[Any, Any, Any]], *, event: None = ...) -> None: + ... + + def remove_listener(self, func: Callable[P, Coroutine[Any, Any, R]], *, event: str | None = None) -> Callable[..., Coroutine[Any, Any, R]] | None: + """Removes a listener registered, if the `event` parameter is passed, the listener will only be removed from that event, this can be used if the same listener is registed to multiple events at once. + + Parameters + ----------- + func: Callable + The function for the listener to be removed + event: Optional[:class:`str`] + The name of the event to remove this from, passing `None` will make this remove the listener from all events this is registered under + """ + if event is None: + for listeners in self.listeners.values(): + try: + listeners.remove(func) + except ValueError: + pass + + else: + try: + self.listeners[event].remove(func) + return func + except ValueError: + pass + + @property + def user(self) -> User: + """:class:`User` the user corrasponding to the client""" + user = self.websocket.user + + assert user + return user + + @property + def users(self) -> list[User]: + """list[:class:`User`] All users the client can see""" + return list(self.state.users.values()) + + @property + def servers(self) -> list[Server]: + """list[:class:'Server'] All servers the client can see""" + return list(self.state.servers.values()) + + @property + def global_emojis(self) -> list[Emoji]: + return self.state.global_emojis + + async def fetch_user(self, user_id: str) -> User: + """Fetchs a user + + Parameters + ----------- + user_id: :class:`str` + The id of the user you are fetching + + Returns + -------- + :class:`User` + The user with the matching id + """ + payload = await self.http.fetch_user(user_id) + return User(payload, self.state) + + async def fetch_dm_channels(self) -> list[Union[DMChannel, GroupDMChannel]]: + """Fetchs all dm channels the client has made + + Returns + -------- + list[Union[:class:`DMChanel`, :class:`GroupDMChannel`]] + A list of :class:`DMChannel` or :class`GroupDMChannel` + """ + channel_payloads = await self.http.fetch_dm_channels() + return cast(list[Union[DMChannel, GroupDMChannel]], [channel_factory(payload, self.state) for payload in channel_payloads]) + + async def fetch_channel(self, channel_id: str) -> Union[DMChannel, GroupDMChannel, SavedMessageChannel, TextChannel, VoiceChannel]: + """Fetches a channel + + Parameters + ----------- + channel_id: :class:`str` + The id of the channel + + Returns + -------- + Union[:class:`DMChannel`, :class:`GroupDMChannel`, :class:`SavedMessageChannel`, :class:`TextChannel`, :class:`VoiceChannel`] + The channel with the matching id + """ + payload = await self.http.fetch_channel(channel_id) + + return channel_factory(payload, self.state) + + async def fetch_server(self, server_id: str) -> Server: + """Fetchs a server + + Parameters + ----------- + server_id: :class:`str` + The id of the server you are fetching + + Returns + -------- + :class:`Server` + The server with the matching id + """ + payload = await self.http.fetch_server(server_id) + + return Server(payload, self.state) + + async def fetch_invite(self, code: str) -> Invite: + """Fetchs an invite + + Parameters + ----------- + code: :class:`str` + The code of the invite you are fetching + + Returns + -------- + :class:`Invite` + The invite with the matching code + """ + payload = await self.http.fetch_invite(code) + + return Invite(payload, code, self.state) + + def get_message(self, message_id: str) -> Message: + """Gets a message from the cache + + Parameters + ----------- + message_id: :class:`str` + The id of the message you are getting + + Returns + -------- + :class:`Message` + The message with the matching id + + Raises + ------- + LookupError + This raises if the message is not found in the cache + """ + for message in self.state.messages: + if message.id == message_id: + return message + + raise LookupError + + async def edit_self(self, **kwargs: Any) -> None: + """Edits the client's own user + + Parameters + ----------- + avatar: Optional[:class:`File`] + The avatar to change to, passing in ``None`` will remove the avatar + """ + if kwargs.get("avatar", Missing) is None: + del kwargs["avatar"] + remove = ["Avatar"] + else: + remove = None + + await self.state.http.edit_self(remove, kwargs) + + async def edit_status(self, **kwargs: Any) -> None: + """Edits the client's own status + + Parameters + ----------- + presence: :class:`PresenceType` + The presence to change to + text: Optional[:class:`str`] + The text to change the status to, passing in ``None`` will remove the status + """ + if kwargs.get("text", Missing) is None: + del kwargs["text"] + remove = ["StatusText"] + else: + remove = None + + if presence := kwargs.get("presence"): + kwargs["presence"] = presence.value + + await self.state.http.edit_self(remove, {"status": kwargs}) + + async def edit_profile(self, **kwargs: Any) -> None: + """Edits the client's own profile + + Parameters + ----------- + content: Optional[:class:`str`] + The new content for the profile, passing in ``None`` will remove the profile content + background: Optional[:class:`File`] + The new background for the profile, passing in ``None`` will remove the profile background + """ + remove: list[str] = [] + + if kwargs.get("content", Missing) is None: + del kwargs["content"] + remove.append("ProfileContent") + + if kwargs.get("background", Missing) is None: + del kwargs["background"] + remove.append("ProfileBackground") + + await self.state.http.edit_self(remove, {"profile": kwargs}) + + async def fetch_emoji(self, emoji_id: str) -> Emoji: + """Fetches an emoji + + Parameters + ----------- + emoji_id: str + The id of the emoji + + Returns + -------- + :class:`Emoji` + The emoji with the corrasponding id + """ + + emoji = await self.state.http.fetch_emoji(emoji_id) + + return Emoji(emoji, self.state) + + async def upload_file(self, file: File, tag: Literal['attachments', 'avatars', 'backgrounds', 'icons', 'banners', 'emojis']) -> Ulid: + """Uploads a file to next + + Parameters + ----------- + file: :class:`File` + The file to upload + tag: :class:`str` + The type of file to upload, this should a string of either `'attachments'`, `'avatars'`, `'backgrounds'`, `'icons'`, `'banners'` or `'emojis'` + Returns + -------- + :class:`Ulid` + The id of the file that was uploaded + """ + asset = await self.http.upload_file(file, tag) + + ulid = Ulid() + ulid.id = asset["id"] + + return ulid + + # events + + async def on_ready(self) -> None: + pass + + async def on_message(self, message: next.Message) -> None: + pass + + async def on_raw_message_update(self, payload: next.types.MessageUpdateEventPayload) -> None: + pass + + async def on_message_update(self, before: next.Message, after: next.Message) -> None: + pass + + async def on_raw_message_delete(self, payload: next.types.MessageDeleteEventPayload) -> None: + pass + + async def on_message_delete(self, message: next.Message) -> None: + pass + + async def on_channel_create(self, channel: next.Channel) -> None: + pass + + async def on_channel_update(self, before: next.Channel, after: next.Channel) -> None: + pass + + async def on_channel_delete(self, channel: next.Channel) -> None: + pass + + async def on_typing_start(self, channel: next.Channel, user: next.User) -> None: + pass + + async def on_typing_stop(self, channel: next.Channel, user: next.User) -> None: + pass + + async def on_server_update(self, before: next.Server, after: next.Server) -> None: + pass + + async def on_server_delete(self, server: next.Server) -> None: + pass + + async def on_server_join(self, server: next.Server) -> None: + pass + + async def on_member_update(self, before: next.Member, after: next.Member) -> None: + pass + + async def on_member_join(self, member: next.Member) -> None: + pass + + async def on_member_leave(self, member: next.Member) -> None: + pass + + async def on_role_create(self, role: next.Role) -> None: + pass + + async def on_role_update(self, before: next.Role, after: next.Role) -> None: + pass + + async def on_role_delete(self, role: next.Role) -> None: + pass + + async def on_user_update(self, before: next.User, after: next.User) -> None: + pass + + async def on_user_relationship_update(self, user: next.User, before: next.RelationshipType, after: next.RelationshipType) -> None: + pass + + async def on_raw_reaction_add(self, payload: next.types.MessageReactEventPayload) -> None: + pass + + async def on_reaction_add(self, message: next.Message, user: next.User, emoji_id: str) -> None: + pass + + async def on_raw_reaction_remove(self, payload: next.types.MessageUnreactEventPayload) -> None: + pass + + async def on_reaction_remove(self, message: next.Message, user: next.User, emoji_id: str) -> None: + pass + + async def on_raw_reaction_clear(self, payload: next.types.MessageRemoveReactionEventPayload) -> None: + pass + + async def on_reaction_clear(self, message: next.Message, user: next.User, emoji_id: str) -> None: + pass + + async def raw_bulk_message_delete(self, payload: next.types.BulkMessageDeleteEventPayload) -> None: + pass + + async def bulk_message_delete(self, messages: list[next.Message]) -> None: + pass diff --git a/next/embed.py b/next/embed.py new file mode 100644 index 0000000..8fd8096 --- /dev/null +++ b/next/embed.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, TypedDict, Union + +from typing_extensions import NotRequired, Unpack + +from next.types.embed import WebsiteSpecial + +from .asset import Asset +from .enums import EmbedType + +if TYPE_CHECKING: + from .state import State + from .types import Embed as EmbedPayload + from .types import ImageEmbed as ImageEmbedPayload + from .types import SendableEmbed as SendableEmbedPayload + from .types import TextEmbed as TextEmbedPayload + from .types import WebsiteEmbed as WebsiteEmbedPayload + from .types import JanuaryImage, JanuaryVideo + +__all__ = ("Embed", "WebsiteEmbed", "ImageEmbed", "TextEmbed", "NoneEmbed", "to_embed", "SendableEmbed") + +class WebsiteEmbed: + type = EmbedType.website + + def __init__(self, embed: WebsiteEmbedPayload): + self.url: str | None = embed.get("url") + self.special: WebsiteSpecial | None = embed.get("special") + self.title: str | None = embed.get("title") + self.description: str | None = embed.get("description") + self.image: JanuaryImage | None = embed.get("image") + self.video: JanuaryVideo | None = embed.get("video") + self.site_name: str | None = embed.get("site_name") + self.icon_url: str | None = embed.get("icon_url") + self.colour: str | None = embed.get("colour") + +class ImageEmbed: + type: EmbedType = EmbedType.image + + def __init__(self, image: ImageEmbedPayload): + self.url: str = image.get("url") + self.width: int = image.get("width") + self.height: int = image.get("height") + self.size: str = image.get("size") + +class TextEmbed: + type: EmbedType = EmbedType.text + + def __init__(self, embed: TextEmbedPayload, state: State): + self.icon_url: str | None = embed.get("icon_url") + self.url: str | None = embed.get("url") + self.title: str | None = embed.get("title") + self.description: str | None = embed.get("description") + + self.media: Asset | None + + if media := embed.get("media"): + self.media = Asset(media, state) + else: + self.media = None + + self.colour: str | None = embed.get("colour") + +class NoneEmbed: + type: EmbedType = EmbedType.none + +Embed = Union[WebsiteEmbed, ImageEmbed, TextEmbed, NoneEmbed] + +def to_embed(payload: EmbedPayload, state: State) -> Embed: + if payload["type"] == "Website": + return WebsiteEmbed(payload) + elif payload["type"] == "Image": + return ImageEmbed(payload) + elif payload["type"] == "Text": + return TextEmbed(payload, state) + else: + return NoneEmbed() + +class EmbedParameters(TypedDict): + title: NotRequired[str] + description: NotRequired[str] + media: NotRequired[str] + icon_url: NotRequired[str] + colour: NotRequired[str] + url: NotRequired[str] + +class SendableEmbed: + """ + Represents an embed that can be sent in a message, you will never receive this, you will receive :class:`Embed`. + + Attributes + ----------- + title: Optional[:class:`str`] + The title of the embed + + description: Optional[:class:`str`] + The description of the embed + + media: Optional[:class:`str`] + The file inside the embed, this is the ID of the file, you can use :meth:`Client.upload_file` to get an ID. + + icon_url: Optional[:class:`str`] + The url of the icon url + + colour: Optional[:class:`str`] + The embed's accent colour, this is any valid `CSS color `_ + + url: Optional[:class:`str`] + URL for hyperlinking the embed's title + """ + def __init__(self, **attrs: Unpack[EmbedParameters]): + self.title: Optional[str] = None + self.description: Optional[str] = None + self.media: Optional[str] = None + self.icon_url: Optional[str] = None + self.colour: Optional[str] = None + self.url: Optional[str] = None + + for key, value in attrs.items(): + setattr(self, key, value) + + def to_dict(self) -> SendableEmbedPayload: + """Converts the embed to a dictionary which next accepts + + Returns + -------- + :class:`dict[str, Any]` + The embed + """ + output: SendableEmbedPayload = {"type": "Text"} + + if title := self.title: + output["title"] = title + + if description := self.description: + output["description"] = description + + if media := self.media: + output["media"] = media + + if icon_url := self.icon_url: + output["icon_url"] = icon_url + + if colour := self.colour: + output["colour"] = colour + + if url := self.url: + output["url"] = url + + return output diff --git a/next/emoji.py b/next/emoji.py new file mode 100644 index 0000000..f2e4ac6 --- /dev/null +++ b/next/emoji.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .utils import Ulid + +if TYPE_CHECKING: + from .server import Server + from .state import State + from .types import Emoji as EmojiPayload + +__all__ = ("Emoji",) + +class Emoji(Ulid): + """Represents a custom emoji. + + Attributes + ----------- + id: :class:`str` + The id of the emoji + author_id: :class:`str` + The id of the of user who created the emoji + name: :class:`str` + The name of the emoji + animated: :class:`bool` + Whether or not the emoji is animated + nsfw: :class:`bool` + Whether or not the emoji is nsfw + server_id: Optional[:class:`str`] + The server id this emoji belongs to, if any + """ + def __init__(self, payload: EmojiPayload, state: State): + self.state: State = state + + self.id: str = payload["_id"] + self.author_id: str = payload["creator_id"] + self.name: str = payload["name"] + self.animated: bool = payload.get("animated", False) + self.nsfw: bool = payload.get("nsfw", False) + self.server_id: str | None = payload["parent"].get("id") + + async def delete(self) -> None: + """Deletes the emoji.""" + await self.state.http.delete_emoji(self.id) + + @property + def server(self) -> Server: + """Returns the server this emoji is part of + + Returns + -------- + :class:`Server` + The Server this emoji is part of + """ + return self.state.get_server(self.server_id) # type: ignore diff --git a/next/enums.py b/next/enums.py new file mode 100644 index 0000000..6473d4e --- /dev/null +++ b/next/enums.py @@ -0,0 +1,59 @@ +from typing import TYPE_CHECKING + +# typing does not understand aenum so I am pretending its stdlib enum while type checking + +if TYPE_CHECKING: + import enum +else: + import aenum as enum + + +__all__ = ( + "ChannelType", + "PresenceType", + "RelationshipType", + "AssetType", + "SortType", + "EmbedType" +) + +class ChannelType(enum.Enum): + saved_messages = "SavedMessages" + direct_message = "DirectMessage" + group = "Group" + text_channel = "TextChannel" + voice_channel = "VoiceChannel" + +class PresenceType(enum.Enum): + busy = "Busy" + idle = "Idle" + invisible = "Invisible" + online = "Online" + focus = "Focus" + +class RelationshipType(enum.Enum): + blocked = "Blocked" + blocked_other = "BlockedOther" + friend = "Friend" + incoming_friend_request = "Incoming" + none = "None" + outgoing_friend_request = "Outgoing" + user = "User" + +class AssetType(enum.Enum): + image = "Image" + video = "Video" + text = "Text" + audio = "Audio" + file = "File" + +class SortType(enum.Enum): + latest = "Latest" + oldest = "Oldest" + relevance = "Relevance" + +class EmbedType(enum.Enum): + website = "Website" + image = "Image" + text = "Text" + none = "None" diff --git a/next/errors.py b/next/errors.py new file mode 100644 index 0000000..26e3ef1 --- /dev/null +++ b/next/errors.py @@ -0,0 +1,26 @@ +__all__ = ( + "NextError", + "HTTPError", + "ServerError", + "FeatureDisabled", + "AutumnDisabled", + "Forbidden", +) + +class NextError(Exception): + "Base exception for next" + +class HTTPError(NextError): + "Base exception for http errors" + +class ServerError(NextError): + "Internal server error" + +class FeatureDisabled(NextError): + "Base class for any feature disabled errors" + +class AutumnDisabled(FeatureDisabled): + "The autumn feature is disabled" + +class Forbidden(HTTPError): + "Missing permissions" diff --git a/next/ext/__init__.py b/next/ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/next/ext/commands/__init__.py b/next/ext/commands/__init__.py new file mode 100644 index 0000000..93ea2af --- /dev/null +++ b/next/ext/commands/__init__.py @@ -0,0 +1,10 @@ +from .checks import * +from .client import * +from .cog import * +from .command import * +from .context import * +from .converters import * +from .cooldown import * +from .errors import * +from .group import * +from .help import * diff --git a/next/ext/commands/checks.py b/next/ext/commands/checks.py new file mode 100644 index 0000000..be482fd --- /dev/null +++ b/next/ext/commands/checks.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import Any, Callable, Coroutine, Union, cast +from typing_extensions import TypeVar + +import next + +from .command import Command +from .context import Context +from .errors import (MissingPermissionsError, NotBotOwner, NotServerOwner, + ServerOnly) +from .utils import ClientT_D + +__all__ = ("check", "Check", "is_bot_owner", "is_server_owner", "has_permissions", "has_channel_permissions") + +T = TypeVar("T", Callable[..., Any], Command, default=Command) + +Check = Callable[[Context[ClientT_D]], Union[Any, Coroutine[Any, Any, Any]]] + +def check(check: Check[ClientT_D]) -> Callable[[T], T]: + """A decorator for adding command checks + + Parameters + ----------- + check: Callable[[Context], Union[Any, Coroutine[Any, Any, Any]]] + The function to be called, must take one parameter, context and optionally be a coroutine, the return value denoating whether the check should pass or fail + """ + def inner(func: T) -> T: + if isinstance(func, Command): + command = cast(Command[ClientT_D], func) # cant verify generic at runtime so must cast + command.checks.append(check) + else: + checks = getattr(func, "_checks", []) + checks.append(check) + func._checks = checks # type: ignore + + return func + + return inner + +def is_bot_owner() -> Callable[[T], T]: + """A command check for limiting the command to only the bot's owner""" + @check + def inner(context: Context[ClientT_D]): + if user_id := context.client.user.owner_id: + if context.author.id == user_id: + return True + else: + if context.author.id == context.client.user.id: + return True + + raise NotBotOwner + + return inner + +def is_server_owner() -> Callable[[T], T]: + """A command check for limiting the command to only a server's owner""" + @check + def inner(context: Context[ClientT_D]) -> bool: + if not context.server_id: + raise ServerOnly + + if context.author.id == context.server.owner_id: + return True + + raise NotServerOwner + + return inner + +def has_permissions(**permissions: bool) -> Callable[[T], T]: + @check + def inner(context: Context[ClientT_D]) -> bool: + author = context.author + + if not author.has_permissions(**permissions): + raise MissingPermissionsError(permissions) + + return True + + return inner + +def has_channel_permissions(**permissions: bool) -> Callable[[T], T]: + @check + def inner(context: Context[ClientT_D]) -> bool: + author = context.author + + if not isinstance(author, next.Member): + raise ServerOnly + + if not author.has_channel_permissions(context.channel, **permissions): + raise MissingPermissionsError(permissions) + + return True + + return inner diff --git a/next/ext/commands/client.py b/next/ext/commands/client.py new file mode 100644 index 0000000..a67cb88 --- /dev/null +++ b/next/ext/commands/client.py @@ -0,0 +1,389 @@ +from __future__ import annotations + +import sys +import traceback +from importlib import import_module +from typing import (TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeVar, Union, + overload, runtime_checkable) + +from typing_extensions import Self + +import next + +if TYPE_CHECKING: + from .help import HelpCommand + +from .cog import Cog +from .command import Command +from .context import Context +from .errors import CheckError, CommandNotFound, MissingSetup +from .view import StringView + +__all__ = ( + "CommandsMeta", + "CommandsClient" +) + +V = TypeVar("V") +T = TypeVar("T") + +@runtime_checkable +class ExtensionProtocol(Protocol): + @staticmethod + def setup(client: CommandsClient) -> None: + raise NotImplementedError + +class CommandsMeta(type): + _commands: list[Command[Any]] + + def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any]) -> Any: + commands: list[Command[Any]] = [] + self = super().__new__(cls, name, bases, attrs) + + for base in reversed(self.__mro__): + for value in base.__dict__.values(): + if isinstance(value, Command) and value.parent is None: + commands.append(value) + + self._commands = commands + + return self + + +class CaseInsensitiveDict(dict[str, V]): + def __setitem__(self, key: str, value: V) -> None: + super().__setitem__(key.casefold(), value) + + def __getitem__(self, key: str) -> V: + return super().__getitem__(key.casefold()) + + def __contains__(self, key: object) -> bool: + if isinstance(key, str): + return super().__contains__(key.casefold()) + else: + return False + + @overload + def get(self, key: str) -> V | None: + ... + + @overload + def get(self, key: str, default: V | T) -> V | T: + ... + + def get(self, key: str, default: Optional[T] = None) -> V | T | None: + return super().get(key.casefold(), default) + + def __delitem__(self, key: str) -> None: + super().__delitem__(key.casefold()) + + +class CommandsClient(next.Client, metaclass=CommandsMeta): + """Main class that adds commands, this class should be subclassed along with `next.Client`.""" + + _commands: list[Command[Self]] + + def __init__(self, *args: Any, help_command: Union[HelpCommand[Self], None, next.utils._Missing] = next.utils.Missing, case_insensitive: bool = False, **kwargs: Any): + from .help import DefaultHelpCommand, HelpCommandImpl + + self.all_commands: dict[str, Command[Self]] = {} if not case_insensitive else CaseInsensitiveDict() + self.cogs: dict[str, Cog[Self]] = {} + self.extensions: dict[str, ExtensionProtocol] = {} + + for command in self._commands: + self.all_commands[command.name] = command + + for alias in command.aliases: + self.all_commands[alias] = command + + self.help_command: HelpCommand[Self] | None + + if help_command is not None: + self.help_command = help_command or DefaultHelpCommand[Self]() + self.add_command(HelpCommandImpl(self)) + else: + self.help_command = None + + super().__init__(*args, **kwargs) + + @property + def commands(self) -> list[Command[Self]]: + """Gets all commands registered + + Returns + -------- + list[:class:`Command`] + The registered commands + """ + return list(set(self.all_commands.values())) + + async def get_prefix(self, message: next.Message) -> Union[str, list[str]]: + """Overwrite this function to set the prefix used for commands, this function is called for every message. + + Parameters + ----------- + message: :class:`Message` + The message that was sent + + Returns + -------- + Union[:class:`str`, list[:class:`str`]] + The prefix(s) for the commands + """ + raise NotImplementedError + + def get_command(self, name: str) -> Command[Self]: + """Gets a command. + + Parameters + ----------- + name: :class:`str` + The name or alias of the command + + Returns + -------- + :class:`Command` + The command with the name + """ + return self.all_commands[name] + + def add_command(self, command: Command[Self]) -> None: + """Adds a command, this is typically only used for dynamic commands, you should use the `commands.command` decorator for most usecases. + + Parameters + ----------- + name: :class:`str` + The name or alias of the command + command: :class:`Command` + The command to be added + """ + self.all_commands[command.name] = command + + for alias in command.aliases: + self.all_commands[alias] = command + + def remove_command(self, name: str) -> Optional[Command[Self]]: + """Removes a command. + + Parameters + ----------- + name: :class:`str` + The name or alias of the command + + Returns + -------- + Optional[:class:`Command`] + The command that was removed + """ + command = self.all_commands.pop(name, None) + + if command is not None: + for alias in command.aliases: + self.all_commands.pop(alias, None) + + return command + + def get_view(self, message: next.Message) -> type[StringView]: + """Returns the StringView class to use, this can be overwritten to customize how arguments are parsed + + Returns + -------- + type[:class:`StringView`] + The string view class to use + """ + return StringView + + def get_context(self, message: next.Message) -> type[Context[Self]]: + """Returns the Context class to use, this can be overwritten to add extra features to context + + Returns + -------- + type[:class:`Context`] + The context class to use + """ + return Context[Self] + + async def process_commands(self, message: next.Message) -> Any: + """Processes commands, if you overwrite `Client.on_message` you should manually call this function inside the event. + + Parameters + ----------- + message: :class:`Message` + The message to process commands on + + Returns + -------- + Any + The return of the command, if any + """ + content = message.content + + prefixes = await self.get_prefix(message) + + if isinstance(prefixes, str): + prefixes = [prefixes] + + for prefix in prefixes: + if content.startswith(prefix): + content = content[len(prefix):] + break + else: + return + + if not content: + return + + view = self.get_view(message)(content) + + try: + command_name = view.get_next_word() + except StopIteration: + return + + context_cls = self.get_context(message) + + try: + command = self.get_command(command_name) + except KeyError: + context = context_cls(None, command_name, view, message, self) + return self.dispatch("command_error", context, CommandNotFound(command_name)) + + context = context_cls(command, command_name, view, message, self) + + try: + self.dispatch("command", context) + + if not await self.global_check(context): + raise CheckError(f"the global check for the command failed") + + if not await context.can_run(): + raise CheckError(f"the check(s) for the command failed") + + output = await context.invoke() + self.dispatch("after_command_invoke", context, output) + + return output + except Exception as e: + await command._error_handler(command.cog or self, context, e) + self.dispatch("command_error", context, e) + + async def on_command_error(self, ctx: Context[Self], error: Exception, /) -> None: + traceback.print_exception(type(error), error, error.__traceback__) + + def on_message(self, message: next.Message) -> Coroutine[Any, Any, Any]: + return self.process_commands(message) + + async def global_check(self, context: Context[Self]) -> bool: + """A global check that stops commands from running on certain criteria. + + Parameters + ----------- + context: :class:`Context` + The context for the invokation of the command + + Returns + -------- + :class:`bool` represents if the command should run or not + """ + + return True + + def add_cog(self, cog: Cog[Self]) -> None: + """Adds a cog, this cog must subclass `Cog`. + + Parameters + ----------- + cog: :class:`Cog` + The cog to be added + """ + cog._inject(self) + + def remove_cog(self, cog_name: str) -> Cog[Self]: + """Removes a cog. + + Parameters + ----------- + cog_name: :class:`str` + The name of the cog to be removed + + Returns + -------- + :class:`Cog` + The cog that was removed + """ + cog = self.cogs.pop(cog_name) + cog._uninject(self) + + return cog + + def load_extension(self, name: str) -> None: + """Loads an extension, this takes a module name and runs the setup function inside of it. + + Parameters + ----------- + name: :class:`str` + The name of the extension to be loaded + """ + extension = import_module(name) + + if not isinstance(extension, ExtensionProtocol): + raise MissingSetup(f"'{extension}' is missing a setup function") + + self.extensions[name] = extension + extension.setup(self) + + def unload_extension(self, name: str) -> None: + """Unloads an extension, this takes a module name and runs the teardown function inside of it. + + Parameters + ----------- + name: :class:`str` + The name of the extension to be unloaded + """ + extension = self.extensions.pop(name) + + del sys.modules[name] + + if teardown := getattr(extension, "teardown", None): + teardown(self) + + def reload_extension(self, name: str) -> None: + """Reloads an extension, this will unload and reload the extension. + + Parameters + ----------- + name: :class:`str` + The name of the extension to be reloaded + """ + self.unload_extension(name) + self.load_extension(name) + + def get_cog(self, name: str) -> Cog[Self]: + """Gets a cog. + + Parameters + ----------- + name: :class:`str` + The name of the cog to get + + Returns + -------- + :class:`Cog` + The cog that was requested + """ + return self.cogs[name] + + def get_extension(self, name: str) -> ExtensionProtocol: + """Gets an extension. + + Parameters + ----------- + name: :class:`str` + The name of the extension to get + + Returns + -------- + :class:`ExtensionProtocol` + The extension that was requested + """ + return self.extensions[name] diff --git a/next/ext/commands/cog.py b/next/ext/commands/cog.py new file mode 100644 index 0000000..0842187 --- /dev/null +++ b/next/ext/commands/cog.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +from typing import Any, Callable, Coroutine, Generic, Optional, TypeVar +from typing_extensions import ParamSpec + +from next.errors import NextError + +from .command import Command +from .utils import ClientT_D + +P = ParamSpec("P") +R = TypeVar("R") + +__all__ = ("Cog", "CogMeta") + +class CogMeta(type): + _cog_commands: list[Command[Any]] + _cog_listeners: dict[str, list[str]] + qualified_name: str + + def __new__(cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], *, qualified_name: Optional[str] = None, extras: dict[str, Any] | None = None) -> Any: + commands: list[Command[Any]] = [] + listeners: dict[str, list[str]] = {} + + self = super().__new__(cls, name, bases, attrs) + extras = extras or {} + + for base in reversed(self.__mro__): + for key, value in base.__dict__.items(): + if isinstance(value, Command): + for extra_key, extra_value in extras.items(): + setattr(value, extra_key, extra_value) + + commands.append(value) + + elif event_name := getattr(value, "__listener_name", None): + listeners.setdefault(event_name, []).append(key) + + self._cog_commands = commands + self._cog_listeners = listeners + self.qualified_name = qualified_name or name + return self + +class Cog(Generic[ClientT_D], metaclass=CogMeta): + _cog_commands: list[Command[ClientT_D]] + _cog_listeners: dict[str, list[str]] + qualified_name: str + + def cog_load(self) -> None: + """A special method that is called when the cog gets loaded.""" + pass + + def cog_unload(self) -> None: + """A special method that is called when the cog gets removed.""" + pass + + def _inject(self, client: ClientT_D) -> None: + client.cogs[self.qualified_name] = self + + for command in self._cog_commands: + command.cog = self + + if command.parent is None: + client.add_command(command) + + for key, listeners in self._cog_listeners.items(): + for listener_name in listeners: + client.listeners.setdefault(key, []).append(getattr(self, listener_name)) + + self.cog_load() + + def _uninject(self, client: ClientT_D) -> None: + for name, command in client.all_commands.copy().items(): + if command in self._cog_commands: + del client.all_commands[name] + + for key, listeners in self._cog_listeners.items(): + for listener_name in listeners: + client.listeners[key].remove(getattr(self, listener_name)) + + self.cog_unload() + + @property + def commands(self) -> list[Command[ClientT_D]]: + return self._cog_commands + + @staticmethod + def listen(name: str | None = None) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]]: + def inner(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: + if not func.__name__.startswith("on_"): + raise NextError("event name must start with `on_`") + + setattr(func, "__listener_name", name or func.__name__[3:]) + return func + + return inner diff --git a/next/ext/commands/command.py b/next/ext/commands/command.py new file mode 100644 index 0000000..451b808 --- /dev/null +++ b/next/ext/commands/command.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import inspect +import traceback +from contextlib import suppress +from typing import (TYPE_CHECKING, Annotated, Any, Callable, Coroutine, + Generic, Literal, Optional, Union, get_args, get_origin) +from typing_extensions import ParamSpec +import sys + +if sys.version_info >= (3, 10): + from types import UnionType + + UnionTypes = (Union, UnionType) +else: + UnionTypes = (Union,) + +from ...utils import maybe_coroutine + +from .errors import CommandOnCooldown, InvalidLiteralArgument, UnionConverterError +from .utils import ClientT_Co_D, evaluate_parameters, ClientT_Co +from .cooldown import BucketType, CooldownMapping + +if TYPE_CHECKING: + from .checks import Check + from .cog import Cog + from .context import Context + from .group import Group + +__all__: tuple[str, ...] = ( + "Command", + "command" +) + +NoneType: type[None] = type(None) +P = ParamSpec("P") + +class Command(Generic[ClientT_Co_D]): + """Class for holding info about a command. + + Parameters + ----------- + callback: Callable[..., Coroutine[Any, Any, Any]] + The callback for the command + name: :class:`str` + The name of the command + aliases: list[:class:`str`] + The aliases of the command + parent: Optional[:class:`Group`] + The parent of the command if this command is a subcommand + cog: Optional[:class:`Cog`] + The cog the command is apart of. + usage: Optional[:class:`str`] + The usage string for the command + checks: Optional[list[Callable]] + The list of checks the command has + cooldown: Optional[:class:`Cooldown`] + The cooldown for the command to restrict how often the command can be used + description: Optional[:class:`str`] + The commands description if it has one + hidden: :class:`bool` + Whether or not the command should be hidden from the help command + """ + __slots__ = ("callback", "name", "aliases", "signature", "checks", "parent", "_error_handler", "cog", "description", "usage", "parameters", "hidden", "cooldown", "cooldown_bucket") + + def __init__( + self, + callback: Callable[..., Coroutine[Any, Any, Any]], + name: str, + *, + aliases: list[str] | None = None, + usage: Optional[str] = None, + checks: list[Check[ClientT_Co_D]] | None = None, + cooldown: Optional[CooldownMapping] | None = None, + bucket: Optional[BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]]] = None, + description: str | None = None, + hidden: bool = False, + ): + self.callback: Callable[..., Coroutine[Any, Any, Any]] = callback + self.name: str = name + self.aliases: list[str] = aliases or [] + self.usage: str | None = usage + self.signature: inspect.Signature = inspect.signature(self.callback) + self.parameters: list[inspect.Parameter] = evaluate_parameters(self.signature.parameters.values(), getattr(callback, "__globals__", {})) + self.checks: list[Check[ClientT_Co_D]] = checks or getattr(callback, "_checks", []) + self.cooldown: CooldownMapping | None = cooldown or getattr(callback, "_cooldown", None) + self.cooldown_bucket: BucketType | Callable[[Context[ClientT_Co_D]], Coroutine[Any, Any, str]] = bucket or getattr(callback, "_bucket", BucketType.default) + self.parent: Optional[Group[ClientT_Co_D]] = None + self.cog: Optional[Cog[ClientT_Co_D]] = None + self._error_handler: Callable[[Any, Context[ClientT_Co_D], Exception], Coroutine[Any, Any, Any]] = type(self)._default_error_handler + self.description: str | None = description or callback.__doc__ + self.hidden: bool = hidden + + async def invoke(self, context: Context[ClientT_Co_D], *args: Any, **kwargs: Any) -> Any: + """Runs the command and calls the error handler if the command errors. + + Parameters + ----------- + context: :class:`Context` + The context for the command + args: list[:class:`str`] + The arguments for the command + """ + try: + return await self.callback(self.cog or context.client, context, *args, **kwargs) + except Exception as err: + return await self._error_handler(self.cog or context.client, context, err) + + def __call__(self, context: Context[ClientT_Co_D], *args: Any, **kwargs: Any) -> Any: + return self.invoke(context, *args, **kwargs) + + def error(self, func: Callable[..., Coroutine[Any, Any, Any]]) -> Callable[..., Coroutine[Any, Any, Any]]: + """Sets the error handler for the command. + + Parameters + ----------- + func: Callable[..., Coroutine[Any, Any, Any]] + The function for the error handler + + Example + -------- + .. code-block:: python3 + + @mycommand.error + async def mycommand_error(self, ctx, error): + await ctx.send(str(error)) + + """ + self._error_handler = func + return func + + async def _default_error_handler(self, ctx: Context[ClientT_Co_D], error: Exception): + traceback.print_exception(type(error), error, error.__traceback__) + + @classmethod + async def handle_origin(cls, context: Context[ClientT_Co_D], origin: Any, annotation: Any, arg: str) -> Any: + if origin in UnionTypes: + for converter in get_args(annotation): + try: + return await cls.convert_argument(arg, converter, context) + except: + if converter is NoneType: + context.view.undo() + return None + + raise UnionConverterError(arg) + + elif origin is Annotated: + annotated_args = get_args(annotation) + + if annotated_args[1] == "_next_greedy_marker": + real_annotation = get_args(annotated_args[0])[0] + converted_args: list[Any] = [] + + converted_args.append(await cls.convert_argument(arg, real_annotation, context)) + + for arg in context.view: + try: + converted_args.append(await cls.convert_argument(arg, real_annotation, context)) + except: + context.view.undo() + break + + return converted_args + else: + return await cls.convert_argument(arg, annotated_args[1], context) + + elif origin is Literal: + if arg in get_args(annotation): + return arg + else: + raise InvalidLiteralArgument(arg) + + @classmethod + async def convert_argument(cls, arg: str, annotation: Any, context: Context[ClientT_Co_D]) -> Any: + if annotation is not inspect.Signature.empty: + if annotation is str: # no converting is needed - its already a string + return arg + + origin: Any + if origin := get_origin(annotation): + return await cls.handle_origin(context, origin, annotation, arg) + else: + return await maybe_coroutine(annotation, arg, context) + else: + return arg + + async def parse_arguments(self, context: Context[ClientT_Co_D]) -> None: + # please pr if you can think of a better way to do this + + for parameter in self.parameters[2:]: + if parameter.kind == parameter.KEYWORD_ONLY: + try: + arg = await self.convert_argument(context.view.get_rest(), parameter.annotation, context) + except StopIteration: + if parameter.default is not parameter.empty: + arg = parameter.default + + elif is_optional(parameter.annotation): + arg = None + + else: + raise + + context.kwargs[parameter.name] = arg + + elif parameter.kind == parameter.VAR_POSITIONAL: + with suppress(StopIteration): + while True: + context.args.append(await self.convert_argument(context.view.get_next_word(), parameter.annotation, context)) + + elif parameter.kind == parameter.POSITIONAL_OR_KEYWORD: + try: + rest = context.view.get_next_word() + arg = await self.convert_argument(rest, parameter.annotation, context) + except StopIteration: + if parameter.default is not parameter.empty: + arg = parameter.default + + elif is_optional(parameter.annotation): + arg = None + + else: + raise + + context.args.append(arg) + + async def run_cooldown(self, context: Context[ClientT_Co_D]) -> None: + if mapping := self.cooldown: + if isinstance(self.cooldown_bucket, BucketType): + key = self.cooldown_bucket.resolve(context) + else: + key = await self.cooldown_bucket(context) + + cooldown = mapping.get_bucket(key) + + if retry_after := cooldown.update_cooldown(): + raise CommandOnCooldown(retry_after) + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} name=\"{self.name}\">" + + @property + def short_description(self) -> Optional[str]: + """Returns the first line of the description or None if there is no description.""" + if self.description: + return self.description.split("\n")[0] + + def get_usage(self) -> str: + """Returns the usage string for the command.""" + if self.usage: + return self.usage + + parents: list[str] = [] + + if self.parent: + parent = self.parent + + while parent: + parents.append(parent.name) + parent = parent.parent + + parameters: list[str] = [] + + for parameter in self.parameters[2:]: + if parameter.kind == parameter.POSITIONAL_OR_KEYWORD: + if parameter.default is not parameter.empty: + parameters.append(f"[{parameter.name}]") + else: + parameters.append(f"<{parameter.name}>") + elif parameter.kind == parameter.KEYWORD_ONLY: + if parameter.default is not parameter.empty: + parameters.append(f"[{parameter.name}]") + else: + parameters.append(f"<{parameter.name}...>") + elif parameter.kind == parameter.VAR_POSITIONAL: + parameters.append(f"[{parameter.name}...]") + + return f"{' '.join(parents[::-1])} {self.name} {' '.join(parameters)}" + +def is_optional(arg: Any) -> bool: + return get_origin(arg) in UnionTypes and any(arg is NoneType for arg in get_args(arg)) + +def command( + *, + name: Optional[str] = None, + aliases: Optional[list[str]] = None, + cls: type[Command[ClientT_Co]] = Command, + usage: Optional[str] = None +) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Command[ClientT_Co]]: + """A decorator that turns a function into a :class:`Command`.n + + Parameters + ----------- + name: Optional[:class:`str`] + The name of the command, this defaults to the functions name + aliases: Optional[list[:class:`str`]] + The aliases of the command, defaults to no aliases + cls: type[:class:`Command`] + The class used for creating the command, this defaults to :class:`Command` but can be used to use a custom command subclass + usage: Optional[:class:`str`] + The signature for how the command should be called + + Returns + -------- + Callable[Callable[..., Coroutine], :class:`Command`] + A function that takes the command callback and returns a :class:`Command` + """ + def inner(func: Callable[..., Coroutine[Any, Any, Any]]) -> Command[ClientT_Co]: + return cls(func, name or func.__name__, aliases=aliases or [], usage=usage) + + return inner diff --git a/next/ext/commands/context.py b/next/ext/commands/context.py new file mode 100644 index 0000000..5ac248c --- /dev/null +++ b/next/ext/commands/context.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, Optional + +import next +from next.utils import maybe_coroutine + +from .command import Command +from .group import Group +from .utils import ClientT_Co_D + +if TYPE_CHECKING: + from .view import StringView + from next.state import State + +__all__ = ( + "Context", +) + +class Context(next.Messageable, Generic[ClientT_Co_D]): + """Stores metadata the commands execution. + + Attributes + ----------- + command: Optional[:class:`Command`] + The command, this can be `None` when no command was found and the error handler is being executed + invoked_with: :class:`str` + The command name that was used, this can be an alias, the commands name or a command that doesnt exist + message: :class:`Message` + The message that was sent to invoke the command + channel: :class:`Messageable` + The channel the command was invoked in + server_id: Optional[:class:`Server`] + The server the command was invoked in + author: Union[:class:`Member`, :class:`User`] + The user or member that invoked the commad, will be :class:`User` in DMs + args: list[:class:`str`] + The positional arguments being passed to the command + kwargs: dict[:class:`str`, Any] + The keyword arguments being passed to the command + client: :class:`CommandsClient` + The next client + """ + __slots__ = ("command", "invoked_with", "args", "message", "channel", "author", "view", "kwargs", "state", "client", "server_id") + + async def _get_channel_id(self) -> str: + return self.channel.id + + def __init__(self, command: Optional[Command[ClientT_Co_D]], invoked_with: str, view: StringView, message: next.Message, client: ClientT_Co_D): + self.command: Command[ClientT_Co_D] | None = command + self.invoked_with: str = invoked_with + self.view: StringView = view + self.message: next.Message = message + self.client: ClientT_Co_D = client + self.args: list[Any] = [] + self.kwargs: dict[str, Any] = {} + self.server_id: str | None = message.server_id + self.channel: next.TextChannel | next.GroupDMChannel | next.DMChannel | next.SavedMessageChannel = message.channel + self.author: next.Member | next.User = message.author + self.state: State = message.state + + @property + def server(self) -> next.Server: + """:class:`Server` The server this context belongs too + + Raises + ------- + :class:`LookupError` + Raises if the context is not from a server + """ + if not self.server_id: + raise LookupError + + return self.state.get_server(self.server_id) + + async def invoke(self) -> Any: + """Invokes the command. + + .. note:: If the command is `None`, this function will do nothing. + + Parameters + ----------- + args: list[:class:`str`] + The args being passed to the command + """ + + if command := self.command: + if isinstance(command, Group): + try: + subcommand_name = self.view.get_next_word() + except StopIteration: + pass + else: + if subcommand := command.subcommands.get(subcommand_name): + self.command = command = subcommand + return await self.invoke() + + self.view.undo() + + await command.run_cooldown(self) + await command.parse_arguments(self) + return await command.invoke(self, *self.args, **self.kwargs) + + async def can_run(self, command: Optional[Command[ClientT_Co_D]] = None) -> bool: + """Runs all of the commands checks, and returns true if all of them pass""" + command = command or self.command + + return all([await maybe_coroutine(check, self) for check in (command.checks if command else [])]) + + async def send_help(self, argument: Command[Any] | Group[Any] | ClientT_Co_D | None = None) -> None: + argument = argument or self.client + + command = self.client.get_command("help") + await command.invoke(self, argument) diff --git a/next/ext/commands/converters.py b/next/ext/commands/converters.py new file mode 100644 index 0000000..aac7c44 --- /dev/null +++ b/next/ext/commands/converters.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import re +from typing import TYPE_CHECKING, Annotated, TypeVar + +from next import Category, Channel, Member, User, utils + +from .context import Context +from .errors import (BadBoolArgument, CategoryConverterError, + ChannelConverterError, MemberConverterError, ServerOnly, + UserConverterError) + +if TYPE_CHECKING: + from .client import CommandsClient + +T = TypeVar("T") + +__all__: tuple[str, ...] = ("bool_converter", "category_converter", "channel_converter", "user_converter", "member_converter", "IntConverter", "BoolConverter", "CategoryConverter", "UserConverter", "MemberConverter", "ChannelConverter", "Greedy") + +channel_regex: re.Pattern[str] = re.compile("<#([A-z0-9]{26})>") +user_regex: re.Pattern[str] = re.compile("<@([A-z0-9]{26})>") + +ClientT = TypeVar("ClientT", bound="CommandsClient") + +def bool_converter(arg: str, _: Context[ClientT]) -> bool: + lowered = arg.lower() + if lowered in ("yes", "true", "ye", "y", "1", "on", "enable"): + return True + elif lowered in ("no", "false", "n", "f", "0", "off", "disabled"): + return False + else: + raise BadBoolArgument(lowered) + +def category_converter(arg: str, context: Context[ClientT]) -> Category: + if not context.server_id: + raise ServerOnly + + try: + return context.server.get_category(arg) + except LookupError: + try: + return utils.get(context.server.categories, name=arg) + except LookupError: + raise CategoryConverterError(arg) + +def channel_converter(arg: str, context: Context[ClientT]) -> Channel: + if not context.server_id: + raise ServerOnly + + if (match := channel_regex.match(arg)): + arg = match.group(1) + + try: + return context.server.get_channel(arg) + except LookupError: + try: + return utils.get(context.server.channels, name=arg) + except LookupError: + raise ChannelConverterError(arg) + +def user_converter(arg: str, context: Context[ClientT]) -> User: + if (match := user_regex.match(arg)): + arg = match.group(1) + + try: + return context.client.get_user(arg) + except LookupError: + try: + parts = arg.split("#") + + if len(parts) == 1: + return ( + utils.get(context.client.users, original_name=arg) + or utils.get(context.client.users, display_name=arg) + ) + elif len(parts) == 2: + return ( + utils.get(context.client.users, original_name=parts[0], discriminator=parts[1]) + or utils.get(context.client.users, display_name=parts[0], discriminator=parts[1]) + ) + else: + raise LookupError + + except LookupError: + raise UserConverterError(arg) + +def member_converter(arg: str, context: Context[ClientT]) -> Member: + if not context.server_id: + raise ServerOnly + + if (match := user_regex.match(arg)): + arg = match.group(1) + + try: + return context.server.get_member(arg) + except LookupError: + try: + parts = arg.split("#") + + if len(parts) == 1: + return ( + utils.get(context.server.members, original_name=arg) + or utils.get(context.server.members, display_name=arg) + ) + elif len(parts) == 2: + return ( + utils.get(context.server.members, original_name=parts[0], discriminator=parts[1]) + or utils.get(context.server.members, display_name=parts[0], discriminator=parts[1]) + ) + else: + raise LookupError + + except LookupError: + raise MemberConverterError(arg) + +def int_converter(arg: str, context: Context[ClientT]) -> int: + return int(arg) + +IntConverter = Annotated[int, int_converter] +BoolConverter = Annotated[bool, bool_converter] +CategoryConverter = Annotated[Category, category_converter] +UserConverter = Annotated[User, user_converter] +MemberConverter = Annotated[Member, member_converter] +ChannelConverter = Annotated[Channel, channel_converter] + +Greedy = Annotated[list[T], "_next_greedy_marker"] \ No newline at end of file diff --git a/next/ext/commands/cooldown.py b/next/ext/commands/cooldown.py new file mode 100644 index 0000000..781cb73 --- /dev/null +++ b/next/ext/commands/cooldown.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Callable, Coroutine, TypeVar, cast + +from .errors import ServerOnly + +if TYPE_CHECKING: + from enum import Enum + + from .context import Context + from .utils import ClientT_Co_D, ClientT_Co +else: + from aenum import Enum + +__all__ = ("Cooldown", "CooldownMapping", "BucketType", "cooldown") + +T = TypeVar("T") + +class Cooldown: + """Represent a single cooldown for a single key + + Parameters + ----------- + rate: :class:`int` + How many times it can be used + per: :class:`int` + How long the window is before the ratelimit resets + """ + + def __init__(self, rate: int, per: int): + self.rate: int = rate + self.per: int = per + self.window: float = 0.0 + self.tokens: int = rate + self.last: float = 0.0 + + def get_tokens(self, current: float | None) -> int: + current = current or time.time() + + if current > (self.window + self.per): + return self.rate + else: + return self.tokens + + def update_cooldown(self) -> float | None: + current = time.time() + + self.last = current + + self.tokens = self.get_tokens(current) + + if self.tokens == 0: + return self.per - (current - self.window) + + self.tokens -= 1 + + if self.tokens == 0: + self.window = current + + return None + +class CooldownMapping: + """Holds all cooldowns for every key""" + def __init__(self, rate: int, per: int): + self.rate = rate + self.per = per + self.cache: dict[str, Cooldown] = {} + + def verify_cache(self) -> None: + current = time.time() + self.cache = {k: v for k, v in self.cache.items() if current < (v.last + v.per)} + + def get_bucket(self, key: str) -> Cooldown: + self.verify_cache() + + if not (rl := self.cache.get(key)): + self.cache[key] = rl = Cooldown(self.rate, self.per) + + return rl + +class BucketType(Enum): + default = 0 + user = 1 + server = 2 + channel = 3 + member = 4 + + def resolve(self, context: Context[ClientT_Co_D]) -> str: + if self == BucketType.default: + return f"{context.author.id}{context.channel.id}" + + elif self == BucketType.user: + return context.author.id + + elif self == BucketType.server: + if id := context.server_id: + return id + + raise ServerOnly + + elif self == BucketType.channel: + return context.channel.id + + else: # BucketType.member + if server_id := context.server_id: + return f"{context.author.id}{server_id}" + + raise ServerOnly + +def cooldown(rate: int, per: int, *, bucket: BucketType | Callable[[Context[ClientT_Co]], Coroutine[Any, Any, str]] = BucketType.default) -> Callable[[T], T]: + """Adds a cooldown to a command + + Parameters + ----------- + rate: :class:`int` + How many times it can be used + per: :class:`int` + How long the window is before the ratelimit resets + bucket: Optional[Union[:class:`BucketType`, Callable[[Context], str]]] + Controls how the key is generated for the cooldowns + + Examples + -------- + .. code-block:: python + @commands.command() + @commands.cooldown(1, 5) + async def ping(self, ctx: Context): + await ctx.send("Pong") + """ + def inner(func: T) -> T: + from .command import Command + + if isinstance(func, Command): + command = cast(Command[ClientT_Co], func) # cant verify generic at runtime so must cast + command.cooldown = CooldownMapping(rate, per) + command.cooldown_bucket = bucket + else: + func._cooldown = CooldownMapping(rate, per) # type: ignore + func._bucket = bucket # type: ignore + + return func + + return inner \ No newline at end of file diff --git a/next/ext/commands/errors.py b/next/ext/commands/errors.py new file mode 100644 index 0000000..cd4b7cf --- /dev/null +++ b/next/ext/commands/errors.py @@ -0,0 +1,114 @@ +from next import NextError + +__all__ = ( + "CommandError", + "CommandNotFound", + "NoClosingQuote", + "CheckError", + "NotBotOwner", + "NotServerOwner", + "ServerOnly", + "ConverterError", + "InvalidLiteralArgument", + "BadBoolArgument", + "CategoryConverterError", + "ChannelConverterError", + "UserConverterError", + "MemberConverterError", + "MissingSetup", + "CommandOnCooldown" +) + +class CommandError(NextError): + """base error for all command's related errors""" + +class CommandNotFound(CommandError): + """Raised when a command isnt found. + + Parameters + ----------- + command_name: :class:`str` + The name of the command that wasnt found + """ + __slots__ = ("command_name",) + + def __init__(self, command_name: str): + self.command_name: str = command_name + +class NoClosingQuote(CommandError): + """Raised when there is no closing quote for a command argument""" + +class CheckError(CommandError): + """Raised when a check fails for a command""" + +class NotBotOwner(CheckError): + """Raised when the `is_bot_owner` check fails""" + +class NotServerOwner(CheckError): + """Raised when the `is_server_owner` check fails""" + +class ServerOnly(CheckError): + """Raised when a check requires the command to be ran in a server""" + +class MissingPermissionsError(CheckError): + """Raised when a check requires permissions the user does not have + + Attributes + ----------- + permissions: :class:`dict[str, bool]` + The permissions which the user did not have + """ + + def __init__(self, permissions: dict[str, bool]): + self.permissions = permissions + +class ConverterError(CommandError): + """Base class for all converter errors""" + +class InvalidLiteralArgument(ConverterError): + """Raised when the argument is not a valid literal argument""" + +class BadBoolArgument(ConverterError): + """Raised when the bool converter fails""" + +class CategoryConverterError(ConverterError): + """Raised when the Category conveter fails""" + def __init__(self, argument: str): + self.argument = argument + +class ChannelConverterError(ConverterError): + """Raised when the Channel conveter fails""" + def __init__(self, argument: str): + self.argument = argument + +class UserConverterError(ConverterError): + """Raised when the Category conveter fails""" + def __init__(self, argument: str): + self.argument = argument + +class MemberConverterError(ConverterError): + """Raised when the Category conveter fails""" + def __init__(self, argument: str): + self.argument = argument + +class UnionConverterError(ConverterError): + """Raised when all converters in a union fails""" + def __init__(self, argument: str): + self.argument = argument + +class MissingSetup(CommandError): + """Raised when an extension is missing the `setup` function""" + +class CommandOnCooldown(CommandError): + """Raised when a command is on cooldown + + Attributes + ----------- + retry_after: :class:`float` + How long the user must wait until the cooldown resets + """ + + __slots__ = ("retry_after",) + + def __init__(self, retry_after: float): + self.retry_after: float = retry_after \ No newline at end of file diff --git a/next/ext/commands/group.py b/next/ext/commands/group.py new file mode 100644 index 0000000..a72a878 --- /dev/null +++ b/next/ext/commands/group.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +from typing import Any, Callable, Coroutine, Optional + +from .command import Command +from .utils import ClientT_Co_D, ClientT_D + + +__all__ = ( + "Group", + "group" +) + +class Group(Command[ClientT_Co_D]): + """Class for holding info about a group command. + + Parameters + ----------- + callback: Callable[..., Coroutine[Any, Any, Any]] + The callback for the group command + name: :class:`str` + The name of the command + aliases: list[:class:`str`] + The aliases of the group command + subcommands: dict[:class:`str`, :class:`Command`] + The group's subcommands. + """ + + __slots__: tuple[str, ...] = ("subcommands",) + + def __init__(self, callback: Callable[..., Coroutine[Any, Any, Any]], name: str, aliases: list[str]): + self.subcommands: dict[str, Command[ClientT_Co_D]] = {} + super().__init__(callback, name, aliases=aliases) + + def command(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Command[ClientT_Co_D]] = Command[ClientT_Co_D]) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Command[ClientT_Co_D]]: + """A decorator that turns a function into a :class:`Command` and registers the command as a subcommand. + + Parameters + ----------- + name: Optional[:class:`str`] + The name of the command, this defaults to the functions name + aliases: Optional[list[:class:`str`]] + The aliases of the command, defaults to no aliases + cls: type[:class:`Command`] + The class used for creating the command, this defaults to :class:`Command` but can be used to use a custom command subclass + + Returns + -------- + Callable[Callable[..., Coroutine], :class:`Command`] + A function that takes the command callback and returns a :class:`Command` + """ + def inner(func: Callable[..., Coroutine[Any, Any, Any]]): + command = cls(func, name or func.__name__, aliases=aliases or []) + command.parent = self + self.subcommands[command.name] = command + + for alias in command.aliases: + self.subcommands[alias] = command + + return command + + return inner + + def group(self, *, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: Optional[type[Group[ClientT_Co_D]]] = None) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group[ClientT_Co_D]]: + """A decorator that turns a function into a :class:`Group` and registers the command as a subcommand + + Parameters + ----------- + name: Optional[:class:`str`] + The name of the group command, this defaults to the functions name + aliases: Optional[list[:class:`str`]] + The aliases of the group command, defaults to no aliases + cls: type[:class:`Group`] + The class used for creating the command, this defaults to :class:`Group` but can be used to use a custom group subclass + + Returns + -------- + Callable[Callable[..., Coroutine], :class:`Group`] + A function that takes the command callback and returns a :class:`Group` + """ + cls = cls or type(self) + + def inner(func: Callable[..., Coroutine[Any, Any, Any]]): + command = cls(func, name or func.__name__, aliases or []) + command.parent = self + self.subcommands[command.name] = command + + for alias in command.aliases: + self.subcommands[alias] = command + + return command + + return inner + + def __repr__(self) -> str: + return f"" + + @property + def commands(self) -> list[Command[ClientT_Co_D]]: + """Gets all commands registered + + Returns + -------- + list[:class:`Command`] + The registered commands + """ + return list(set(self.subcommands.values())) + + def get_command(self, name: str) -> Command[ClientT_Co_D]: + """Gets a command. + + Parameters + ----------- + name: :class:`str` + The name or alias of the command + + Returns + -------- + :class:`Command` + The command with the name + """ + return self.subcommands[name] + + def add_command(self, command: Command[ClientT_Co_D]) -> None: + """Adds a command, this is typically only used for dynamic commands, you should use the `commands.command` decorator for most usecases. + + Parameters + ----------- + name: :class:`str` + The name or alias of the command + command: :class:`Command` + The command to be added + """ + self.subcommands[command.name] = command + + for alias in command.aliases: + self.subcommands[alias] = command + + def remove_command(self, name: str) -> Optional[Command[ClientT_Co_D]]: + """Removes a command. + + Parameters + ----------- + name: :class:`str` + The name or alias of the command + + Returns + -------- + Optional[:class:`Command`] + The command that was removed + """ + command = self.subcommands.pop(name, None) + + if command is not None: + for alias in command.aliases: + self.subcommands.pop(alias, None) + + return command + +def group(*, name: Optional[str] = None, aliases: Optional[list[str]] = None, cls: type[Group[ClientT_D]] = Group) -> Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group[ClientT_D]]: + """A decorator that turns a function into a :class:`Group` + + Parameters + ----------- + name: Optional[:class:`str`] + The name of the group command, this defaults to the functions name + aliases: Optional[list[:class:`str`]] + The aliases of the group command, defaults to no aliases + cls: type[:class:`Group`] + The class used for creating the command, this defaults to :class:`Group` but can be used to use a custom group subclass + + Returns + -------- + Callable[Callable[..., Coroutine], :class:`Group`] + A function that takes the command callback and returns a :class:`Group` + """ + + def inner(func: Callable[..., Coroutine[Any, Any, Any]]): + return cls(func, name or func.__name__, aliases or []) + + return inner diff --git a/next/ext/commands/help.py b/next/ext/commands/help.py new file mode 100644 index 0000000..a6379df --- /dev/null +++ b/next/ext/commands/help.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, Optional, TypedDict, Union, cast + +from typing_extensions import NotRequired + +from .cog import Cog +from .command import Command +from .context import Context +from .group import Group +from .utils import ClientT_Co_D, ClientT_D + +from next import File, Message, Messageable, MessageReply, SendableEmbed + +if TYPE_CHECKING: + from .cog import Cog + +__all__ = ("MessagePayload", "HelpCommand", "DefaultHelpCommand", "help_command_impl") + + +class MessagePayload(TypedDict): + content: str + embed: NotRequired[SendableEmbed] + embeds: NotRequired[list[SendableEmbed]] + attachments: NotRequired[list[File]] + replies: NotRequired[list[MessageReply]] + +class HelpCommand(ABC, Generic[ClientT_Co_D]): + @abstractmethod + async def create_global_help(self, context: Context[ClientT_Co_D], commands: dict[Optional[Cog[ClientT_Co_D]], list[Command[ClientT_Co_D]]]) -> Union[str, SendableEmbed, MessagePayload]: + raise NotImplementedError + + @abstractmethod + async def create_command_help(self, context: Context[ClientT_Co_D], command: Command[ClientT_Co_D]) -> Union[str, SendableEmbed, MessagePayload]: + raise NotImplementedError + + @abstractmethod + async def create_group_help(self, context: Context[ClientT_Co_D], group: Group[ClientT_Co_D]) -> Union[str, SendableEmbed, MessagePayload]: + raise NotImplementedError + + @abstractmethod + async def create_cog_help(self, context: Context[ClientT_Co_D], cog: Cog[ClientT_Co_D]) -> Union[str, SendableEmbed, MessagePayload]: + raise NotImplementedError + + async def send_help_command(self, context: Context[ClientT_Co_D], message_payload: MessagePayload) -> Message: + return await context.send(**message_payload) + + async def filter_commands(self, context: Context[ClientT_Co_D], commands: list[Command[ClientT_Co_D]]) -> list[Command[ClientT_Co_D]]: + filtered: list[Command[ClientT_Co_D]] = [] + + for command in commands: + if command.hidden: + continue + + try: + if await context.can_run(command): + filtered.append(command) + except Exception: + pass + + return filtered + + async def group_commands(self, context: Context[ClientT_Co_D], commands: list[Command[ClientT_Co_D]]) -> dict[Optional[Cog[ClientT_Co_D]], list[Command[ClientT_Co_D]]]: + cogs: dict[Optional[Cog[ClientT_Co_D]], list[Command[ClientT_Co_D]]] = {} + + for command in commands: + cogs.setdefault(command.cog, []).append(command) + + return cogs + + async def handle_message(self, context: Context[ClientT_Co_D], message: Message) -> None: + pass + + async def get_channel(self, context: Context) -> Messageable: + return context + + @abstractmethod + async def handle_no_command_found(self, context: Context[ClientT_Co_D], name: str) -> Union[str, SendableEmbed, MessagePayload]: + raise NotImplementedError + +class DefaultHelpCommand(HelpCommand[ClientT_Co_D]): + def __init__(self, default_cog_name: str = "No Cog"): + self.default_cog_name = default_cog_name + + async def create_global_help(self, context: Context[ClientT_Co_D], commands: dict[Optional[Cog[ClientT_Co_D]], list[Command[ClientT_Co_D]]]) -> Union[str, SendableEmbed, MessagePayload]: + lines = ["```"] + + for cog, cog_commands in commands.items(): + cog_lines: list[str] = [] + cog_lines.append(f"{cog.qualified_name if cog else self.default_cog_name}:") + + for command in cog_commands: + cog_lines.append(f" {command.name} - {command.short_description or 'No description'}") + + lines.append("\n".join(cog_lines)) + + lines.append("```") + return "\n".join(lines) + + async def create_cog_help(self, context: Context[ClientT_Co_D], cog: Cog[ClientT_Co_D]) -> Union[str, SendableEmbed, MessagePayload]: + lines = ["```"] + + lines.append(f"{cog.qualified_name}:") + + for command in cog.commands: + lines.append(f" {command.name} - {command.short_description or 'No description'}") + + lines.append("```") + return "\n".join(lines) + + async def create_command_help(self, context: Context[ClientT_Co_D], command: Command[ClientT_Co_D]) -> Union[str, SendableEmbed, MessagePayload]: + lines = ["```"] + + lines.append(f"{command.name}:") + lines.append(f" Usage: {command.get_usage()}") + + if command.aliases: + lines.append(f" Aliases: {', '.join(command.aliases)}") + + + if command.description: + lines.append(command.description) + + lines.append("```") + return "\n".join(lines) + + async def create_group_help(self, context: Context[ClientT_Co_D], group: Group[ClientT_Co_D]) -> Union[str, SendableEmbed, MessagePayload]: + lines = ["```"] + + lines.append(f"{group.name}:") + lines.append(f" Usage: {group.get_usage()}") + + if group.aliases: + lines.append(f" Aliases: {', '.join(group.aliases)}") + + if group.description: + lines.append(group.description) + + for command in group.commands: + lines.append(f" {command.name} - {command.short_description or 'No description'}") + + lines.append("```") + return "\n".join(lines) + + async def handle_no_command_found(self, context: Context[ClientT_Co_D], name: str) -> str: + return f"Command `{name}` not found." + +class HelpCommandImpl(Command[ClientT_Co_D]): + def __init__(self, client: ClientT_Co_D): + self.client = client + + async def callback(_: Union[ClientT_Co_D, Cog[ClientT_Co_D]], context: Context[ClientT_Co_D], *args: str) -> None: + await help_command_impl(context.client, context, *args) + + super().__init__(callback=callback, name="help", aliases=[]) + self.description: str | None = "Shows help for a command, cog or the entire bot" + + +async def help_command_impl(client: ClientT_D, context: Context[ClientT_D], *arguments: str) -> None: + help_command = client.help_command + + if not help_command: + return + + filtered_commands = await help_command.filter_commands(context, client.commands) + commands = await help_command.group_commands(context, filtered_commands) + + if not arguments: + payload = await help_command.create_global_help(context, commands) + + else: + parent: ClientT_D | Group[ClientT_D] = client + + for param in arguments: + try: + command = parent.get_command(param) + except LookupError: + try: + cog = client.get_cog(param) + except LookupError: + payload = await help_command.handle_no_command_found(context, param) + else: + payload = await help_command.create_cog_help(context, cog) + finally: + break + + if isinstance(command, Group): + command = cast(Group[ClientT_D], command) + parent = command + else: + payload = await help_command.create_command_help(context, command) + break + else: + + if TYPE_CHECKING: + command = cast(Command[ClientT_D], ...) + + if isinstance(command, Group): + payload = await help_command.create_group_help(context, command) + else: + payload = await help_command.create_command_help(context, command) + + if TYPE_CHECKING: + payload = cast(MessagePayload, ...) + + msg_payload: MessagePayload + + if isinstance(payload, str): + msg_payload = {"content": payload} + elif isinstance(payload, SendableEmbed): + msg_payload = {"embed": payload, "content": " "} + else: + msg_payload = payload + + message = await help_command.send_help_command(context, msg_payload) + await help_command.handle_message(context, message) diff --git a/next/ext/commands/utils.py b/next/ext/commands/utils.py new file mode 100644 index 0000000..1d940c7 --- /dev/null +++ b/next/ext/commands/utils.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from inspect import Parameter +from typing import TYPE_CHECKING, Any, Iterable + +from typing_extensions import TypeVar + +if TYPE_CHECKING: + from .client import CommandsClient + from .context import Context + + +__all__ = ("evaluate_parameters",) + +ClientT_Co = TypeVar("ClientT_Co", bound="CommandsClient", covariant=True) +ClientT_D = TypeVar("ClientT_D", bound="CommandsClient", default="CommandsClient") +ClientT_Co_D = TypeVar("ClientT_Co_D", bound="CommandsClient", default="CommandsClient", covariant=True) +ContextT = TypeVar("ContextT", bound="Context", default="Context") + +def evaluate_parameters(parameters: Iterable[Parameter], globals: dict[str, Any]) -> list[Parameter]: + new_parameters: list[Parameter] = [] + + for parameter in parameters: + if parameter.annotation is not parameter.empty: + if isinstance(parameter.annotation, str): + parameter = parameter.replace(annotation=eval(parameter.annotation, globals)) + + new_parameters.append(parameter) + + return new_parameters diff --git a/next/ext/commands/view.py b/next/ext/commands/view.py new file mode 100644 index 0000000..5f732d5 --- /dev/null +++ b/next/ext/commands/view.py @@ -0,0 +1,62 @@ +from typing import Iterator +from typing_extensions import Self + +from .errors import NoClosingQuote + + +class StringView: + def __init__(self, string: str): + self.value: Iterator[str] = iter(string) + self.temp: str = "" + self.should_undo: bool = False + + def undo(self) -> None: + self.should_undo = True + + def next_char(self) -> str: + return next(self.value) + + def get_rest(self) -> str: + if self.should_undo: + return f"{self.temp} {''.join(self.value)}".rstrip() + # prevent a new space appearing at end if the buffer is depleted + + return "".join(self.value) + + def get_next_word(self) -> str: + if self.should_undo: + self.should_undo = False + return self.temp + + char = self.next_char() + temp: list[str] = [] + + while char == " ": + char = self.next_char() + + if char in ["\"", "'"]: + quote = char + try: + while (char := self.next_char()) != quote: + temp.append(char) + except StopIteration: + raise NoClosingQuote + + else: + temp.append(char) + try: + while (char := self.next_char()) not in " \n": + temp.append(char) + except StopIteration: + pass + + output = "".join(temp) + self.temp = output + + return output + + def __iter__(self) -> Self: + return self + + def __next__(self) -> str: + return self.get_next_word() \ No newline at end of file diff --git a/next/file.py b/next/file.py new file mode 100644 index 0000000..03ddd8e --- /dev/null +++ b/next/file.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import io +from typing import Optional, Union, cast + +__all__ = ("File",) + +class File: + """Respresents a file about to be uploaded to next + + Parameters + ----------- + file: Union[str, bytes] + The name of the file or the content of the file in bytes, text files will be need to be encoded + filename: Optional[str] + The filename of the file when being uploaded, this will default to the name of the file if one exists + spoiler: bool + Determines if the file will be a spoiler, this prefexes the filename with `SPOILER_` + """ + __slots__ = ("f", "spoiler", "filename") + + def __init__(self, file: Union[str, bytes], *, filename: Optional[str] = None, spoiler: bool = False): + self.f: io.BufferedIOBase + + if isinstance(file, str): + self.f = open(file, "rb") + else: + self.f = io.BytesIO(file) + + if filename is None and isinstance(file, str): + filename = cast(Optional[str], self.f.name) + + self.spoiler: bool = spoiler or (bool(filename) and filename.startswith("SPOILER_")) + + if self.spoiler and (filename and not filename.startswith("SPOILER_")): + filename = f"SPOILER_{filename}" + + self.filename: str | None = filename diff --git a/next/flags.py b/next/flags.py new file mode 100644 index 0000000..53f8219 --- /dev/null +++ b/next/flags.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +from typing import Callable, Iterator, Optional, Union, overload + +from typing_extensions import Self + +__all__ = ("Flag", "Flags", "UserBadges") + + +class Flag: + __slots__ = ("flag", "__doc__") + + def __init__(self, func: Callable[[], int]): + self.flag: int = func() + self.__doc__: str | None = func.__doc__ + + @overload + def __get__(self: Self, instance: None, owner: type[Flags]) -> Self: + ... + + @overload + def __get__(self, instance: Flags, owner: type[Flags]) -> bool: + ... + + def __get__(self: Self, instance: Optional[Flags], owner: type[Flags]) -> Union[Self, bool]: + if instance is None: + return self + + return instance._check_flag(self.flag) + + def __set__(self, instance: Flags, value: bool) -> None: + instance._set_flag(self.flag, value) + +class Flags: + FLAG_NAMES: list[str] + + def __init_subclass__(cls) -> None: + cls.FLAG_NAMES = [] + + for name in dir(cls): + value = getattr(cls, name) + + if isinstance(value, Flag): + cls.FLAG_NAMES.append(name) + + def __init__(self, value: int = 0, **flags: bool): + self.value = value + + for k, v in flags.items(): + setattr(self, k, v) + + @classmethod + def _from_value(cls, value: int) -> Self: + self = cls.__new__(cls) + self.value = value + return self + + def _check_flag(self, flag: int) -> bool: + return (self.value & flag) == flag + + def _set_flag(self, flag: int, value: bool) -> None: + if value: + self.value |= flag + else: + self.value &= ~flag + + def __eq__(self, other: Self) -> bool: + return self.value == other.value + + def __ne__(self, other: Self) -> bool: + return not self.__eq__(other) + + def __or__(self, other: Self) -> Self: + return self.__class__._from_value(self.value | other.value) + + def __and__(self, other: Self) -> Self: + return self.__class__._from_value(self.value & other.value) + + def __invert__(self) -> Self: + return self.__class__._from_value(~self.value) + + def __add__(self, other: Self) -> Self: + return self | other + + def __sub__(self, other: Self) -> Self: + return self & ~other + + def __lt__(self, other: Self) -> bool: + return self.value < other.value + + def __gt__(self, other: Self) -> bool: + return self.value > other.value + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} value={self.value}>" + + def __iter__(self) -> Iterator[tuple[str, bool]]: + for name, value in self.__class__.__dict__.items(): + if isinstance(value, Flag): + yield name, self._check_flag(value.flag) + + def __hash__(self) -> int: + return hash(self.value) + +class UserBadges(Flags): + """Contains all user badges""" + + @Flag + def developer(): + """:class:`bool` The developer badge.""" + return 1 << 0 + + @Flag + def translator(): + """:class:`bool` The translator badge.""" + return 1 << 1 + + @Flag + def supporter(): + """:class:`bool` The supporter badge.""" + return 1 << 2 + + @Flag + def responsible_disclosure(): + """:class:`bool` The responsible disclosure badge.""" + return 1 << 3 + + @Flag + def founder(): + """:class:`bool` The founder badge.""" + return 1 << 4 + + @Flag + def platform_moderation(): + """:class:`bool` The platform moderation badge.""" + return 1 << 5 + + @Flag + def active_supporter(): + """:class:`bool` The active supporter badge.""" + return 1 << 6 + + @Flag + def bug_hunter(): + """:class:`bool` The bug hunter badge.""" + return 1 << 7 + + @Flag + def early_adopter(): + """:class:`bool` The early adopter badge.""" + return 1 << 8 + + @Flag + def reserved_relevant_joke_badge_1(): + """:class:`bool` The reserved relevant joke badge 1 badge.""" + return 1 << 9 diff --git a/next/http.py b/next/http.py new file mode 100644 index 0000000..e539a75 --- /dev/null +++ b/next/http.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +from typing import (TYPE_CHECKING, Any, Coroutine, Literal, Optional, TypeVar, + Union, overload) + +import aiohttp +import ulid + + +from .errors import Forbidden, HTTPError, ServerError +from .file import File + +try: + import ujson as _json +except ImportError: + import json as _json + +if TYPE_CHECKING: + import aiohttp + + from .enums import SortType + from .file import File + from .types import Autumn as AutumnPayload + from .types import Emoji as EmojiPayload + from .types import Interactions as InteractionsPayload + from .types import Masquerade as MasqueradePayload + from .types import Member as MemberPayload + from .types import Message as MessagePayload + from .types import SendableEmbed as SendableEmbedPayload + from .types import User as UserPayload + from .types import (Server, ServerBans, TextChannel, UserProfile, VoiceChannel, Member, Invite, ApiInfo, Channel, SavedMessages, + DMChannel, EmojiParent, GetServerMembers, GroupDMChannel, MessageReplyPayload, MessageWithUserData, PartialInvite, CreateRole) + +__all__ = ("HttpClient",) + +T = TypeVar("T") +Request = Coroutine[Any, Any, T] + +class HttpClient: + __slots__ = ("session", "token", "api_url", "api_info", "auth_header") + + def __init__(self, session: aiohttp.ClientSession, token: str, api_url: str, api_info: ApiInfo, bot: bool = True): + self.session: aiohttp.ClientSession = session + self.token: str = token + self.api_url: str = api_url + self.api_info: ApiInfo = api_info + self.auth_header: str = "x-bot-token" if bot else "x-session-token" + + async def request(self, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"], route: str, *, json: Optional[dict[str, Any]] = None, nonce: bool = True, params: Optional[dict[str, Any]] = None) -> Any: + url = f"{self.api_url}{route}" + + kwargs = {} + + headers = { + "User-Agent": "Next.py (https://github.com/avanpost200/next.py)", + self.auth_header: self.token + } + + if json: + headers["Content-Type"] = "application/json" + + if nonce: + json["nonce"] = ulid.new().str # type: ignore + + kwargs["data"] = _json.dumps(json) + + kwargs["headers"] = headers + + if params: + kwargs["params"] = params + + async with self.session.request(method, url, **kwargs) as resp: + text = await resp.text() + if text: + try: + response = _json.loads(await resp.text()) + except ValueError: + raise HTTPError(f"Invalid json response:\n{text}") from None + else: + response = text + + resp_code = resp.status + + if 200 <= resp_code <= 300: + return response + elif resp_code == 401: + raise Forbidden("401: Missing Permissions") + else: + raise HTTPError(resp_code) + + async def upload_file(self, file: File, tag: Literal["attachments", "avatars", "backgrounds", "icons", "banners", "emojis"]) -> AutumnPayload: + url = f"{self.api_info['features']['autumn']['url']}/{tag}" + + headers = { + "User-Agent": "Next.py (https://github.com/avanpost200/next.py)" + } + + form = aiohttp.FormData() + form.add_field("file", file.f.read(), filename=file.filename) + + async with self.session.post(url, data=form, headers=headers) as resp: + response: AutumnPayload = _json.loads(await resp.text()) + + resp_code = resp.status + + if resp_code == 400: + raise HTTPError(response) + elif 500 <= resp_code <= 600: + raise ServerError + else: + return response + + async def send_message(self, channel: str, content: Optional[str], embeds: Optional[list[SendableEmbedPayload]], attachments: Optional[list[File]], replies: Optional[list[MessageReplyPayload]], masquerade: Optional[MasqueradePayload], interactions: Optional[InteractionsPayload]) -> MessagePayload: + json: dict[str, Any] = {} + + if content: + json["content"] = content + + if embeds: + json["embeds"] = embeds + + if attachments: + attachment_ids: list[str] = [] + + for attachment in attachments: + data = await self.upload_file(attachment, "attachments") + attachment_ids.append(data["id"]) + + json["attachments"] = attachment_ids + + if replies: + json["replies"] = replies + + if masquerade: + json["masquerade"] = masquerade + + if interactions: + json["interactions"] = interactions + + return await self.request("POST", f"/channels/{channel}/messages", json=json) + + def edit_message(self, channel: str, message: str, content: Optional[str], embeds: Optional[list[SendableEmbedPayload]] = None) -> Request[None]: + json: dict[str, Any] = {} + + if content is not None: + json["content"] = content + + if embeds is not None: + json["embeds"] = embeds + + return self.request("PATCH", f"/channels/{channel}/messages/{message}", json=json) + + def delete_message(self, channel: str, message: str) -> Request[None]: + return self.request("DELETE", f"/channels/{channel}/messages/{message}") + + def fetch_message(self, channel: str, message: str) -> Request[MessagePayload]: + return self.request("GET", f"/channels/{channel}/messages/{message}") + + @overload + def fetch_messages( + self, + channel: str, + sort: SortType, + *, + limit: Optional[int] = ..., + before: Optional[str] = ..., + after: Optional[str] = ..., + nearby: Optional[str] = ..., + include_users: Literal[False] = ... + ) -> Request[list[MessagePayload]]: + ... + + @overload + def fetch_messages( + self, + channel: str, + sort: SortType, + *, + limit: Optional[int] = ..., + before: Optional[str] = ..., + after: Optional[str] = ..., + nearby: Optional[str] = ..., + include_users: Literal[True] = ... + ) -> Request[MessageWithUserData]: + ... + + def fetch_messages( + self, + channel: str, + sort: SortType, + *, + limit: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + nearby: Optional[str] = None, + include_users: bool = False + ) -> Request[Union[list[MessagePayload], MessageWithUserData]]: + + json: dict[str, Any] = {"sort": sort.value, "include_users": str(include_users)} + + if limit: + json["limit"] = limit + + if before: + json["before"] = before + + if after: + json["after"] = after + + if nearby: + json["nearby"] = nearby + + return self.request("GET", f"/channels/{channel}/messages", params=json) + + @overload + def search_messages( + self, + channel: str, + query: str, + *, + limit: Optional[int] = ..., + before: Optional[str] = ..., + after: Optional[str] = ..., + sort: Optional[SortType] = ..., + include_users: Literal[False] = ... + ) -> Request[list[MessagePayload]]: + ... + + @overload + def search_messages( + self, + channel: str, + query: str, + *, + limit: Optional[int] = ..., + before: Optional[str] = ..., + after: Optional[str] = ..., + sort: Optional[SortType] = ..., + include_users: Literal[True] = ... + ) -> Request[MessageWithUserData]: + ... + + def search_messages( + self, + channel: str, + query: str, + *, + limit: Optional[int] = None, + before: Optional[str] = None, + after: Optional[str] = None, + sort: Optional[SortType] = None, + include_users: bool = False + ) -> Request[Union[list[MessagePayload], MessageWithUserData]]: + + json: dict[str, Any] = {"query": query, "include_users": include_users} + + if limit: + json["limit"] = limit + + if before: + json["before"] = before + + if after: + json["after"] = after + + if sort: + json["sort"] = sort.value + + return self.request("POST", f"/channels/{channel}/search", json=json) + + async def request_file(self, url: str) -> bytes: + async with self.session.get(url) as resp: + return await resp.content.read() + + def fetch_user(self, user_id: str) -> Request[UserPayload]: + return self.request("GET", f"/users/{user_id}") + + def fetch_profile(self, user_id: str) -> Request[UserProfile]: + return self.request("GET", f"/users/{user_id}/profile") + + def fetch_default_avatar(self, user_id: str) -> Request[bytes]: + return self.request_file(f"{self.api_url}/users/{user_id}/default_avatar") + + def fetch_dm_channels(self) -> Request[list[Union[DMChannel, GroupDMChannel]]]: + return self.request("GET", "/users/dms") + + def open_dm(self, user_id: str) -> Request[DMChannel | SavedMessages]: + return self.request("GET", f"/users/{user_id}/dm") + + def fetch_channel(self, channel_id: str) -> Request[Channel]: + return self.request("GET", f"/channels/{channel_id}") + + def close_channel(self, channel_id: str) -> Request[None]: + return self.request("DELETE", f"/channels/{channel_id}") + + def fetch_server(self, server_id: str) -> Request[Server]: + return self.request("GET", f"/servers/{server_id}") + + def delete_leave_server(self, server_id: str) -> Request[None]: + return self.request("DELETE", f"/servers/{server_id}") + + @overload + def create_channel(self, server_id: str, channel_type: Literal["Text"], name: str, description: Optional[str]) -> Request[TextChannel]: + ... + + @overload + def create_channel(self, server_id: str, channel_type: Literal["Voice"], name: str, description: Optional[str]) -> Request[VoiceChannel]: + ... + + def create_channel(self, server_id: str, channel_type: Literal["Text", "Voice"], name: str, description: Optional[str]) -> Request[Union[TextChannel, VoiceChannel]]: + payload = { + "type": channel_type, + "name": name + } + + if description: + payload["description"] = description + + return self.request("POST", f"/servers/{server_id}/channels", json=payload) + + def fetch_server_invites(self, server_id: str) -> Request[list[PartialInvite]]: + return self.request("GET", f"/servers/{server_id}/invites") + + def fetch_member(self, server_id: str, member_id: str) -> Request[Member]: + return self.request("GET", f"/servers/{server_id}/members/{member_id}") + + def kick_member(self, server_id: str, member_id: str) -> Request[None]: + return self.request("DELETE", f"/servers/{server_id}/members/{member_id}") + + def fetch_members(self, server_id: str) -> Request[GetServerMembers]: + return self.request("GET", f"/servers/{server_id}/members") + + def ban_member(self, server_id: str, member_id: str, reason: Optional[str]) -> Request[GetServerMembers]: + payload = {"reason": reason} if reason else None + + return self.request("PUT", f"/servers/{server_id}/bans/{member_id}", json=payload, nonce=False) + + def unban_member(self, server_id: str, member_id: str) -> Request[None]: + return self.request("DELETE", f"/servers/{server_id}/bans/{member_id}") + + def fetch_bans(self, server_id: str) -> Request[ServerBans]: + return self.request("GET", f"/servers/{server_id}/bans") + + def create_role(self, server_id: str, name: str) -> Request[CreateRole]: + return self.request("POST", f"/servers/{server_id}/roles", json={"name": name}, nonce=False) + + def delete_role(self, server_id: str, role_id: str) -> Request[None]: + return self.request("DELETE", f"/servers/{server_id}/roles/{role_id}") + + def fetch_invite(self, code: str) -> Request[Invite]: + return self.request("GET", f"/invites/{code}") + + def delete_invite(self, code: str) -> Request[None]: + return self.request("DELETE", f"/invites/{code}") + + def edit_channel(self, channel_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[None]: + if remove: + values["remove"] = remove + + return self.request("PATCH", f"/channels/{channel_id}", json=values) + + def edit_role(self, server_id: str, role_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[None]: + if remove: + values["remove"] = remove + + return self.request("PATCH", f"/servers/{server_id}/roles/{role_id}", json=values) + + async def edit_self(self, remove: list[str] | None, values: dict[str, Any]) -> Request[None]: + if remove: + values["remove"] = remove + + if avatar := values.get("avatar"): + asset = await self.upload_file(avatar, "avatars") + values["avatar"] = asset["id"] + + if profile := values.get("profile"): + if background := profile.background(): + asset = await self.upload_file(background, "backgrounds") + profile["background"] = asset["id"] + + return await self.request("PATCH", "/users/@me", json=values) + + def set_guild_channel_default_permissions(self, channel_id: str, allow: int, deny: int) -> Request[None]: + return self.request("PUT", f"/channels/{channel_id}/permissions/default", json={"permissions": {"allow": allow, "deny": deny}}) + + def set_guild_channel_role_permissions(self, channel_id: str, role_id: str, allow: int, deny: int) -> Request[None]: + return self.request("PUT", f"/channels/{channel_id}/permissions/{role_id}", json={"permissions": {"allow": allow, "deny": deny}}) + + def set_group_channel_default_permissions(self, channel_id: str, value: int) -> Request[None]: + return self.request("PUT", f"/channels/{channel_id}/permissions/default", json={"permissions": value}) + + def set_server_role_permissions(self, server_id: str, role_id: str, allow: int, deny: int) -> Request[None]: + return self.request("PUT", f"/servers/{server_id}/permissions/{role_id}", json={"permissions": {"allow": allow, "deny": deny}}) + + def set_server_default_permissions(self, server_id: str, value: int) -> Request[None]: + return self.request("PUT", f"/servers/{server_id}/permissions/default", json={"permissions": value}) + + def add_reaction(self, channel_id: str, message_id: str, emoji: str) -> Request[None]: + return self.request("PUT", f"/channels/{channel_id}/messages/{message_id}/reactions/{emoji}") + + def remove_reaction(self, channel_id: str, message_id: str, emoji: str, user_id: Optional[str], remove_all: bool) -> Request[None]: + parameters: dict[str, str] = {} + + if user_id: + parameters["user_id"] = user_id + + parameters["remove_all"] = "true" if remove_all else "false" + + return self.request("DELETE", f"/channels/{channel_id}/messages/{message_id}/reactions/{emoji}", params=parameters) + + def remove_all_reactions(self, channel_id: str, message_id: str) -> Request[None]: + return self.request("DELETE", f"/channels/{channel_id}/messages/{message_id}/reactions") + + def delete_emoji(self, emoji_id: str) -> Request[None]: + return self.request("DELETE", f"/custom/emoji/{emoji_id}") + + def fetch_emoji(self, emoji_id: str) -> Request[EmojiPayload]: + return self.request("GET", f"/custom/emoji/{emoji_id}") + + async def create_emoji(self, name: str, file: File, nsfw: bool, parent: EmojiParent) -> EmojiPayload: + asset = await self.upload_file(file, "emojis") + + return await self.request("PUT", f"/custom/emoji/{asset['id']}", json={"name": name, "parent": parent, "nsfw": nsfw}) + + def edit_member(self, server_id: str, member_id: str, remove: list[str] | None, values: dict[str, Any]) -> Request[MemberPayload]: + if remove: + values["remove"] = remove + + return self.request("PATCH", f"/servers/{server_id}/members/{member_id}", json=values) + + def delete_messages(self, channel_id: str, messages: list[str]) -> Request[None]: + return self.request("DELETE", f"/channels/{channel_id}/messages/bulk", json={"ids": messages}) diff --git a/next/invite.py b/next/invite.py new file mode 100644 index 0000000..4e55f48 --- /dev/null +++ b/next/invite.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + + +from .asset import Asset +from .utils import Ulid + +if TYPE_CHECKING: + from .state import State + from .channel import Channel + from .server import Server + from .types import Invite as InvitePayload + from .user import User + + +__all__ = ("Invite",) + +class Invite(Ulid): + """Represents a server invite. + + Attributes + ----------- + code: :class:`str` + The code for the invite + id: :class:`str` + Alias for :attr:`code` + server: :class:`Server` + The server this invite is for + channel: :class:`Channel` + The channel this invite is for + user_name: :class:`str` + The name of the user who made the invite + user: Optional[:class:`User`] + The user who made the invite, this is only set if this was fetched via :meth:`Server.fetch_invites` + user_avatar: Optional[:class:`Asset`] + The invite creator's avatar, if any + member_count: :class:`int` + The member count of the server this invite is for + """ + + __slots__ = ("state", "code", "id", "server", "channel", "user_name", "user_avatar", "user", "member_count") + + def __init__(self, data: InvitePayload, code: str, state: State): + self.state: State = state + + self.code: str = code + self.id: str = code + self.server: Server = state.get_server(data["server_id"]) + self.channel: Channel = self.server.get_channel(data["channel_id"]) + + self.user_name: str = data["user_name"] + self.user: User | None = None + + self.user_avatar: Asset | None + + if avatar := data.get("user_avatar"): + self.user_avatar = Asset(avatar, state) + else: + self.user_avatar = None + + self.member_count: int = data["member_count"] + + @staticmethod + def _from_partial(code: str, server: str, creator: str, channel: str, state: State) -> Invite: + invite = Invite.__new__(Invite) + + invite.state = state + invite.code = code + invite.server = state.get_server(server) + invite.channel = state.get_channel(channel) + invite.user = state.get_user(creator) + invite.user_name = invite.user.name + invite.user_avatar = invite.user.avatar + invite.member_count = len(invite.server.members) + + return invite + + async def delete(self) -> None: + """Deletes the invite""" + await self.state.http.delete_invite(self.code) diff --git a/next/member.py b/next/member.py new file mode 100644 index 0000000..533c494 --- /dev/null +++ b/next/member.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, Optional + + +from .utils import _Missing, Missing, parse_timestamp + +from .asset import Asset +from .permissions import Permissions +from .permissions_calculator import calculate_permissions +from .user import User +from .file import File + +if TYPE_CHECKING: + from .channel import Channel + from .server import Server + from .state import State + from .types import File as FilePayload + from .types import Member as MemberPayload + from .role import Role + +__all__ = ("Member",) + +def flattern_user(member: Member, user: User) -> None: + for attr in user.__flattern_attributes__: + setattr(member, attr, getattr(user, attr)) + +class Member(User): + """Represents a member of a server, subclasses :class:`User` + + Attributes + ----------- + nickname: Optional[:class:`str`] + The nickname of the member if any + roles: list[:class:`Role`] + The roles of the member, ordered by the role's rank in decending order + server: :class:`Server` + The server the member belongs to + guild_avatar: Optional[:class:`Asset`] + The member's guild avatar if any + """ + __slots__ = ("state", "nickname", "roles", "server", "guild_avatar", "joined_at", "current_timeout") + + def __init__(self, data: MemberPayload, server: Server, state: State): + user = state.get_user(data["_id"]["user"]) + + # due to not having a user payload and only a user object we have to manually add all the attributes instead of calling User.__init__ + flattern_user(self, user) + user._members[server.id] = self + + self.state: State = state + + self.guild_avatar: Asset | None + + if avatar := data.get("avatar"): + self.guild_avatar = Asset(avatar, state) + else: + self.guild_avatar = None + + roles = [server.get_role(role_id) for role_id in data.get("roles", [])] + self.roles: list[Role] = sorted(roles, key=lambda role: role.rank, reverse=True) + + self.server: Server = server + self.nickname: str | None = data.get("nickname") + self.joined_at: datetime.datetime = parse_timestamp(data["joined_at"]) + + self.current_timeout: datetime.datetime | None + + if current_timeout := data.get("timeout"): + self.current_timeout = parse_timestamp(current_timeout) + else: + self.current_timeout = None + + @property + def avatar(self) -> Optional[Asset]: + """Optional[:class:`Asset`] The avatar the member is displaying, this includes guild avatars and masqueraded avatar""" + return self.masquerade_avatar or self.guild_avatar or self.original_avatar + + @property + def name(self) -> str: + """:class:`str` The name the user is displaying, this includes (in order) their masqueraded name, display name and orginal name""" + return self.nickname or self.display_name or self.masquerade_name or self.original_name + + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the given member.""" + return f"<@{self.id}>" + + def _update( + self, + *, + nickname: Optional[str] = None, + avatar: Optional[FilePayload] = None, + roles: Optional[list[str]] = None, + timeout: Optional[str | int] = None + ) -> None: + if nickname is not None: + self.nickname = nickname + + if avatar is not None: + self.guild_avatar = Asset(avatar, self.state) + + if roles is not None: + member_roles = [self.server.get_role(role_id) for role_id in roles] + self.roles = sorted(member_roles, key=lambda role: role.rank, reverse=True) + + if timeout is not None: + self.current_timeout = parse_timestamp(timeout) + + async def kick(self) -> None: + """Kicks the member from the server""" + await self.state.http.kick_member(self.server.id, self.id) + + async def ban(self, *, reason: Optional[str] = None) -> None: + """Bans the member from the server + + Parameters + ----------- + reason: Optional[:class:`str`] + The reason for the ban + """ + await self.state.http.ban_member(self.server.id, self.id, reason) + + async def unban(self) -> None: + """Unbans the member from the server""" + await self.state.http.unban_member(self.server.id, self.id) + + async def edit( + self, + *, + nickname: str | None | _Missing = Missing, + roles: list[Role] | None | _Missing = Missing, + avatar: File | None | _Missing = Missing, + timeout: datetime.timedelta | None | _Missing = Missing + ) -> None: + """Edits the member + + Parameters + ----------- + nickname: Union[:class:`str`, :class:`None`] + The new nickname, or :class:`None` to reset it + roles: Union[list[:class:`Role`], :class:`None`] + The new roles for the member, or :class:`None` to clear it + avatar: Union[:class:`File`, :class:`None`] + The new server avatar, or :class:`None` to reset it + timeout: Union[:class:`datetime.timedelta`, :class:`None`] + The new timeout length for the member, or :class:`None` to reset it + """ + remove: list[str] = [] + data: dict[str, Any] = {} + + if nickname is None: + remove.append("Nickname") + elif nickname is not Missing: + data["nickname"] = nickname + + if roles is None: + remove.append("Roles") + elif not isinstance(roles, _Missing): + data["roles"] = [role.id for role in roles] + + if avatar is None: + remove.append("Avatar") + elif not isinstance(avatar, _Missing): + data["avatar"] = (await self.state.http.upload_file(avatar, "avatars"))["id"] + + if timeout is None: + remove.append("Timeout") + elif not isinstance(timeout, _Missing): + data["timeout"] = (datetime.datetime.now(datetime.timezone.utc) + timeout).isoformat() + + await self.state.http.edit_member(self.server.id, self.id, remove, data) + + async def timeout(self, length: datetime.timedelta) -> None: + """Timeouts the member + + Parameters + ----------- + length: :class:`datetime.timedelta` + The length of the timeout + """ + ends_at = datetime.datetime.now(tz=datetime.timezone.utc) + length + + await self.state.http.edit_member(self.server.id, self.id, None, {"timeout": ends_at.isoformat()}) + + def get_permissions(self) -> Permissions: + """Gets the permissions for the member in the server + + Returns + -------- + :class:`Permissions` + The members permissions + """ + return calculate_permissions(self, self.server) + + def get_channel_permissions(self, channel: Channel) -> Permissions: + """Gets the permissions for the member in the server taking into account the channel as well + + Parameters + ----------- + channel: :class:`Channel` + The channel to calculate permissions with + + Returns + -------- + :class:`Permissions` + The members permissions + """ + return calculate_permissions(self, channel) + + def has_permissions(self, **permissions: bool) -> bool: + """Computes if the member has the specified permissions + + Parameters + ----------- + permissions: :class:`bool` + The permissions to check, this also accepted `False` if you need to check if the member does not have the permission + + Returns + -------- + :class:`bool` + Whether or not they have the permissions + """ + calculated_perms = self.get_permissions() + + return all([getattr(calculated_perms, key, False) == value for key, value in permissions.items()]) + + def has_channel_permissions(self, channel: Channel, **permissions: bool) -> bool: + """Computes if the member has the specified permissions, taking into account the channel as well + + Parameters + ----------- + channel: :class:`Channel` + The channel to calculate permissions with + permissions: :class:`bool` + The permissions to check, this also accepted `False` if you need to check if the member does not have the permission + + Returns + -------- + :class:`bool` + Whether or not they have the permissions + """ + calculated_perms = self.get_channel_permissions(channel) + + return all([getattr(calculated_perms, key, False) == value for key, value in permissions.items()]) diff --git a/next/message.py b/next/message.py new file mode 100644 index 0000000..30d904a --- /dev/null +++ b/next/message.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import datetime +from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union + + +from .asset import Asset, PartialAsset +from .channel import DMChannel, GroupDMChannel, TextChannel, SavedMessageChannel +from .embed import Embed, SendableEmbed, to_embed +from .utils import Ulid, parse_timestamp + +if TYPE_CHECKING: + from .server import Server + from .state import State + from .types import Embed as EmbedPayload + from .types import Interactions as InteractionsPayload + from .types import Masquerade as MasqueradePayload + from .types import Message as MessagePayload + from .types import MessageReplyPayload, SystemMessageContent + from .user import User + from .member import Member + +__all__ = ( + "Message", + "MessageReply", + "Masquerade", + "MessageInteractions" +) + +class Message(Ulid): + """Represents a message + + Attributes + ----------- + id: :class:`str` + The id of the message + content: :class:`str` + The content of the message, this will not include system message's content + attachments: list[:class:`Asset`] + The attachments of the message + embeds: list[Union[:class:`WebsiteEmbed`, :class:`ImageEmbed`, :class:`TextEmbed`, :class:`NoneEmbed`]] + The embeds of the message + channel: :class:`Messageable` + The channel the message was sent in + author: Union[:class:`Member`, :class:`User`] + The author of the message, will be :class:`User` in DMs + edited_at: Optional[:class:`datetime.datetime`] + The time at which the message was edited, will be None if the message has not been edited + raw_mentions: list[:class:`str`] + A list of ids of the mentions in this message + replies: list[:class:`Message`] + The message's this message has replied to, this may not contain all the messages if they are outside the cache + reply_ids: list[:class:`str`] + The message's ids this message has replies to + reactions: dict[str, list[:class:`User`]] + The reactions on the message + interactions: Optional[:class:`MessageInteractions`] + The interactions on the message, if any + """ + __slots__ = ("state", "id", "content", "attachments", "embeds", "channel", "author", "edited_at", "replies", "reply_ids", "reactions", "interactions") + + def __init__(self, data: MessagePayload, state: State): + self.state: State = state + + self.id: str = data["_id"] + self.content: str = data.get("content", "") + + self.system_content: SystemMessageContent | None = data.get("system") + + self.attachments: list[Asset] = [Asset(attachment, state) for attachment in data.get("attachments", [])] + self.embeds: list[Embed] = [to_embed(embed, state) for embed in data.get("embeds", [])] + + channel = state.get_channel(data["channel"]) + assert isinstance(channel, (TextChannel, GroupDMChannel, DMChannel, SavedMessageChannel)) + self.channel: TextChannel | GroupDMChannel | DMChannel | SavedMessageChannel = channel + + self.server_id: str | None = self.channel.server_id + + self.raw_mentions: list[str] = data.get("mentions", []) + + if self.system_content: + author_id: str = self.system_content.get("id", data["author"]) + else: + author_id = data["author"] + + if self.server_id: + author = state.get_member(self.server_id, author_id) + + else: + author = state.get_user(author_id) + + self.author: Member | User = author + + if masquerade := data.get("masquerade"): + if name := masquerade.get("name"): + self.author.masquerade_name = name + + if avatar := masquerade.get("avatar"): + self.author.masquerade_avatar = PartialAsset(avatar, state) + + if edited_at := data.get("edited"): + self.edited_at: Optional[datetime.datetime] = parse_timestamp(edited_at) + + self.replies: list[Message] = [] + self.reply_ids: list[str] = [] + + for reply in data.get("replies", []): + try: + message = state.get_message(reply) + self.replies.append(message) + except LookupError: + pass + + self.reply_ids.append(reply) + + reactions = data.get("reactions", {}) + + self.reactions: dict[str, list[User]] = {} + + for emoji, users in reactions.items(): + self.reactions[emoji] = [self.state.get_user(user_id) for user_id in users] + + self.interactions: MessageInteractions | None + + if interactions := data.get("interactions"): + self.interactions = MessageInteractions(reactions=interactions.get("reactions"), restrict_reactions=interactions.get("restrict_reactions", False)) + else: + self.interactions = None + + def _update(self, *, content: Optional[str] = None, embeds: Optional[list[EmbedPayload]] = None, edited: Optional[Union[str, int]] = None): + if content is not None: + self.content = content + + if embeds is not None: + self.embeds = [to_embed(embed, self.state) for embed in embeds] + + if edited is not None: + self.edited_at = parse_timestamp(edited) + + @property + def mentions(self) -> list[User | Member]: + """The users or members that where mentioned in the message + + Returns: list[Union[:class:`Member`, :class:`User`]] + """ + + mentions: list[User | Member] = [] + + if self.server_id: + for mention in self.raw_mentions: + try: + self.mentions.append(self.server.get_member(mention)) + except LookupError: + pass + + else: + for mention in self.raw_mentions: + try: + self.mentions.append(self.state.get_user(mention)) + except LookupError: + pass + + return mentions + + async def edit(self, *, content: Optional[str] = None, embeds: Optional[list[SendableEmbed]] = None) -> None: + """Edits the message. The bot can only edit its own message + + Parameters + ----------- + content: :class:`str` + The new content of the message + embeds: list[:class:`SendableEmbed`] + The new embeds of the message + """ + + new_embeds = [embed.to_dict() for embed in embeds] if embeds else None + + await self.state.http.edit_message(self.channel.id, self.id, content, new_embeds) + + async def delete(self) -> None: + """Deletes the message. The bot can only delete its own messages and messages it has permission to delete """ + await self.state.http.delete_message(self.channel.id, self.id) + + def reply(self, *args: Any, mention: bool = False, **kwargs: Any) -> Coroutine[Any, Any, Message]: + """Replies to this message, equivilant to: + + .. code-block:: python + + await channel.send(..., replies=[MessageReply(message, mention)]) + + """ + return self.channel.send(*args, **kwargs, replies=[MessageReply(self, mention)]) + + async def add_reaction(self, emoji: str) -> None: + """Adds a reaction to the message + + Parameters + ----------- + emoji: :class:`str` + The emoji to add as a reaction + """ + await self.state.http.add_reaction(self.channel.id, self.id, emoji) + + async def remove_reaction(self, emoji: str, user: Optional[User] = None, remove_all: bool = False) -> None: + """Removes a reaction from the message, this can remove either a specific users, the current users reaction or all of a specific emoji + + Parameters + ----------- + emoji: :class:`str` + The emoji to remove + user: Optional[:class:`User`] + The user to use for removing a reaction from + remove_all: bool + Whether or not to remove all reactions for that specific emoji + """ + await self.state.http.remove_reaction(self.channel.id, self.id, emoji, user.id if user else None, remove_all) + + async def remove_all_reactions(self) -> None: + """Removes all reactions from the message""" + await self.state.http.remove_all_reactions(self.channel.id, self.id) + + @property + def server(self) -> Server: + """:class:`Server` The server this voice channel belongs too + + Raises + ------- + :class:`LookupError` + Raises if the channel is not part of a server + """ + return self.channel.server + +class MessageReply: + """represents a reply to a message. + + Parameters + ----------- + message: :class:`Message` + The message being replied to. + mention: :class:`bool` + Whether the reply should mention the author of the message. Defaults to false. + """ + __slots__ = ("message", "mention") + + def __init__(self, message: Ulid, mention: bool = False): + self.message: Ulid = message + self.mention: bool = mention + + def to_dict(self) -> MessageReplyPayload: + return {"id": self.message.id, "mention": self.mention} + +class Masquerade: + """represents a message's masquerade. + + Parameters + ----------- + name: Optional[:class:`str`] + The name to display for the message + avatar: Optional[:class:`str`] + The avatar's url to display for the message + colour: Optional[:class:`str`] + The colour of the name, similar to role colours + """ + __slots__ = ("name", "avatar", "colour") + + def __init__(self, name: Optional[str] = None, avatar: Optional[str] = None, colour: Optional[str] = None): + self.name: str | None = name + self.avatar: str | None = avatar + self.colour: str | None = colour + + def to_dict(self) -> MasqueradePayload: + output: MasqueradePayload = {} + + if name := self.name: + output["name"] = name + + if avatar := self.avatar: + output["avatar"] = avatar + + if colour := self.colour: + output["colour"] = colour + + return output + +class MessageInteractions: + """Represents a message's interactions, this is for allowing preset reactions and restricting adding reactions to only those. + + Parameters + ----------- + reactions: Optional[list[:class:`str`]] + The preset reactions on the message + restrict_reactions: bool + Whether or not users can only react to the interaction's reactions + """ + __slots__ = ("reactions", "restrict_reactions") + + def __init__(self, *, reactions: Optional[list[str]] = None, restrict_reactions: bool = False): + self.reactions: list[str] | None = reactions + self.restrict_reactions: bool = restrict_reactions + + def to_dict(self) -> InteractionsPayload: + output: InteractionsPayload = {} + + if reactions := self.reactions: + output["reactions"] = reactions + + if restrict_reactions := self.restrict_reactions: + output["restrict_reactions"] = restrict_reactions + + return output diff --git a/next/messageable.py b/next/messageable.py new file mode 100644 index 0000000..c16e10c --- /dev/null +++ b/next/messageable.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from .enums import SortType + +if TYPE_CHECKING: + from .embed import SendableEmbed + from .file import File + from .message import Masquerade, Message, MessageInteractions, MessageReply + from .state import State + from .types.http import MessageWithUserData + + +__all__ = ("Messageable",) + +class Messageable: + """Base class for all channels that you can send messages in + + Attributes + ----------- + id: :class:`str` + The id of the channel + """ + state: State + + __slots__ = () + + async def _get_channel_id(self) -> str: + raise NotImplementedError + + async def send(self, content: Optional[str] = None, *, embeds: Optional[list[SendableEmbed]] = None, embed: Optional[SendableEmbed] = None, attachments: Optional[list[File]] = None, replies: Optional[list[MessageReply]] = None, reply: Optional[MessageReply] = None, masquerade: Optional[Masquerade] = None, interactions: Optional[MessageInteractions] = None) -> Message: + """Sends a message in a channel, you must send at least one of either `content`, `embeds` or `attachments` + + Parameters + ----------- + content: Optional[:class:`str`] + The content of the message, this will not include system message's content + attachments: Optional[list[:class:`File`]] + The attachments of the message + embed: Optional[:class:`SendableEmbed`] + The embed to send with the message + embeds: Optional[list[:class:`SendableEmbed`]] + The embeds to send with the message + replies: Optional[list[:class:`MessageReply`]] + The list of messages to reply to. + masquerade: Optional[:class:`Masquerade`] + The masquerade for the message, this can overwrite the username and avatar shown + interactions: Optional[:class:`MessageInteractions`] + The interactions for the message + + Returns + -------- + :class:`Message` + The message that was just sent + """ + if embed: + embeds = [embed] + + if reply: + replies = [reply] + + embed_payload = [embed.to_dict() for embed in embeds] if embeds else None + reply_payload = [reply.to_dict() for reply in replies] if replies else None + masquerade_payload = masquerade.to_dict() if masquerade else None + interactions_payload = interactions.to_dict() if interactions else None + + message = await self.state.http.send_message(await self._get_channel_id(), content, embed_payload, attachments, reply_payload, masquerade_payload, interactions_payload) + return self.state.add_message(message) + + + async def fetch_message(self, message_id: str) -> Message: + """Fetches a message from the channel + + Parameters + ----------- + message_id: :class:`str` + The id of the message you want to fetch + + Returns + -------- + :class:`Message` + The message with the matching id + """ + from .message import Message + + payload = await self.state.http.fetch_message(await self._get_channel_id(), message_id) + return Message(payload, self.state) + + def _add_missing_users(self, payload: MessageWithUserData): + for user in payload["users"]: + if user["_id"] not in self.state.users: + self.state.add_user(user) + + if members := payload.get("members", []): + server = self.state.get_server(members[0]["_id"]["server"]) + + for member in members: + if member["_id"]["user"] not in server._members: + server._add_member(member) + + async def history(self, *, sort: SortType = SortType.latest, limit: int = 100, before: Optional[str] = None, after: Optional[str] = None, nearby: Optional[str] = None) -> list[Message]: + """Fetches multiple messages from the channel's history + + Parameters + ----------- + sort: :class:`SortType` + The order to sort the messages in + limit: :class:`int` + How many messages to fetch + before: Optional[:class:`str`] + The id of the message which should come *before* all the messages to be fetched + after: Optional[:class:`str`] + The id of the message which should come *after* all the messages to be fetched + nearby: Optional[:class:`str`] + The id of the message which should be nearby all the messages to be fetched + + Returns + -------- + list[:class:`Message`] + The messages found in order of the sort parameter + """ + from .message import Message + + payload = await self.state.http.fetch_messages(await self._get_channel_id(), sort=sort, limit=limit, before=before, after=after, nearby=nearby, include_users=True) + self._add_missing_users(payload) + + return [Message(msg, self.state) for msg in payload["messages"]] + + async def search(self, query: str, *, sort: SortType = SortType.latest, limit: int = 100, before: Optional[str] = None, after: Optional[str] = None) -> list[Message]: + """searches the channel for a query + + Parameters + ----------- + query: :class:`str` + The query to search for in the channel + sort: :class:`SortType` + The order to sort the messages in + limit: :class:`int` + How many messages to fetch + before: Optional[:class:`str`] + The id of the message which should come *before* all the messages to be fetched + after: Optional[:class:`str`] + The id of the message which should come *after* all the messages to be fetched + + Returns + -------- + list[:class:`Message`] + The messages found in order of the sort parameter + """ + from .message import Message + + payload = await self.state.http.search_messages(await self._get_channel_id(), query, sort=sort, limit=limit, before=before, after=after, include_users=True) + self._add_missing_users(payload) + + return [Message(msg, self.state) for msg in payload["messages"]] + + async def delete_messages(self, messages: list[Message]) -> None: + """Bulk deletes messages from the channel + + .. note:: The messages must have been sent in the last 7 days. + + Parameters + ----------- + messages: list[:class:`Message`] + The messages for deletion, this can be up to 100 messages + """ + + await self.state.http.delete_messages(await self._get_channel_id(), [message.id for message in messages]) diff --git a/next/permissions.py b/next/permissions.py new file mode 100644 index 0000000..df4bfb1 --- /dev/null +++ b/next/permissions.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from typing_extensions import Self + +from .flags import Flag, Flags +from .types.permissions import Overwrite + +__all__ = ("Permissions", "PermissionsOverwrite", "UserPermissions") + +class UserPermissions(Flags): + """Permissions for users""" + + @Flag + def access() -> int: + return 1 << 0 + + @Flag + def view_profile() -> int: + return 1 << 1 + + @Flag + def send_message() -> int: + return 1 << 2 + + @Flag + def invite() -> int: + return 1 << 3 + + @classmethod + def all(cls) -> Self: + return cls(access=True, view_profile=True, send_message=True, invite=True) + +class Permissions(Flags): + """Server permissions for members and roles""" + + @Flag + def manage_channel() -> int: + return 1 << 0 + + @Flag + def manage_server() -> int: + return 1 << 1 + + @Flag + def manage_permissions() -> int: + return 1 << 2 + + @Flag + def manage_role() -> int: + return 1 << 3 + + @Flag + def kick_members() -> int: + return 1 << 6 + + @Flag + def ban_members() -> int: + return 1 << 7 + + @Flag + def timeout_members() -> int: + return 1 << 8 + + @Flag + def asign_roles() -> int: + return 1 << 9 + + @Flag + def change_nickname() -> int: + return 1 << 10 + + @Flag + def manage_nicknames() -> int: + return 1 << 11 + + @Flag + def change_avatars() -> int: + return 1 << 12 + + @Flag + def remove_avatars() -> int: + return 1 << 13 + + @Flag + def view_channel() -> int: + return 1 << 20 + + @Flag + def read_message_history() -> int: + return 1 << 21 + + @Flag + def send_messages() -> int: + return 1 << 22 + + @Flag + def manage_messages() -> int: + return 1 << 23 + + @Flag + def manage_webhooks() -> int: + return 1 << 24 + + @Flag + def invite_others() -> int: + return 1 << 25 + + @Flag + def send_embeds() -> int: + return 1 << 26 + + @Flag + def upload_files() -> int: + return 1 << 27 + + @Flag + def masquerade() -> int: + return 1 << 28 + + @Flag + def connect() -> int: + return 1 << 30 + + @Flag + def speak() -> int: + return 1 << 31 + + @Flag + def video() -> int: + return 1 << 32 + + @Flag + def mute_members() -> int: + return 1 << 33 + + @Flag + def deafen_members() -> int: + return 1 << 34 + + @Flag + def move_members() -> int: + return 1 << 35 + + @classmethod + def all(cls) -> Self: + return cls(0x000F_FFFF_FFFF_FFFF) + + @classmethod + def default_view_only(cls) -> Self: + return cls(view_channel=True, read_message_history=True) + + @classmethod + def default(cls) -> Self: + return cls.default_view_only() | cls(send_messages=True, invite_others=True, send_embeds=True, upload_files=True, connect=True, speak=True) + + @classmethod + def default_direct_message(cls) -> Self: + return cls.default_view_only() | cls(react=True, manage_channel=True) + +class PermissionsOverwrite: + """A permissions overwrite in a channel""" + + def __init__(self, allow: Permissions, deny: Permissions): + self._allow = allow + self._deny = deny + + for perm in Permissions.FLAG_NAMES: + if getattr(allow, perm): + value = True + elif getattr(deny, perm): + value = False + else: + value = None + + super().__setattr__(perm, value) + + def __setattr__(self, key: str, value: Any) -> None: + if key in Permissions.FLAG_NAMES: + if key is True: + setattr(self._allow, key, True) + super().__setattr__(key, True) + + elif key is False: + setattr(self._deny, key, True) + super().__setattr__(key, False) + + else: + setattr(self._allow, key, False) + setattr(self._deny, key, False) + super().__setattr__(key, None) + else: + super().__setattr__(key, value) + + if TYPE_CHECKING: + manage_channel: Optional[bool] + manage_server: Optional[bool] + manage_permissions: Optional[bool] + manage_role: Optional[bool] + kick_members: Optional[bool] + ban_members: Optional[bool] + timeout_members: Optional[bool] + asign_roles: Optional[bool] + change_nickname: Optional[bool] + manage_nicknames: Optional[bool] + change_avatars: Optional[bool] + remove_avatars: Optional[bool] + view_channel: Optional[bool] + read_message_history: Optional[bool] + send_messages: Optional[bool] + manage_messages: Optional[bool] + manage_webhooks: Optional[bool] + invite_others: Optional[bool] + send_embeds: Optional[bool] + upload_files: Optional[bool] + masquerade: Optional[bool] + connect: Optional[bool] + speak: Optional[bool] + video: Optional[bool] + mute_members: Optional[bool] + deafen_members: Optional[bool] + move_members: Optional[bool] + + def to_pair(self) -> tuple[Permissions, Permissions]: + return self._allow, self._deny + + @classmethod + def _from_overwrite(cls, overwrite: Overwrite) -> Self: + allow = Permissions(overwrite["a"]) + deny = Permissions(overwrite["d"]) + + return cls(allow, deny) diff --git a/next/permissions_calculator.py b/next/permissions_calculator.py new file mode 100644 index 0000000..71a959a --- /dev/null +++ b/next/permissions_calculator.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, cast + +from next.enums import ChannelType + +from .permissions import Permissions + +if TYPE_CHECKING: + from .channel import Channel, DMChannel, GroupDMChannel, ServerChannel + from .member import Member + from .server import Server + + +def calculate_permissions(member: Member, target: Server | Channel) -> Permissions: + if member.privileged: + return Permissions.all() + + from .server import Server + + if isinstance(target, Server): + if target.owner_id == member.id: + return Permissions.all() + + permissions = target.default_permissions + + for role in member.roles: + permissions = (permissions | role.permissions._allow) & (~role.permissions._deny) + + if member.current_timeout and member.current_timeout > datetime.now(): + permissions = permissions & Permissions.default_view_only() + + return permissions + + else: + channel_type = target.channel_type + + if channel_type is ChannelType.saved_messages: + return Permissions.all() + + elif channel_type is ChannelType.direct_message: + target = cast("DMChannel", target) + + user_permissions = target.recipient.get_permissions() + + if user_permissions.send_message: + return Permissions.default_direct_message() + + else: + return Permissions.default_view_only() + + elif channel_type is ChannelType.group: + target = cast("GroupDMChannel", target) + + if target.owner.id != member.id: + return Permissions.default_direct_message() + else: + if target.permissions.value == 0: + return Permissions.default_direct_message() + else: + return target.permissions + + else: + target = cast("ServerChannel", target) + server = target.server + + if server.owner_id == member.id: + return Permissions.all() + + else: + perms = calculate_permissions(member, server) + perms = (perms | target.default_permissions._allow) & (~target.default_permissions._deny) + + for role in server.roles[::-1]: + if overwrite :=target.permissions.get(role.id): + perms = (perms | overwrite._allow) & (~overwrite._deny) + + if member.current_timeout and member.current_timeout > datetime.now(): + perms = perms & Permissions(view_channel=True, read_message_history=True) + + return perms diff --git a/next/py.typed b/next/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/next/role.py b/next/role.py new file mode 100644 index 0000000..63e52b2 --- /dev/null +++ b/next/role.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from .permissions import Overwrite, PermissionsOverwrite +from .utils import Missing, Ulid + +if TYPE_CHECKING: + from .server import Server + from .state import State + from .types import Role as RolePayload + + +__all__ = ("Role",) + +class Role(Ulid): + """Represents a role + + Attributes + ----------- + id: :class:`str` + The id of the role + name: :class:`str` + The name of the role + colour: Optional[:class:`str`] + The colour of the role + hoist: :class:`bool` + Whether members with the role will display seperate from everyone else + rank: :class:`int` + The position of the role in the role heirarchy + server: :class:`Server` + The server the role belongs to + server_permissions: :class:`ServerPermissions` + The server permissions for the role + channel_permissions: :class:`ChannelPermissions` + The channel permissions for the role + """ + __slots__: tuple[str, ...] = ("id", "name", "colour", "hoist", "rank", "state", "server", "permissions") + + def __init__(self, data: RolePayload, role_id: str, server: Server, state: State): + self.state: State = state + self.id: str = role_id + self.name: str = data["name"] + self.colour: str | None = data.get("colour", None) + self.hoist: bool = data.get("hoist", False) + self.rank: int = data["rank"] + self.server: Server = server + self.permissions: PermissionsOverwrite = PermissionsOverwrite._from_overwrite(data.get("permissions", {"a": 0, "d": 0})) + + @property + def color(self) -> str | None: + return self.colour + + async def set_permissions_overwrite(self, *, permissions: PermissionsOverwrite) -> None: + """Sets the permissions for a role in a server. + Parameters + ----------- + server_permissions: Optional[:class:`ServerPermissions`] + The new server permissions for the role + channel_permissions: Optional[:class:`ChannelPermissions`] + The new channel permissions for the role + """ + allow, deny = permissions.to_pair() + await self.state.http.set_server_role_permissions(self.server.id, self.id, allow.value, deny.value) + + def _update(self, *, name: Optional[str] = None, colour: Optional[str] = None, hoist: Optional[bool] = None, rank: Optional[int] = None, permissions: Optional[Overwrite] = None) -> None: + if name is not None: + self.name = name + + if colour is not None: + self.colour = colour + + if hoist is not None: + self.hoist = hoist + + if rank is not None: + self.rank = rank + + if permissions is not None: + self.permissions = PermissionsOverwrite._from_overwrite(permissions) + + async def delete(self) -> None: + """Deletes the role""" + await self.state.http.delete_role(self.server.id, self.id) + + async def edit(self, **kwargs: Any) -> None: + """Edits the role + + Parameters + ----------- + name: str + The name of the role + colour: str + The colour of the role + hoist: bool + Whether the role should make the member display seperately in the member list + rank: int + The position of the role + """ + if kwargs.get("colour", Missing) is None: + remove = ["Colour"] + else: + remove = None + + await self.state.http.edit_role(self.server.id, self.id, remove, kwargs) diff --git a/next/server.py b/next/server.py new file mode 100644 index 0000000..bd4d218 --- /dev/null +++ b/next/server.py @@ -0,0 +1,472 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, cast + +from .asset import Asset +from .category import Category +from .invite import Invite +from .permissions import Permissions +from .role import Role +from .utils import Ulid +from .channel import Channel, TextChannel, VoiceChannel +from .member import Member + +if TYPE_CHECKING: + from .emoji import Emoji + from .file import File + from .state import State + from .types import Ban + from .types import Category as CategoryPayload + from .types import File as FilePayload + from .types import Server as ServerPayload + from .types import SystemMessagesConfig + from .types import Member as MemberPayload + +__all__ = ("Server", "SystemMessages", "ServerBan") + +class SystemMessages: + """Holds all the configuration for the server's system message channels""" + + def __init__(self, data: SystemMessagesConfig, state: State): + self.state: State = state + self.user_joined_id: str | None = data.get("user_joined") + self.user_left_id: str | None = data.get("user_left") + self.user_kicked_id: str | None = data.get("user_kicked") + self.user_banned_id: str | None = data.get("user_banned") + + @property + def user_joined(self) -> Optional[TextChannel]: + """The channel which user join messages get sent in + + Returns + -------- + Optional[:class:`TextChannel`] + The channel + """ + if not self.user_joined_id: + return + + channel = self.state.get_channel(self.user_joined_id) + assert isinstance(channel, TextChannel) + return channel + + @property + def user_left(self) -> Optional[TextChannel]: + """The channel which user leave messages get sent in + + Returns + -------- + Optional[:class:`TextChannel`] + The channel + """ + if not self.user_left_id: + return + + channel = self.state.get_channel(self.user_left_id) + assert isinstance(channel, TextChannel) + return channel + + @property + def user_kicked(self) -> Optional[TextChannel]: + """The channel which user kick messages get sent in + + Returns + -------- + Optional[:class:`TextChannel`] + The channel + """ + if not self.user_kicked_id: + return + + channel = self.state.get_channel(self.user_kicked_id) + assert isinstance(channel, TextChannel) + return channel + + @property + def user_banned(self) -> Optional[TextChannel]: + """The channel which user ban messages get sent in + + Returns + -------- + Optional[:class:`TextChannel`] + The channel + """ + if not self.user_banned_id: + return + + channel = self.state.get_channel(self.user_banned_id) + assert isinstance(channel, TextChannel) + return channel + +class Server(Ulid): + """Represents a server + + Attributes + ----------- + id: :class:`str` + The id of the server + name: :class:`str` + The name of the server + owner_id: :class:`str` + The owner's id of the server + description: Optional[:class:`str`] + The servers description + nsfw: :class:`bool` + Whether the server is nsfw or not + system_messages: :class:`SystemMessages` + The system message config for the server + icon: Optional[:class:`Asset`] + The servers icon + banner: Optional[:class:`Asset`] + The servers banner + default_permissions: :class:`Permissions` + The permissions for the default role + """ + __slots__ = ("state", "id", "name", "owner_id", "default_permissions", "_members", "_roles", "_channels", "description", "icon", "banner", "nsfw", "system_messages", "_categories", "_emojis") + + def __init__(self, data: ServerPayload, state: State): + self.state: State = state + self.id: str = data["_id"] + self.name: str = data["name"] + self.owner_id: str = data["owner"] + self.description: str | None = data.get("description") or None + self.nsfw: bool = data.get("nsfw", False) + self.system_messages: SystemMessages = SystemMessages(data.get("system_messages", cast("SystemMessagesConfig", {})), state) + self._categories: dict[str, Category] = {data["id"]: Category(data, state) for data in data.get("categories", [])} + self.default_permissions: Permissions = Permissions(data["default_permissions"]) + + self.icon: Asset | None + + if icon := data.get("icon"): + self.icon = Asset(icon, state) + else: + self.icon = None + + self.banner: Asset | None + + if banner := data.get("banner"): + self.banner = Asset(banner, state) + else: + self.banner = None + + self._members: dict[str, Member] = {} + self._roles: dict[str, Role] = {role_id: Role(role, role_id, self, state) for role_id, role in data.get("roles", {}).items()} + + self._channels: dict[str, Channel] = {} + + # The api doesnt send us all the channels but sends us all the ids, this is because channels we dont have permissions to see are not sent + # this causes get_channel to error so we have to first check ourself if its in the cache. + + for channel_id in data["channels"]: + if channel := state.channels.get(channel_id): + self._channels[channel_id] = channel + + self._emojis: dict[str, Emoji] = {} + + def _update(self, *, owner: Optional[str] = None, name: Optional[str] = None, description: Optional[str] = None, icon: Optional[FilePayload] = None, banner: Optional[FilePayload] = None, default_permissions: Optional[int] = None, nsfw: Optional[bool] = None, system_messages: Optional[SystemMessagesConfig] = None, categories: Optional[list[CategoryPayload]] = None, channels: Optional[list[str]] = None): + if owner is not None: + self.owner_id = owner + if name is not None: + self.name = name + if description is not None: + self.description = description or None + if icon is not None: + self.icon = Asset(icon, self.state) + if banner is not None: + self.banner = Asset(banner, self.state) + if default_permissions is not None: + self.default_permissions = Permissions(default_permissions) + if nsfw is not None: + self.nsfw = nsfw + if system_messages is not None: + self.system_messages = SystemMessages(system_messages, self.state) + if categories is not None: + self._categories = {data["id"]: Category(data, self.state) for data in categories} + if channels is not None: + self._channels = {channel_id: self.state.get_channel(channel_id) for channel_id in channels} + + def _add_member(self, payload: MemberPayload) -> Member: + member = Member(payload, self, self.state) + self._members[member.id] = member + + return member + + @property + def roles(self) -> list[Role]: + """list[:class:`Role`] Gets all roles in the server in decending order""" + return list(self._roles.values()) + + @property + def members(self) -> list[Member]: + """list[:class:`Member`] Gets all members in the server""" + return list(self._members.values()) + + @property + def channels(self) -> list[Channel]: + """list[:class:`Member`] Gets all channels in the server""" + return list(self._channels.values()) + + @property + def categories(self) -> list[Category]: + """list[:class:`Category`] Gets all categories in the server""" + return list(self._categories.values()) + + @property + def emojis(self) -> list[Emoji]: + """list[:class:`Emoji`] Gets all emojis in the server""" + return list(self._emojis.values()) + + def get_role(self, role_id: str) -> Role: + """Gets a role from the cache + + Parameters + ----------- + id: :class:`str` + The id of the role + + Returns + -------- + :class:`Role` + The role + """ + return self._roles[role_id] + + def get_member(self, member_id: str) -> Member: + """Gets a member from the cache + + Parameters + ----------- + id: :class:`str` + The id of the member + + Returns + -------- + :class:`Member` + The member + """ + try: + return self._members[member_id] + except KeyError: + raise LookupError from None + + def get_channel(self, channel_id: str) -> Channel: + """Gets a channel from the cache + + Parameters + ----------- + id: :class:`str` + The id of the channel + + Returns + -------- + :class:`Channel` + The channel + """ + try: + return self._channels[channel_id] + except KeyError: + raise LookupError from None + + def get_category(self, category_id: str) -> Category: + """Gets a category from the cache + + Parameters + ----------- + id: :class:`str` + The id of the category + + Returns + -------- + :class:`Category` + The category + """ + try: + return self._categories[category_id] + except KeyError: + raise LookupError from None + + def get_emoji(self, emoji_id: str) -> Emoji: + """Gets a emoji from the cache + + Parameters + ----------- + id: :class:`str` + The id of the emoji + + Returns + -------- + :class:`Emoji` + The emoji + """ + try: + return self._emojis[emoji_id] + except KeyError as e: + raise LookupError from e + + @property + def owner(self) -> Member: + """:class:`Member` The owner of the server""" + return self.get_member(self.owner_id) + + async def set_default_permissions(self, permissions: Permissions) -> None: + """Sets the default server permissions. + Parameters + ----------- + server_permissions: Optional[:class:`ServerPermissions`] + The new default server permissions + channel_permissions: Optional[:class:`ChannelPermissions`] + the new default channel permissions + """ + + await self.state.http.set_server_default_permissions(self.id, permissions.value) + + async def leave_server(self) -> None: + """Leaves or deletes the server""" + await self.state.http.delete_leave_server(self.id) + + async def delete_server(self) -> None: + """Leaves or deletes a server, alias to :meth`Server.leave_server`""" + await self.leave_server() + + async def create_text_channel(self, *, name: str, description: Optional[str] = None) -> TextChannel: + """Creates a text channel in the server + + Parameters + ----------- + name: :class:`str` + The name of the channel + description: Optional[:class:`str`] + The channel's description + + Returns + -------- + :class:`TextChannel` + The text channel that was just created + """ + payload = await self.state.http.create_channel(self.id, "Text", name, description) + + channel = TextChannel(payload, self.state) + self._channels[channel.id] = channel + + return channel + + async def create_voice_channel(self, *, name: str, description: Optional[str] = None) -> VoiceChannel: + """Creates a voice channel in the server + + Parameters + ----------- + name: :class:`str` + The name of the channel + description: Optional[:class:`str`] + The channel's description + + Returns + -------- + :class:`VoiceChannel` + The voice channel that was just created + """ + payload = await self.state.http.create_channel(self.id, "Voice", name, description) + + channel = self.state.add_channel(payload) + self._channels[channel.id] = channel + + return cast(VoiceChannel, channel) + + async def fetch_invites(self) -> list[Invite]: + """Fetches all invites in the server + + Returns + -------- + list[:class:`Invite`] + """ + invite_payloads = await self.state.http.fetch_server_invites(self.id) + + return [Invite._from_partial(payload["_id"], payload["server"], payload["creator"], payload["channel"], self.state) for payload in invite_payloads] + + async def fetch_member(self, member_id: str) -> Member: + """Fetches a member from this server + + Parameters + ----------- + member_id: :class:`str` + The id of the member you are fetching + + Returns + -------- + :class:`Member` + The member with the matching id + """ + payload = await self.state.http.fetch_member(self.id, member_id) + + return Member(payload, self, self.state) + + async def fetch_bans(self) -> list[ServerBan]: + """Fetches all bans in the server + + Returns + -------- + list[:class:`ServerBan`] + """ + payload = await self.state.http.fetch_bans(self.id) + + return [ServerBan(ban, self.state) for ban in payload["bans"]] + + async def create_role(self, name: str) -> Role: + """Creates a role in the server + + Parameters + ----------- + name: :class:`str` + The name of the role + + + Returns + -------- + :class:`Role` + The role that was just created + """ + payload = await self.state.http.create_role(self.id, name) + + return Role(payload["role"], payload["id"], self, self.state) + + async def create_emoji(self, name: str, file: File, *, nsfw: bool = False) -> Emoji: + """Creates an emoji + + Parameters + ----------- + name: :class:`str` + The name for the emoji + file: :class:`File` + The image for the emoji + nsfw: :class:`bool` + Whether or not the emoji is nsfw + """ + payload = await self.state.http.create_emoji(name, file, nsfw, {"type": "Server", "id": self.id}) + + return self.state.add_emoji(payload) + + +class ServerBan: + """Represents a server ban + + Attributes + ----------- + reason: Optional[:class:`str`] + The reason the user was banned + server: :class:`Server` + The server the user was banned in + user_id: :class:`str` + The id of the user who was banned + """ + + __slots__ = ("reason", "server", "user_id", "state") + + def __init__(self, ban: Ban, state: State): + self.reason: str | None = ban.get("reason") + self.server: Server = state.get_server(ban["_id"]["server"]) + self.user_id: str = ban["_id"]["user"] + self.state: State = state + + async def unban(self) -> None: + """Unbans the user""" + await self.state.http.unban_member(self.server.id, self.user_id) diff --git a/next/state.py b/next/state.py new file mode 100644 index 0000000..fc29a16 --- /dev/null +++ b/next/state.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING + +from .channel import Channel, channel_factory +from .emoji import Emoji +from .member import Member +from .message import Message +from .server import Server +from .user import User + +if TYPE_CHECKING: + from .http import HttpClient + from .types import ApiInfo + from .types import Channel as ChannelPayload + from .types import Emoji as EmojiPayload + from .types import Member as MemberPayload + from .types import Message as MessagePayload + from .types import Server as ServerPayload + from .types import User as UserPayload + +__all__ = ("State",) + +class State: + __slots__ = ("http", "api_info", "max_messages", "users", "channels", "servers", "messages", "global_emojis", "user_id", "me") + + def __init__(self, http: HttpClient, api_info: ApiInfo, max_messages: int): + self.http: HttpClient = http + self.api_info: ApiInfo = api_info + self.max_messages: int = max_messages + + self.me: User + + self.users: dict[str, User] = {} + self.channels: dict[str, Channel] = {} + self.servers: dict[str, Server] = {} + self.messages: deque[Message] = deque() + self.global_emojis: list[Emoji] = [] + + def get_user(self, id: str) -> User: + try: + return self.users[id] + except KeyError: + raise LookupError from None + + def get_member(self, server_id: str, member_id: str) -> Member: + server = self.servers[server_id] + return server.get_member(member_id) + + def get_channel(self, id: str) -> Channel: + try: + return self.channels[id] + except KeyError: + raise LookupError from None + + def get_server(self, id: str) -> Server: + try: + return self.servers[id] + except KeyError: + raise LookupError from None + + def add_user(self, payload: UserPayload) -> User: + + + user = User(payload, self) + + if payload.get("relationship") == "User": + self.me = user + + self.users[user.id] = user + return user + + def add_member(self, server_id: str, payload: MemberPayload) -> Member: + server = self.get_server(server_id) + + return server._add_member(payload) + + def add_channel(self, payload: ChannelPayload) -> Channel: + channel = channel_factory(payload, self) + self.channels[channel.id] = channel + return channel + + def add_server(self, payload: ServerPayload) -> Server: + server = Server(payload, self) + self.servers[server.id] = server + return server + + def add_message(self, payload: MessagePayload) -> Message: + message = Message(payload, self) + if len(self.messages) >= self.max_messages: + self.messages.pop() + + self.messages.appendleft(message) + return message + + def add_emoji(self, payload: EmojiPayload) -> Emoji: + emoji = Emoji(payload, self) + + if server_id := emoji.server_id: + server = self.get_server(server_id) + server._emojis[emoji.id] = emoji + else: + self.global_emojis.append(emoji) + + return emoji + + def get_message(self, message_id: str) -> Message: + for msg in self.messages: + if msg.id == message_id: + return msg + + raise LookupError + + async def fetch_server_members(self, server_id: str) -> None: + data = await self.http.fetch_members(server_id) + + for user in data["users"]: + self.add_user(user) + + for member in data["members"]: + self.add_member(server_id, member) + + async def fetch_all_server_members(self) -> None: + for server_id in self.servers: + await self.fetch_server_members(server_id) diff --git a/next/types/__init__.py b/next/types/__init__.py new file mode 100644 index 0000000..b8ac014 --- /dev/null +++ b/next/types/__init__.py @@ -0,0 +1,14 @@ +from .category import * +from .channel import * +from .embed import * +from .emoji import * +from .file import * +from .gateway import * +from .http import * +from .invite import * +from .member import * +from .message import * +from .permissions import * +from .role import * +from .server import * +from .user import * diff --git a/next/types/category.py b/next/types/category.py new file mode 100644 index 0000000..be8cf60 --- /dev/null +++ b/next/types/category.py @@ -0,0 +1,8 @@ +from typing import TypedDict + +__all__ = ("Category",) + +class Category(TypedDict): + id: str + title: str + channels: list[str] diff --git a/next/types/channel.py b/next/types/channel.py new file mode 100644 index 0000000..1b330bc --- /dev/null +++ b/next/types/channel.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict, Union + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .file import File + from .permissions import Overwrite + +__all__ = ( + "SavedMessages", + "DMChannel", + "GroupDMChannel", + "TextChannel", + "VoiceChannel", + "ServerChannel", + "Channel", +) + +class BaseChannel(TypedDict): + _id: str + nonce: str + +class SavedMessages(BaseChannel): + user: str + channel_type: Literal["SavedMessages"] + +class DMChannel(BaseChannel): + active: bool + recipients: list[str] + last_message_id: NotRequired[str] + channel_type: Literal["DirectMessage"] + +class GroupDMChannel(BaseChannel): + recipients: list[str] + name: str + owner: str + channel_type: Literal["Group"] + icon: NotRequired[File] + permissions: NotRequired[int] + description: NotRequired[str] + nsfw: NotRequired[bool] + last_message_id: NotRequired[str] + +class TextChannel(BaseChannel): + server: str + name: str + description: str + channel_type: Literal["TextChannel"] + icon: NotRequired[File] + default_permissions: NotRequired[Overwrite] + role_permissions: NotRequired[dict[str, Overwrite]] + nsfw: NotRequired[bool] + last_message_id: NotRequired[str] + +class VoiceChannel(BaseChannel): + server: str + name: str + description: str + channel_type: Literal["VoiceChannel"] + icon: NotRequired[File] + default_permissions: NotRequired[Overwrite] + role_permissions: NotRequired[dict[str, Overwrite]] + nsfw: NotRequired[bool] + +ServerChannel = Union[TextChannel, VoiceChannel] +Channel = Union[SavedMessages, DMChannel, GroupDMChannel, TextChannel, VoiceChannel] diff --git a/next/types/embed.py b/next/types/embed.py new file mode 100644 index 0000000..9bb3a9e --- /dev/null +++ b/next/types/embed.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict, Union + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .file import File + +__all__ = ("Embed", "SendableEmbed", "WebsiteEmbed", "ImageEmbed", "TextEmbed", "NoneEmbed", "YoutubeSpecial", "TwitchSpecial", "SpotifySpecial", "SoundcloudSpecial", "BandcampSpecial", "WebsiteSpecial", "JanuaryImage", "JanuaryVideo") + +class YoutubeSpecial(TypedDict): + type: Literal["Youtube"] + id: str + timestamp: NotRequired[str] + +class TwitchSpecial(TypedDict): + type: Literal["Twitch"] + content_type: Literal["Channel", "Video", "Clip"] + id: str + +class SpotifySpecial(TypedDict): + type: Literal["Spotify"] + content_type: str + id: str + +class SoundcloudSpecial(TypedDict): + type: Literal["Soundcloud"] + +class BandcampSpecial(TypedDict): + type: Literal["Bandcamp"] + content_type: Literal["Album", "Track"] + id: str + +WebsiteSpecial = Union[YoutubeSpecial, TwitchSpecial, SpotifySpecial, SoundcloudSpecial, BandcampSpecial] + +class JanuaryImage(TypedDict): + url: str + width: int + height: int + size: Literal["Large", "Preview"] + +class JanuaryVideo(TypedDict): + url: str + width: int + height: int + +class WebsiteEmbed(TypedDict): + type: Literal["Website"] + url: NotRequired[str] + special: NotRequired[WebsiteSpecial] + title: NotRequired[str] + description: NotRequired[str] + image: NotRequired[JanuaryImage] + video: NotRequired[JanuaryVideo] + site_name: NotRequired[str] + icon_url: NotRequired[str] + colour: NotRequired[str] + +class ImageEmbed(JanuaryImage): + type: Literal["Image"] + +class TextEmbed(TypedDict): + type: Literal["Text"] + icon_url: NotRequired[str] + url: NotRequired[str] + title: NotRequired[str] + description: NotRequired[str] + media: NotRequired[File] + colour: NotRequired[str] + +class NoneEmbed(TypedDict): + type: Literal["None"] + +Embed = Union[WebsiteEmbed, ImageEmbed, TextEmbed, NoneEmbed] + +class SendableEmbed(TypedDict): + type: Literal["Text"] + icon_url: NotRequired[str] + url: NotRequired[str] + title: NotRequired[str] + description: NotRequired[str] + media: NotRequired[str] + colour: NotRequired[str] + diff --git a/next/types/emoji.py b/next/types/emoji.py new file mode 100644 index 0000000..9b95df3 --- /dev/null +++ b/next/types/emoji.py @@ -0,0 +1,21 @@ +from typing import Literal, TypedDict, Union + +from typing_extensions import NotRequired + + +class EmojiParentServer(TypedDict): + type: Literal["Server"] + id: str + +class EmojiParentDetached(TypedDict): + type: Literal["Detached"] + +EmojiParent = Union[EmojiParentServer, EmojiParentDetached] + +class Emoji(TypedDict): + _id: str + parent: EmojiParent + creator_id: str + name: str + animated: NotRequired[bool] + nsfw: NotRequired[bool] diff --git a/next/types/file.py b/next/types/file.py new file mode 100644 index 0000000..1466851 --- /dev/null +++ b/next/types/file.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Literal, TypedDict, Union + +__all__ = ("File",) + +class SizedMetadata(TypedDict): + type: Literal["Image", "Video"] + height: int + width: int + +class SimpleMetadata(TypedDict): + type: Literal["File", "Text", "Audio"] + +FileMetadata = Union[SizedMetadata, SimpleMetadata] + +class File(TypedDict): + _id: str + tag: str + size: int + filename: str + metadata: FileMetadata + content_type: str diff --git a/next/types/gateway.py b/next/types/gateway.py new file mode 100644 index 0000000..ff622ad --- /dev/null +++ b/next/types/gateway.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict, Union + +from typing_extensions import NotRequired + +from .channel import Channel, DMChannel, GroupDMChannel, SavedMessages, TextChannel, VoiceChannel +from .message import Message +from .permissions import Overwrite + +if TYPE_CHECKING: + from .category import Category + from .embed import Embed + from .emoji import Emoji + from .file import File + from .member import Member, MemberID + from .server import Server, SystemMessagesConfig + from .user import Status, User, UserProfile, UserRelation + +__all__ = ( + "BasePayload", + "AuthenticatePayload", + "ReadyEventPayload", + "MessageEventPayload", + "MessageUpdateData", + "MessageUpdateEventPayload", + "MessageDeleteEventPayload", + "ChannelCreateEventPayload", + "ChannelUpdateEventPayload", + "ChannelDeleteEventPayload", + "ChannelStartTypingEventPayload", + "ChannelDeleteTypingEventPayload", + "ServerUpdateEventPayload", + "ServerDeleteEventPayload", + "ServerMemberUpdateEventPayload", + "ServerMemberJoinEventPayload", + "ServerMemberLeaveEventPayload", + "ServerRoleUpdateEventPayload", + "ServerRoleDeleteEventPayload", + "UserUpdateEventPayload", + "UserRelationshipEventPayload", + "ServerCreateEventPayload", + "MessageReactEventPayload", + "MessageUnreactEventPayload", + "MessageRemoveReactionEventPayload", + "BulkMessageDeleteEventPayload" +) + +class BasePayload(TypedDict): + type: str + +class AuthenticatePayload(BasePayload): + token: str + +class ReadyEventPayload(BasePayload): + users: list[User] + servers: list[Server] + channels: list[Channel] + members: list[Member] + emojis: list[Emoji] + +class MessageEventPayload(BasePayload, Message): + pass + +class MessageUpdateData(TypedDict): + content: str + embeds: list[Embed] + edited: Union[str, int] + +class MessageUpdateEventPayload(BasePayload): + channel: str + data: MessageUpdateData + id: str + +class MessageDeleteEventPayload(BasePayload): + channel: str + id: str + +class ChannelCreateEventPayload_SavedMessages(BasePayload, SavedMessages): + pass + +class ChannelCreateEventPayload_Group(BasePayload, GroupDMChannel): + pass + +class ChannelCreateEventPayload_TextChannel(BasePayload, TextChannel): + pass + +class ChannelCreateEventPayload_VoiceChannel(BasePayload, VoiceChannel): + pass + +class ChannelCreateEventPayload_DMChannel(BasePayload, DMChannel): + pass + +ChannelCreateEventPayload = Union[ChannelCreateEventPayload_Group, ChannelCreateEventPayload_Group, ChannelCreateEventPayload_TextChannel, ChannelCreateEventPayload_VoiceChannel, ChannelCreateEventPayload_DMChannel] + +class ChannelUpdateEventPayloadData(TypedDict, total=False): + name: str + description: str + icon: File + nsfw: bool + active: bool + role_permissions: dict[str, Overwrite] + default_permissions: Overwrite + +class ChannelUpdateEventPayload(BasePayload): + id: str + data: ChannelUpdateEventPayloadData + clear: Literal["Icon", "Description"] + +class ChannelDeleteEventPayload(BasePayload): + id: str + +class ChannelStartTypingEventPayload(BasePayload): + id: str + user: str + +ChannelDeleteTypingEventPayload = ChannelStartTypingEventPayload + +class ServerUpdateEventPayloadData(TypedDict, total=False): + owner: str + name: str + description: str + icon: File + banner: File + default_permissions: int + nsfw: bool + system_messages: SystemMessagesConfig + categories: list[Category] + +class ServerUpdateEventPayload(BasePayload): + id: str + data: ServerUpdateEventPayloadData + clear: Literal["Icon", "Banner", "Description"] + +class ServerDeleteEventPayload(BasePayload): + id: str + +class ServerCreateEventPayload(BasePayload): + id: str + server: Server + channels: list[Channel] + +class ServerMemberUpdateEventPayloadData(TypedDict, total=False): + nickname: str + avatar: File + roles: list[str] + timeout: str | int + +class ServerMemberUpdateEventPayload(BasePayload): + id: MemberID + data: ServerMemberUpdateEventPayloadData + clear: Literal["Nickname", "Avatar"] + +class ServerMemberJoinEventPayload(BasePayload): + id: str + user: str + +ServerMemberLeaveEventPayload = ServerMemberJoinEventPayload + +class ServerRoleUpdateEventPayloadData(TypedDict, total=False): + name: str + colour: str + hoist: bool + rank: int + +class ServerRoleUpdateEventPayload(BasePayload): + id: str + role_id: str + data: ServerRoleUpdateEventPayloadData + clear: Literal["Colour"] + +class ServerRoleDeleteEventPayload(BasePayload): + id: str + role_id: str + +class UserUpdateEventPayloadData(TypedDict): + status: NotRequired[Status] + avatar: NotRequired[File] + online: NotRequired[bool] + profile: NotRequired[UserProfile] + username: NotRequired[str] + display_name: NotRequired[str] + relations: NotRequired[list[UserRelation]] + badges: NotRequired[int] + online: NotRequired[bool] + flags: NotRequired[int] + discriminator: NotRequired[str] + privileged: NotRequired[bool] + +class UserUpdateEventPayload(BasePayload): + id: str + data: UserUpdateEventPayloadData + clear: Literal["ProfileContent", "ProfileBackground", "StatusText", "Avatar"] + +class UserRelationshipEventPayload(BasePayload): + id: str + user: str + status: Status + +class MessageReactEventPayload(BasePayload): + id: str + channel_id: str + user_id: str + emoji_id: str + +MessageUnreactEventPayload = MessageReactEventPayload + +class MessageRemoveReactionEventPayload(BasePayload): + id: str + channel_id: str + emoji_id: str + +class BulkMessageDeleteEventPayload(BasePayload): + channel: str + ids: list[str] diff --git a/next/types/http.py b/next/types/http.py new file mode 100644 index 0000000..0cdb5de --- /dev/null +++ b/next/types/http.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .member import Member + from .message import Message + from .user import User + from .role import Role + + +__all__ = ( + "VosoFeature", + "ApiInfo", + "Autumn", + "GetServerMembers", + "MessageWithUserData", + "CreateRole", +) + + +class ApiFeature(TypedDict): + enabled: bool + url: str + +class VosoFeature(ApiFeature): + ws: str + +class Features(TypedDict): + email: bool + invite_only: bool + captcha: ApiFeature + autumn: ApiFeature + january: ApiFeature + voso: VosoFeature + +class ApiInfo(TypedDict): + revolt: str + features: Features + ws: str + app: str + vapid: str + +class Autumn(TypedDict): + id: str + +class GetServerMembers(TypedDict): + members: list[Member] + users: list[User] + +class MessageWithUserData(TypedDict): + messages: list[Message] + members: NotRequired[list[Member]] + users: list[User] + +class CreateRole(TypedDict): + id: str + role: Role \ No newline at end of file diff --git a/next/types/invite.py b/next/types/invite.py new file mode 100644 index 0000000..d78e76e --- /dev/null +++ b/next/types/invite.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .file import File + +__all__ = ("Invite", "PartialInvite") + + +class Invite(TypedDict): + type: Literal["Server"] + server_id: str + server_name: str + server_icon: NotRequired[str] + server_banner: NotRequired[str] + channel_id: str + channel_name: str + channel_description: NotRequired[str] + user_name: str + user_avatar: NotRequired[File] + member_count: int + +class PartialInvite(TypedDict): + _id: str + server: str + channel: str + creator: str diff --git a/next/types/member.py b/next/types/member.py new file mode 100644 index 0000000..f36f486 --- /dev/null +++ b/next/types/member.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .file import File + + +__all__ = ("Member", "MemberID") + +class MemberID(TypedDict): + server: str + user: str + +class Member(TypedDict): + _id: MemberID + nickname: NotRequired[str] + avatar: NotRequired[File] + roles: NotRequired[list[str]] + joined_at: int | str + timeout: NotRequired[str | int] diff --git a/next/types/message.py b/next/types/message.py new file mode 100644 index 0000000..90b3f61 --- /dev/null +++ b/next/types/message.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict, Union + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .embed import Embed + from .file import File + + +__all__ = ( + "UserAddContent", + "UserRemoveContent", + "UserJoinedContent", + "UserLeftContent", + "UserKickedContent", + "UserBannedContent", + "ChannelRenameContent", + "ChannelDescriptionChangeContent", + "ChannelIconChangeContent", + "Masquerade", + "Interactions", + "Message", + "MessageReplyPayload", + "SystemMessageContent", + ) + +class UserAddContent(TypedDict): + id: str + by: str + +class UserRemoveContent(TypedDict): + id: str + by: str + +class UserJoinedContent(TypedDict): + id: str + by: str + +class UserLeftContent(TypedDict): + id: str + +class UserKickedContent(TypedDict): + id: str + +class UserBannedContent(TypedDict): + id: str + +class ChannelRenameContent(TypedDict): + name: str + by: str + +class ChannelDescriptionChangeContent(TypedDict): + by: str + +class ChannelIconChangeContent(TypedDict): + by: str + +class Masquerade(TypedDict, total=False): + name: str + avatar: str + colour: str + +class Interactions(TypedDict): + reactions: NotRequired[list[str]] + restrict_reactions: NotRequired[bool] + +SystemMessageContent = Union[UserAddContent, UserRemoveContent, UserJoinedContent, UserLeftContent, UserKickedContent, UserBannedContent, ChannelRenameContent, ChannelDescriptionChangeContent, ChannelIconChangeContent] + +class Message(TypedDict): + _id: str + channel: str + author: str + content: str + system: NotRequired[SystemMessageContent] + attachments: NotRequired[list[File]] + embeds: NotRequired[list[Embed]] + mentions: NotRequired[list[str]] + replies: NotRequired[list[str]] + edited: NotRequired[str | int] + masquerade: NotRequired[Masquerade] + interactions: NotRequired[Interactions] + reactions: dict[str, list[str]] + +class MessageReplyPayload(TypedDict): + id: str + mention: bool diff --git a/next/types/permissions.py b/next/types/permissions.py new file mode 100644 index 0000000..ae32692 --- /dev/null +++ b/next/types/permissions.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from typing import TypedDict + + +class Overwrite(TypedDict): + a: int + d: int diff --git a/next/types/role.py b/next/types/role.py new file mode 100644 index 0000000..b4aa9a1 --- /dev/null +++ b/next/types/role.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .permissions import Overwrite + +__all__ = ( + "Role", +) + +class Role(TypedDict): + name: str + permissions: Overwrite + colour: NotRequired[str] + hoist: NotRequired[bool] + rank: int diff --git a/next/types/server.py b/next/types/server.py new file mode 100644 index 0000000..f965e27 --- /dev/null +++ b/next/types/server.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .category import Category + from .file import File + from .role import Role + +__all__ = ( + "Server", + "BannedUser", + "Ban", + "ServerBans", + "SystemMessagesConfig" +) + +class SystemMessagesConfig(TypedDict, total=False): + user_joined: str + user_left: str + user_kicked: str + user_banned: str + + +class Server(TypedDict): + _id: str + owner: str + name: str + channels: list[str] + default_permissions: int + nonce: NotRequired[str] + description: NotRequired[str] + categories: NotRequired[list[Category]] + system_messages: NotRequired[SystemMessagesConfig] + roles: NotRequired[dict[str, Role]] + icon: NotRequired[File] + banner: NotRequired[File] + nsfw: NotRequired[bool] + +class BannedUser(TypedDict): + _id: str + username: str + avatar: NotRequired[File] + +class BanId(TypedDict): + server: str + user: str + +class Ban(TypedDict): + _id: BanId + reason: NotRequired[str] + +class ServerBans(TypedDict): + users: list[BannedUser] + bans: list[Ban] diff --git a/next/types/user.py b/next/types/user.py new file mode 100644 index 0000000..10f94ec --- /dev/null +++ b/next/types/user.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal, TypedDict + +from typing_extensions import NotRequired + +if TYPE_CHECKING: + from .file import File + +__all__ = ( + "UserRelation", + "Relation", + "UserBot", + "Status", + "User", + "UserProfile", +) + +Relation = Literal["Blocked", "BlockedOther", "Friend", "Incoming", "None", "Outgoing", "User"] + +class UserBot(TypedDict): + owner: str + +class Status(TypedDict, total=False): + text: str + presence: Literal["Busy", "Idle", "Invisible", "Online"] + +class UserRelation(TypedDict): + status: Relation + _id: str + +class User(TypedDict): + _id: str + username: str + discriminator: str + display_name: NotRequired[str] + avatar: NotRequired[File] + relations: NotRequired[list[UserRelation]] + badges: NotRequired[int] + status: NotRequired[Status] + relationship: NotRequired[Relation] + online: NotRequired[bool] + flags: NotRequired[int] + bot: NotRequired[UserBot] + privileged: NotRequired[bool] + +class UserProfile(TypedDict, total=False): + content: str + background: File diff --git a/next/user.py b/next/user.py new file mode 100644 index 0000000..73a83b2 --- /dev/null +++ b/next/user.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, NamedTuple, Optional, Union +from weakref import WeakValueDictionary + +from next.types.user import UserRelation + +from .asset import Asset, PartialAsset +from .channel import DMChannel, GroupDMChannel, SavedMessageChannel +from .enums import PresenceType, RelationshipType +from .flags import UserBadges +from .messageable import Messageable +from .permissions import UserPermissions +from .utils import Ulid + +if TYPE_CHECKING: + from .member import Member + from .state import State + from .types import File + from .types import Status as StatusPayload + from .types import User as UserPayload + from .types import UserProfile as UserProfileData + from .server import Server + +__all__ = ("User", "Status", "Relation", "UserProfile") + +class Relation(NamedTuple): + """A namedtuple representing a relation between the bot and a user""" + type: RelationshipType + user: User + +class Status(NamedTuple): + """A namedtuple representing a users status""" + text: Optional[str] + presence: Optional[PresenceType] + +class UserProfile(NamedTuple): + """A namedtuple representing a users profile""" + content: Optional[str] + background: Optional[Asset] + +class User(Messageable, Ulid): + """Represents a user + + Attributes + ----------- + id: :class:`str` + The user's id + discriminator: :class:`str` + The user's discriminator + display_name: Optional[:class:`str`] + The user's display name if they have one + bot: :class:`bool` + Whether or not the user is a bot + owner_id: Optional[:class:`str`] + The bot's owner id if the user is a bot + badges: :class:`UserBadges` + The users badges + online: :class:`bool` + Whether or not the user is online + flags: :class:`int` + The user flags + relations: list[:class:`Relation`] + A list of the users relations + relationship: Optional[:class:`RelationshipType`] + The relationship between the user and the bot + status: Optional[:class:`Status`] + The users status + dm_channel: Optional[:class:`DMChannel`] + The dm channel between the client and the user, this will only be set if the client has dm'ed the user or :meth:`User.open_dm` was run + privileged: :class:`bool` + Whether the user is privileged + """ + __flattern_attributes__: tuple[str, ...] = ("id", "discriminator", "display_name", "bot", "owner_id", "badges", "online", "flags", "relations", "relationship", "status", "masquerade_avatar", "masquerade_name", "original_name", "original_avatar", "profile", "dm_channel", "privileged") + __slots__: tuple[str, ...] = (*__flattern_attributes__, "state", "_members") + + def __init__(self, data: UserPayload, state: State): + self.state = state + self._members: WeakValueDictionary[str, Member] = WeakValueDictionary() # we store all member versions of this user to avoid having to check every guild when needing to update. + self.id: str = data["_id"] + self.discriminator: str = data["discriminator"] + self.display_name: str | None = data.get("display_name") + self.original_name: str = data["username"] + self.dm_channel: DMChannel | SavedMessageChannel | None = None + + bot = data.get("bot") + + self.bot: bool + self.owner_id: str | None + + if bot: + self.bot = True + self.owner_id = bot["owner"] + else: + self.bot = False + self.owner_id = None + + self.badges: UserBadges = UserBadges._from_value(data.get("badges", 0)) + self.online: bool = data.get("online", False) + self.flags: int = data.get("flags", 0) + self.privileged: bool = data.get("privileged", False) + + avatar = data.get("avatar") + self.original_avatar: Asset | None = Asset(avatar, state) if avatar else None + + relations: list[Relation] = [] + + for relation in data.get("relations", []): + user = state.get_user(relation["_id"]) + if user: + relations.append(Relation(RelationshipType(relation["status"]), user)) + + self.relations: list[Relation] = relations + + relationship = data.get("relationship") + self.relationship: RelationshipType | None = RelationshipType(relationship) if relationship else None + + status = data.get("status") + self.status: Status | None + + if status: + presence = status.get("presence") + self.status = Status(status.get("text"), PresenceType(presence) if presence else None) if status else None + else: + self.status = None + + self.profile: Optional[UserProfile] = None + + self.masquerade_avatar: Optional[PartialAsset] = None + self.masquerade_name: Optional[str] = None + + def get_permissions(self) -> UserPermissions: + """Gets the permissions for the user + + Returns + -------- + :class:`UserPermissions` + The users permissions + """ + permissions = UserPermissions() + + if self.relationship in [RelationshipType.friend, RelationshipType.user]: + return UserPermissions.all() + + elif self.relationship in [RelationshipType.blocked, RelationshipType.blocked_other]: + return UserPermissions(access=True) + + elif self.relationship in [RelationshipType.incoming_friend_request, RelationshipType.outgoing_friend_request]: + permissions.access = True + + for channel in self.state.channels.values(): + if (isinstance(channel, (GroupDMChannel, DMChannel)) and self.id in channel.recipient_ids) or any(self.id in (m.id for m in server.members) for server in self.state.servers.values()): + if self.state.me.bot or self.bot: + permissions.send_message = True + + permissions.access = True + permissions.view_profile = True + + return permissions + + def has_permissions(self, **permissions: bool) -> bool: + """Computes if the user has the specified permissions + + Parameters + ----------- + permissions: :class:`bool` + The permissions to check, this also accepted `False` if you need to check if the user does not have the permission + + Returns + -------- + :class:`bool` + Whether or not they have the permissions + """ + perms = self.get_permissions() + + return all([getattr(perms, key, False) == value for key, value in permissions.items()]) + + async def _get_channel_id(self): + if not self.dm_channel: + payload = await self.state.http.open_dm(self.id) + + if payload["channel_type"] == "SavedMessages": + self.dm_channel = SavedMessageChannel(payload, self.state) + else: + self.dm_channel = DMChannel(payload, self.state) + + return self.dm_channel.id + + @property + def owner(self) -> User: + """:class:`User` the owner of the bot account""" + + if not self.owner_id: + raise LookupError + + return self.state.get_user(self.owner_id) + + @property + def name(self) -> str: + """:class:`str` The name the user is displaying, this includes (in order) their masqueraded name, display name and orginal name""" + return self.display_name or self.masquerade_name or self.original_name + + @property + def avatar(self) -> Union[Asset, PartialAsset, None]: + """Optional[:class:`Asset`] The avatar the member is displaying, this includes there orginal avatar and masqueraded avatar""" + return self.masquerade_avatar or self.original_avatar + + @property + def mention(self) -> str: + """:class:`str`: Returns a string that allows you to mention the given user.""" + return f"<@{self.id}>" + + def _update( + self, + *, + status: Optional[StatusPayload] = None, + profile: Optional[UserProfileData] = None, + avatar: Optional[File] = None, + online: Optional[bool] = None, + display_name: Optional[str] = None, + relations: Optional[list[UserRelation]] = None, + badges: Optional[int] = None, + flags: Optional[int] = None, + discriminator: Optional[str] = None, + privileged: Optional[bool] = None, + username: Optional[str] = None + ) -> None: + if status is not None: + presence = status.get("presence") + self.status = Status(status.get("text"), PresenceType(presence) if presence else None) + + if profile is not None: + if background_file := profile.get("background"): + background = Asset(background_file, self.state) + else: + background = None + + self.profile = UserProfile(profile.get("content"), background) + + if avatar is not None: + self.original_avatar = Asset(avatar, self.state) + + if online is not None: + self.online = online + + if display_name is not None: + self.display_name = display_name + + if relations is not None: + new_relations: list[Relation] = [] + + for relation in relations: + user = self.state.get_user(relation["_id"]) + if user: + new_relations.append(Relation(RelationshipType(relation["status"]), user)) + + self.relations = new_relations + + if badges is not None: + self.badges = UserBadges(badges) + + if flags is not None: + self.flags = flags + + if discriminator is not None: + self.discriminator = discriminator + + if privileged is not None: + self.privileged = privileged + + if username is not None: + self.original_name = username + + # update user infomation for all members + + if self.__class__ is User: + for member in self._members.values(): + User._update( + member, + status=status, + profile=profile, + avatar=avatar, + online=online, + display_name=display_name, + relations=relations, + badges=badges, + flags=flags, + discriminator=discriminator, + privileged=privileged, + username=username + ) + + async def default_avatar(self) -> bytes: + """Returns the default avatar for this user + + Returns + -------- + :class:`bytes` + The bytes of the image + """ + return await self.state.http.fetch_default_avatar(self.id) + + async def fetch_profile(self) -> UserProfile: + """Fetches the user's profile + + Returns + -------- + :class:`UserProfile` + The user's profile + """ + if profile := self.profile: + return profile + + payload = await self.state.http.fetch_profile(self.id) + + if file := payload.get("background"): + background = Asset(file, self.state) + else: + background = None + + self.profile = UserProfile(payload.get("content"), background) + return self.profile + + def to_member(self, server: Server) -> Member: + """Gets the member instance for this user for a specific server. + + Roughly equivelent to: + + .. code-block:: python + + member = server.get_member(user.id) + + + Parameters + ----------- + server: :class:`Server` + The server to get the member for + + Returns + -------- + :class:`Member` + The member + + Raises + ------- + :class:`LookupError` + + """ + try: + return self._members[server.id] + except IndexError: + raise LookupError from None + + async def open_dm(self) -> DMChannel | SavedMessageChannel: + """Opens a dm with the user, if this user is the current user this will return :class:`SavedMessageChannel` + + .. note:: using this function is discouraged as :meth:`User.send` does this implicitally. + + Returns + -------- + Union[:class:`DMChannel`, :class:`SavedMessageChannel`] + """ + + await self._get_channel_id() + + assert self.dm_channel + return self.dm_channel diff --git a/next/utils.py b/next/utils.py new file mode 100644 index 0000000..8a5f50c --- /dev/null +++ b/next/utils.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import datetime +import inspect +from contextlib import asynccontextmanager +from operator import attrgetter +from typing import Any, Callable, Coroutine, Iterable, Literal, TypeVar, Union + +import ulid +from aiohttp import ClientSession +from typing_extensions import ParamSpec + +__all__ = ("_Missing", "Missing", "copy_doc", "maybe_coroutine", "get", "client_session", "parse_timestamp") + +class _Missing: + def __repr__(self) -> str: + return "" + + def __bool__(self) -> Literal[False]: + return False + +Missing: _Missing = _Missing() + +T = TypeVar("T") + +def copy_doc(from_t: T) -> Callable[[T], T]: + def inner(to_t: T) -> T: + to_t.__doc__ = from_t.__doc__ + return to_t + + return inner + +R_T = TypeVar("R_T") +P = ParamSpec("P") + +# it is impossible to type this function correctly as typeguard does not narrow for the negative case, +# so `value` would stay being a union even after the if statement (PEP 647 - "The type is not narrowed in the negative case") +# see typing#926, typing#930, typing#996 + +async def maybe_coroutine(func: Callable[P, Union[R_T, Coroutine[Any, Any, R_T]]], *args: P.args, **kwargs: P.kwargs) -> R_T: + value = func(*args, **kwargs) + + if inspect.isawaitable(value): + value = await value + + return value # type: ignore + + +class Ulid: + id: str + + @property + def created_at(self) -> datetime.datetime: + return ulid.from_str(self.id).timestamp().datetime + +class Object(Ulid): + """Class to mock objects with an id""" + def __init__(self, id: str): + self.id = id + +def get(iterable: Iterable[T], **attrs: Any) -> T: + """A convenience function to help get a value from an iterable with a specific attribute + + Examples + --------- + + .. code-block:: python + :emphasize-lines: 3 + + from next import utils + + channel = utils.get(server.channels, name="General") + await channel.send("Hello general chat.") + + Parameters + ----------- + iterable: Iterable + The values to search though + **attrs: Any + The attributes to check + + Returns + -------- + Any + The value from the iterable with the met attributes + + Raises + ------- + LookupError + Raises when none of the values in the iterable matches the attributes + + """ + converted = [(attrgetter(attr.replace('__', '.')), value) for attr, value in attrs.items()] + + for elem in iterable: + if all(pred(elem) == value for pred, value in converted): + return elem + + raise LookupError + + +@asynccontextmanager +async def client_session(): + """A context manager that creates a new aiohttp.ClientSession() and closes it when exiting the context. + + Examples + --------- + + .. code-block:: python + :emphasize-lines: 3 + + async def main(): + async with client_session() as session: + client = next.Client(session, "TOKEN") + await client.start() + + asyncio.run(main()) + """ + session = ClientSession() + + try: + yield session + finally: + await session.close() + +def parse_timestamp(timestamp: int | str) -> datetime.datetime: + if isinstance(timestamp, int): + return datetime.datetime.fromtimestamp(timestamp / 1000, tz=datetime.timezone.utc) + else: + return datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%f%z") diff --git a/next/websocket.py b/next/websocket.py new file mode 100644 index 0000000..6ec1604 --- /dev/null +++ b/next/websocket.py @@ -0,0 +1,496 @@ +from __future__ import annotations + +import asyncio +import logging +import time +from copy import copy +from typing import TYPE_CHECKING, Callable, NamedTuple, cast + +from .errors import NextError +from . import utils +from .channel import GroupDMChannel, TextChannel, VoiceChannel +from .enums import RelationshipType +from .role import Role +from .types import (BulkMessageDeleteEventPayload, ChannelCreateEventPayload, + ChannelDeleteEventPayload, ChannelDeleteTypingEventPayload, + ChannelStartTypingEventPayload, ChannelUpdateEventPayload) +from .types import Member as MemberPayload +from .types import MemberID as MemberIDPayload +from .types import Message as MessagePayload +from .types import (MessageDeleteEventPayload, MessageReactEventPayload, + MessageRemoveReactionEventPayload, + MessageUnreactEventPayload, MessageUpdateEventPayload) +from .types import Role as RolePayload +from .types import (ServerCreateEventPayload, ServerDeleteEventPayload, + ServerMemberJoinEventPayload, + ServerMemberLeaveEventPayload, + ServerMemberUpdateEventPayload, + ServerRoleDeleteEventPayload, ServerRoleUpdateEventPayload, + ServerUpdateEventPayload, UserRelationshipEventPayload, + UserUpdateEventPayload) +from .user import Status, User, UserProfile + +import aiohttp + +try: + import ujson as json +except ImportError: + import json + +use_msgpack: bool + +try: + import msgpack + use_msgpack = True +except ImportError: + use_msgpack = False + +if TYPE_CHECKING: + import aiohttp + + from .state import State + from .types import (AuthenticatePayload, BasePayload, MessageEventPayload, + ReadyEventPayload) + from .message import Message + +class WSMessage(NamedTuple): + type: aiohttp.WSMsgType + data: str | bytes | aiohttp.WSCloseCode + +__all__: tuple[str, ...] = ("WebsocketHandler",) + +logger: logging.Logger = logging.getLogger("next") + +class WebsocketHandler: + __slots__ = ("session", "token", "ws_url", "dispatch", "state", "websocket", "loop", "user", "ready", "server_events") + + def __init__(self, session: aiohttp.ClientSession, token: str, ws_url: str, dispatch: Callable[..., None], state: State): + self.session: aiohttp.ClientSession = session + self.token: str = token + self.ws_url: str = ws_url + self.dispatch: Callable[..., None] = dispatch + self.state: State = state + self.websocket: aiohttp.ClientWebSocketResponse + self.loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + self.user: User | None = None + self.ready: asyncio.Event = asyncio.Event() + self.server_events: dict[str, asyncio.Event] = {} + + async def _wait_for_server_ready(self, server_id: str) -> None: + if event := self.server_events.get(server_id): + await event.wait() + + async def send_payload(self, payload: BasePayload) -> None: + if use_msgpack: + await self.websocket.send_bytes(msgpack.packb(payload)) # type: ignore + else: + await self.websocket.send_str(json.dumps(payload)) + + async def heartbeat(self) -> None: + while not self.websocket.closed: + logger.info("Sending hearbeat") + await self.websocket.ping() + await asyncio.sleep(15) + + async def send_authenticate(self) -> None: + payload: AuthenticatePayload = { + "type": "Authenticate", + "token": self.token + } + + await self.send_payload(payload) + + async def handle_event(self, payload: BasePayload) -> None: + event_type = payload["type"].lower() + logger.debug("Recieved event %s %s", event_type, payload) + + try: + if event_type not in ["ready", "notfound"]: + await self.ready.wait() + + func = getattr(self, f"handle_{event_type}") + except AttributeError: + return logger.debug("Unknown event '%s'", event_type) + + await func(payload) + + async def handle_authenticated(self, _: BasePayload) -> None: + logger.info("Successfully authenticated") + + async def handle_notfound(self, _: BasePayload) -> None: + raise NextError("Invalid token") + + async def handle_ready(self, payload: ReadyEventPayload) -> None: + # Сначала добавляем пользователей + for user_payload in payload["users"]: + user = self.state.add_user(user_payload) + + if user.relationship == RelationshipType.user: + self.user = user + + for server in payload["servers"]: + self.state.add_server(server) + + for channel in payload["channels"]: + self.state.add_channel(channel) + + for member in payload["members"]: + self.state.add_member(member["_id"]["server"], member) + + for emoji in payload["emojis"]: + emoji = self.state.add_emoji(emoji) + + await self.state.fetch_all_server_members() + + self.ready.set() + self.dispatch("ready") + + async def handle_message(self, payload: MessageEventPayload) -> None: + if server := self.state.get_channel(payload["channel"]).server_id: + await self._wait_for_server_ready(server) + + message = self.state.add_message(cast(MessagePayload, payload)) + + + self.dispatch("message", message) + + async def handle_messageupdate(self, payload: MessageUpdateEventPayload) -> None: + self.dispatch("raw_message_update", payload) + + try: + message = self.state.get_message(payload["id"]) + except LookupError: + return + + if server_id := message.channel.server_id: + await self._wait_for_server_ready(server_id) + + before = copy(message) + message._update(**payload["data"]) + + self.dispatch("message_update", before, message) + + async def handle_messagedelete(self, payload: MessageDeleteEventPayload) -> None: + self.dispatch("raw_message_delete", payload) + + try: + message = self.state.get_message(payload["id"]) + except LookupError: + return + + if server_id := message.channel.server_id: + await self._wait_for_server_ready(server_id) + + self.state.messages.remove(message) + + + self.dispatch("message_delete", message) + + async def handle_channelcreate(self, payload: ChannelCreateEventPayload) -> None: + channel = self.state.add_channel(payload) + + if server_id := channel.server_id: + await self._wait_for_server_ready(server_id) + + self.dispatch("channel_create", channel) + + async def handle_channelupdate(self, payload: ChannelUpdateEventPayload) -> None: + # Next sends channel updates for channels we dont have permissions to see, a bug, but still can cause issues as its not in the cache + + if not (channel := self.state.channels.get(payload["id"], None)): + return + + if server_id := channel.server_id: + await self._wait_for_server_ready(server_id) + + old_channel = copy(channel) + + channel._update(**payload["data"]) + + if clear := payload.get("clear"): + if clear == "Icon": + if isinstance(channel, (TextChannel, VoiceChannel, GroupDMChannel)): + channel.icon = None + + elif clear == "Description": + if isinstance(channel, (TextChannel, VoiceChannel, GroupDMChannel)): + channel.description = None + + + self.dispatch("channel_update", old_channel, channel) + + async def handle_channeldelete(self, payload: ChannelDeleteEventPayload) -> None: + channel = self.state.channels.pop(payload["id"]) + + if server_id := channel.server_id: + await self._wait_for_server_ready(server_id) + + self.dispatch("channel_delete", channel) + + async def handle_channelstarttyping(self, payload: ChannelStartTypingEventPayload) -> None: + channel = self.state.get_channel(payload["id"]) + + if server_id := channel.server_id: + await self._wait_for_server_ready(server_id) + + user = self.state.get_user(payload["user"]) + + self.dispatch("typing_start", channel, user) + + async def handle_channelstoptyping(self, payload: ChannelDeleteTypingEventPayload) -> None: + channel = self.state.get_channel(payload["id"]) + + if server_id := channel.server_id: + await self._wait_for_server_ready(server_id) + + user = self.state.get_user(payload["user"]) + + self.dispatch("typing_stop", channel, user) + + async def handle_serverupdate(self, payload: ServerUpdateEventPayload) -> None: + await self._wait_for_server_ready(payload["id"]) + + server = self.state.get_server(payload["id"]) + + old_server = copy(server) + + server._update(**payload["data"]) + + if clear := payload.get("clear"): + if clear == "Icon": + server.icon = None + + elif clear == "Banner": + server.banner = None + + elif clear == "Description": + server.description = None + + + self.dispatch("server_update", old_server, server) + + async def handle_serverdelete(self, payload: ServerDeleteEventPayload) -> None: + server = self.state.servers.pop(payload["id"]) + + for channel in server.channels: + del self.state.channels[channel.id] + + await self._wait_for_server_ready(server.id) + + self.dispatch("server_delete", server) + + async def handle_servercreate(self, payload: ServerCreateEventPayload) -> None: + for channel in payload["channels"]: + self.state.add_channel(channel) + + server = self.state.add_server(payload["server"]) + + # lock all server events until we fetch all the members, otherwise the cache will be incomplete + self.server_events[server.id] = asyncio.Event() + await self.state.fetch_server_members(server.id) + self.server_events.pop(server.id).set() + + self.dispatch("server_join", server) + + async def handle_servermemberupdate(self, payload: ServerMemberUpdateEventPayload) -> None: + await self._wait_for_server_ready(payload["id"]["server"]) + + member = self.state.get_member(payload["id"]["server"], payload["id"]["user"]) + old_member = copy(member) + + if clear := payload.get("clear"): + if clear == "Nickname": + member.nickname = None + elif clear == "Avatar": + member.guild_avatar = None + + member._update(**payload["data"]) + + self.dispatch("member_update", old_member, member) + + async def handle_servermemberjoin(self, payload: ServerMemberJoinEventPayload) -> None: + # avoid an api request if possible + if payload["user"] not in self.state.users: + user = await self.state.http.fetch_user(payload["user"]) + self.state.add_user(user) + + member = self.state.add_member(payload["id"], MemberPayload(_id=MemberIDPayload(server=payload["id"], user=payload["user"]), joined_at=int(time.time()))) # next doesnt give us the joined at time + + self.dispatch("member_join", member) + + async def handle_memberleave(self, payload: ServerMemberLeaveEventPayload) -> None: + await self._wait_for_server_ready(payload["id"]) + + server = self.state.get_server(payload["id"]) + member = server._members.pop(payload["user"]) + + # remove the member from the user + + user = self.state.get_user(payload["user"]) + user._members.pop(server.id) + + self.dispatch("member_leave", member) + + async def handle_serverroleupdate(self, payload: ServerRoleUpdateEventPayload) -> None: + server = self.state.get_server(payload["id"]) + await self._wait_for_server_ready(server.id) + + try: + role = server.get_role(payload["role_id"]) + except LookupError: + # the role wasnt found meaning it was just created + + role = Role(cast(RolePayload, payload["data"]), payload["role_id"], server, self.state) + server._roles[role.id] = role + self.dispatch("role_create", role) + else: + old_role = copy(role) + + if clear := payload.get("clear"): + if clear == "Colour": + role.colour = None + + role._update(**payload["data"]) + + self.dispatch("role_update", old_role, role) + + async def handle_serverroledelete(self, payload: ServerRoleDeleteEventPayload) -> None: + server = self.state.get_server(payload["id"]) + role = server._roles.pop(payload["role_id"]) + + await self._wait_for_server_ready(server.id) + + self.dispatch("role_delete", role) + + async def handle_userupdate(self, payload: UserUpdateEventPayload) -> None: + user = self.state.get_user(payload["id"]) + old_user = copy(user) + + if clear := payload.get("clear"): + if clear == "ProfileContent": + if profile := user.profile: + user.profile = UserProfile(None, profile.background) + + elif clear == "ProfileBackground": + if profile := user.profile: + user.profile = UserProfile(profile.content, None) + + elif clear == "StatusText": + user.status = Status(None, user.status.presence if user.status else None) + + elif clear == "Avatar": + user.original_avatar = None + + user._update(**payload["data"]) + + self.dispatch("user_update", old_user, user) + + async def handle_userrelationship(self, payload: UserRelationshipEventPayload) -> None: + user = self.state.get_user(payload["user"]) + old_relationship = user.relationship + user.relationship = RelationshipType(payload["status"]) + + self.dispatch("user_relationship_update", user, old_relationship, user.relationship) + + async def handle_messagereact(self, payload: MessageReactEventPayload) -> None: + if server := self.state.get_channel(payload["channel_id"]).server_id: + await self._wait_for_server_ready(server) + + self.dispatch("raw_reaction_add", payload) + + try: + message = utils.get(self.state.messages, id=payload["id"]) + except LookupError: + return + + user = self.state.get_user(payload["user_id"]) + message.reactions.setdefault(payload["emoji_id"], []).append(user) + emoji_id = payload["emoji_id"] + + self.dispatch("reaction_add", message, user, emoji_id) + + async def handle_messageunreact(self, payload: MessageUnreactEventPayload) -> None: + if server := self.state.get_channel(payload["channel_id"]).server_id: + await self._wait_for_server_ready(server) + + self.dispatch("raw_reaction_remove", payload) + + try: + message = utils.get(self.state.messages, id=payload["id"]) + except LookupError: + return + + user = self.state.get_user(payload["user_id"]) + message.reactions[payload["emoji_id"]].remove(user) + + self.dispatch("reaction_remove", message, user, payload["emoji_id"]) + + async def handle_messageremovereaction(self, payload: MessageRemoveReactionEventPayload) -> None: + if server := self.state.get_channel(payload["channel_id"]).server_id: + await self._wait_for_server_ready(server) + + self.dispatch("raw_reaction_clear", payload) + + try: + message = utils.get(self.state.messages, id=payload["id"]) + except LookupError: + return + + users = message.reactions.pop(payload["emoji_id"]) + + self.dispatch("reaction_clear", message, users, payload["emoji_id"]) + + async def handle_bulkmessagedelete(self, payload: BulkMessageDeleteEventPayload) -> None: + channel = self.state.get_channel(payload["channel"]) + + self.dispatch("raw_bulk_message_delete", payload) + + messages: list[Message] = [] + + for message_id in payload["ids"]: + if server_id := channel.server_id: + await self._wait_for_server_ready(server_id) + + self.dispatch("raw_message_delete", MessageDeleteEventPayload(type="messagedelete", channel=payload["channel"], id=message_id)) + + try: + message = self.state.get_message(message_id) + except LookupError: + pass + else: + self.state.messages.remove(message) + self.dispatch("message_delete", message) + + messages.append(message) + + self.dispatch("bulk_message_delete", messages) + + async def start(self, reconnect: bool) -> None: + if use_msgpack: + url = f"{self.ws_url}?format=msgpack" + else: + url = f"{self.ws_url}?format=json" + + while True: + self.websocket = await self.session.ws_connect(url) # type: ignore + await self.send_authenticate() + hb = asyncio.create_task(self.heartbeat()) + + async for msg in self.websocket: + msg = cast(WSMessage, msg) # aiohttp doesnt use NamedTuple so the type info is missing + + if use_msgpack: + data = cast(bytes, msg.data) + + payload = msgpack.unpackb(data) # type: ignore + else: + data = cast(str, msg.data) + + payload = json.loads(data) + + self.loop.create_task(self.handle_event(payload)) + + hb.cancel() + + if not reconnect: + return diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..4e32227 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,69 @@ +[project] +name = "next-api-py" +dynamic = ["version"] +description = "Python wrapper for the next.avanpost20.ru API" +requires-python = ">=3.9" +license = "MIT" +readme = "README.md" +keywords = ["wrapper", "async", "api", "websockets", "http"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3 :: Only", +] +dependencies = [ + "aiohttp==3.9.*", + "ulid-py==1.1.*", + "aenum==3.1.*", + "typing_extensions>=4.4.0" +] + +[project.optional-dependencies] +speedups = [ + "ujson==5.1.*", + "msgpack==1.0.*" +] +docs = [ + "Sphinx==5.2.*", + "sphinx-nameko-theme==0.0.*", + "sphinx-toolbox==3.2.*", + "setuptools==65.4.*" +] + +[project.urls] +Homepage = "https://git.avanpost20.ru/next/next.py" +Documentation = "https://nextpy.avanpost20.ru/" +"Source Code" = "https://git.avanpost20.ru/next/next.py" +"Bug Tracker" = "https://git.avanpost20.ru/next/next.py/issues" + +[[project.authors]] +name = "Avanpost" +email = "me@avanpost20.ru" + +[tool.hatch.version] +path = "next/__init__.py" + +[tool.hatch.build] +only-packages = true +include = ["next/**/*"] + + +[tool.pyright] +reportPrivateUsage = false +reportImportCycles = false +reportIncompatibleMethodOverride = false +typeCheckingMode = "strict" + +[tool.hatch.build.targets.sdist] +strict-naming = false + +[tool.hatch.build.targets.wheel] +strict-naming = false + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/typings/msgpack/__init__.pyi b/typings/msgpack/__init__.pyi new file mode 100644 index 0000000..189bc5d --- /dev/null +++ b/typings/msgpack/__init__.pyi @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Optional, Tuple + +from typing_extensions import Protocol + +class _FileLike(Protocol): + def read(self, n: int) -> bytes: ... + +def unpackb( + packed: bytes, + file_like: Optional[_FileLike] = ..., + read_size: int = ..., + use_list: bool = ..., + raw: bool = ..., + timestamp: int = ..., + strict_map_key: bool = ..., + object_hook: Optional[Callable[[Dict[Any, Any]], Any]] = ..., + object_pairs_hook: Optional[Callable[[List[Tuple[Any, Any]]], Any]] = ..., + list_hook: Optional[Callable[[List[Any]], Any]] = ..., + unicode_errors: Optional[str] = ..., + max_buffer_size: int = ..., + ext_hook: Callable[[int, bytes], Any] = ..., + max_str_len: int = ..., + max_bin_len: int = ..., + max_array_len: int = ..., + max_map_len: int = ..., + max_ext_len: int = ..., +) -> Any: ... + +def packb( + o: Any, + default: Optional[Callable[[Any], Any]] = ..., + use_single_float: bool = ..., + autoreset: bool = ..., + use_bin_type: bool = ..., + strict_types: bool = ..., + datetime: bool = ..., + unicode_errors: Optional[str] = ..., +) -> bytes: ... diff --git a/typings/sphinx_nameko_theme/__init__.pyi b/typings/sphinx_nameko_theme/__init__.pyi new file mode 100644 index 0000000..4fbb77d --- /dev/null +++ b/typings/sphinx_nameko_theme/__init__.pyi @@ -0,0 +1 @@ +def get_html_theme_path() -> str: ...