diff --git a/.gitignore b/.gitignore index f6283d1..66833da 100644 --- a/.gitignore +++ b/.gitignore @@ -131,4 +131,6 @@ dmypy.json #pycharm and vscode folders .vscode -.idea \ No newline at end of file +.idea + +.DS_Store diff --git a/README.md b/README.md index 27d5a18..da06e5a 100644 --- a/README.md +++ b/README.md @@ -2,4 +2,6 @@ [![Discord Support Server](https://img.shields.io/discord/885214547391180860?label=Disthon%20-%20Support%20Server&color=5865f2&labelColor=5865f2&&logo=discord&logoColor=ffffff&style=flat-square)](https://discord.gg/PtcfyJHKKp) +![Logo](./logo.png?raw=true) + Discord API wrapper for Python built from scratch diff --git a/discord/__init__.py b/discord/__init__.py index d609e58..2287a23 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -2,6 +2,7 @@ A work in progress discord wrapper built from scratch """ -from .api.intents import Intents as Intents -from .client import Client as Client -from .embeds import Embed as Embed +from .api.intents import Intents +from .client import Client +from .embeds import Embed +from .message import Message diff --git a/discord/abc/abstractuser.py b/discord/abc/abstractuser.py index 15eb8a6..d23987f 100644 --- a/discord/abc/abstractuser.py +++ b/discord/abc/abstractuser.py @@ -2,24 +2,18 @@ from typing import Optional -from discordobject import DiscordObject - from ..message import Message from ..types.avatar import Avatar +from .discordobject import DiscordObject class AbstractUser(DiscordObject): - avatar: Optional[Avatar] - bot: bool - username: str - discriminator: str - id: int @property def tag(self): return f"{self.username}#{self.discriminator}" - + @property def discriminator(self): return self.discriminator @@ -27,11 +21,11 @@ def discriminator(self): @property def mention(self): return f"<@!{self.id}>" - - @propery + + @property def name(self): return self.username - + @property def id(self): return self.id diff --git a/discord/abc/discordobject.py b/discord/abc/discordobject.py index d271fd8..421a7f9 100644 --- a/discord/abc/discordobject.py +++ b/discord/abc/discordobject.py @@ -1,19 +1,29 @@ from __future__ import annotations - -from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel from ..types.snowflake import Snowflake -class DiscordObject(BaseModel): +if TYPE_CHECKING: + from discord import Client + +class DiscordObject(BaseModel): id: Snowflake - created_at: datetime + _client: Client + + class Config: + arbitrary_types_allowed = True + + def __init__(self, client, **payload): + super().__init__(_client=client, **payload) + object.__setattr__(self, "_client", client) # For some reason pydantic doesn't set the client attribute + # So we'll set it manually def __ne__(self, other): return not self.__eq__(other) def __hash__(self): - return self.id.id >> 22 + return self.id >> 22 diff --git a/discord/abc/messageable.py b/discord/abc/messageable.py new file mode 100644 index 0000000..cffb069 --- /dev/null +++ b/discord/abc/messageable.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import abc +import typing + +import discord + +if typing.TYPE_CHECKING: + from discord import Client, Embed + from discord.interactions import View + + +class Messageable(abc.ABC): + id: int + _client: Client + + def _get_channel(self): + raise NotImplementedError + + async def send(self, + content: typing.Optional[str] = None, + *, + embeds: typing.Union[Embed, typing.List[Embed]] = None, + views: typing.Union[View, typing.List[View]] = None + ): + content = str(content) if content is not None else None + + channel = self._get_channel() + data = await self._client.httphandler.send_message(channel.id, content=content, embeds=embeds, views=views) + + return discord.message.Message(self._client, **data) diff --git a/discord/activity/presenseassets.py b/discord/activity/presenseassets.py index 078f356..c51c276 100644 --- a/discord/activity/presenseassets.py +++ b/discord/activity/presenseassets.py @@ -1,8 +1,12 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from ..types.snowflake import Snowflake -from .activity import Activity -from .rawactivityassets import RawActivityAssets + +if TYPE_CHECKING: + from .activity import Activity + from .rawactivityassets import RawActivityAssets class PresenceAssets(dict[Snowflake, str]): diff --git a/discord/api/dataConverters.py b/discord/api/dataConverters.py new file mode 100644 index 0000000..d40c96e --- /dev/null +++ b/discord/api/dataConverters.py @@ -0,0 +1,81 @@ +import inspect +import typing + +from discord.types.snowflake import Snowflake +from ..channels.guildchannel import TextChannel + +from ..guild import Guild +from ..message import Message +from ..user.user import User +from ..user.member import Member + + +class DataConverter: + def __init__(self, client): + self.client = client + self.converters = {} + for name, func in inspect.getmembers(self): + if name.startswith("convert_"): + self.converters[name[8:]] = func + + def _get_channel(self, id): + return None # TODO: get channel from cache + + def convert_event_error(self, data): + return [data] + + def convert_message_create(self, data): + return [Message(self.client, **data)] + + def convert_ready(self, data): + return [] + + def convert_guild_create(self, data): + members = data["members"] + guild = Guild(self.client, **data) + self.client.ws.guild_cache[Snowflake(data["id"])] = guild + + for member_data in members: + user_data = member_data["user"] + member_data.pop("user", None) + member_data["guild"] = guild + member_data["guild_avatar"] = member_data.get("avatar") + member_data.pop("avatar", None) + + self.client.ws.member_cache[Snowflake(user_data["id"])] = Member(self.client, **member_data, **user_data) + self.client.ws.user_cache[Snowflake(user_data["id"])] = User(self.client, **user_data) + + for channel_data in data["channels"]: + self.client.ws.channel_cache[Snowflake(channel_data["id"])] = TextChannel(self.client, **channel_data) + + return [guild] + + def convert_presence_update(self, data): + return [data] + + def convert_typing_start(self, data): + return [data] + + def convert_guild_member_update(self, data): + return [data] + + def convert_interaction_create(self, payload): + message = payload.get("message") + + if message: + message = Message(self.client, **payload) + + if payload["type"] == 3: + component = self.client.httphandler.component_cache.get(payload["data"]["custom_id"]) + + # When the bot restarts the previously cached components are gone + if component: # so check if the component is a newly created + self.client._loop.create_task(component.run_callback(message, payload["data"])) + + return [payload] + + def convert(self, event, data): + func: typing.Callable = self.converters.get(event) + if not func: + return [data] + return func(data) diff --git a/discord/api/handler.py b/discord/api/httphandler.py similarity index 86% rename from discord/api/handler.py rename to discord/api/httphandler.py index dfeaa07..3f9c9ff 100644 --- a/discord/api/handler.py +++ b/discord/api/httphandler.py @@ -14,11 +14,13 @@ DiscordNotFound, DiscordServerError) -class Handler: +class HTTPHandler: def __init__(self): self.base_url: str = "https://discord.com/api/v9/" self.user_agent: str = "Disthon Discord API wrapper V0.0.1b" + self.component_cache = {} + async def request( self, method: str, @@ -79,7 +81,19 @@ async def connect(self, url: str) -> aiohttp.ClientWebSocketResponse: return await self._session.ws_connect(url, **kwargs) async def close(self) -> None: - await self._session.close() + if self._session: + await self._session.close() + + async def get_from_cdn(self, url: str) -> bytes: + async with self._session.get(url) as response: + if response.status == 200: + return await response.read() + elif response.status == 404: + raise DiscordNotFound("asset not found") + elif response.status == 403: + raise DiscordForbidden("cannot retrieve asset") + else: + raise DiscordHTTPException("failed to get asset", response.status) async def send_message( self, @@ -97,20 +111,30 @@ async def send_message( if content: payload["content"] = content + if embeds: - payload["embeds"] = [embed._to_dict() for embed in embeds] + payload["embeds"] = [embed.dict() for embed in embeds] + if views: - payload["components"] = [view._to_dict() for view in views] + def _cache_view_components(view: View): + for component in view.components: + self.component_cache[str(component.custom_id)] = component + return view._to_dict() + + payload["components"] = [_cache_view_components(view) for view in views] data = await self.request( "POST", f"channels/{channel_id}/messages", data=payload ) try: if isinstance(data, dict): - if data["code"] == 50008: - raise DiscordChannelNotFound - elif data["code"] == 10003: - raise DiscordChannelForbidden + code = data["code"] + if code == 50008: + raise DiscordChannelNotFound() + elif code == 10003: + raise DiscordChannelForbidden() + else: + raise DiscordHTTPException(data.get("message"), code) except KeyError: return data diff --git a/discord/api/websocket.py b/discord/api/websocket.py index ca875c7..7602e2f 100644 --- a/discord/api/websocket.py +++ b/discord/api/websocket.py @@ -12,6 +12,11 @@ import aiohttp from aiohttp.http_websocket import WSMessage, WSMsgType +from ..cache import LFUCache +from ..channels.basechannel import BaseChannel +from ..guild import Guild +from ..types.snowflake import Snowflake + if typing.TYPE_CHECKING: from ..client import Client @@ -39,6 +44,12 @@ def __init__(self, client, token: str) -> None: self.token = token self.session_id = None self.heartbeat_acked = True + self.closed: bool = False + + self.guild_cache = LFUCache[Snowflake, Guild](1000) + self.channel_cache = LFUCache[Snowflake, BaseChannel](5000) + self.member_cache = LFUCache[Snowflake, dict](5000) + self.user_cache = LFUCache[Snowflake, dict](5000) async def start( self, @@ -47,37 +58,43 @@ async def start( reconnect: typing.Optional[bool] = False ): if not url: - url = self.client.handler.gateway() - self.socket = await self.client.handler.connect(url) + url = self.client.httphandler.gateway() + self.socket = await self.client.httphandler.connect(url) await self.receive_events() await self.identify() if reconnect: await self.resume() else: - t = threading.Thread(target=self.keep_alive, daemon=True) - t.start() + self.hb_t: threading.Thread = threading.Thread(target=self.keep_alive, daemon=True) + self.hb_stop: threading.Event = threading.Event() + self.hb_t.start() return self + async def close(self) -> None: + """Closes the websocket""" + self.closed = True + await self.socket.close() + self.hb_stop.set() + def keep_alive(self) -> None: - while True: - time.sleep(self.hb_int) + while not self.hb_stop.wait(self.hb_int): if not self.heartbeat_acked: # We have a zombified connection - self.socket.close() + self.socket.close(code=1000) asyncio.run(self.start(reconnect=True)) else: asyncio.run(self.heartbeat()) - def on_websocket_message(self, msg: WSMessage) -> dict: + def on_websocket_message(self, msg: WSMessage) -> dict: if type(msg) is bytes: # always push the message data to your cache self.buffer.extend(msg) # check if last 4 bytes are ZLIB_SUFFIX if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff": - return + return msg - msg = self.decompress.decompress(self.buffer) + msg: bytes = self.decompress.decompress(self.buffer) msg = msg.decode("utf-8") self.buffer = bytearray() @@ -94,8 +111,11 @@ async def receive_events(self) -> None: aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED, ): - await self.socket.close() - raise ConnectionResetError(msg.extra) + if not self.closed: + await self.socket.close() + raise ConnectionResetError(msg.extra) + else: + return msg = json.loads(msg) diff --git a/discord/asset.py b/discord/asset.py new file mode 100644 index 0000000..4f1de9b --- /dev/null +++ b/discord/asset.py @@ -0,0 +1,24 @@ +import io +from typing import Union, TYPE_CHECKING + +from pydantic import BaseModel + + +if TYPE_CHECKING: + from discord import Client + + +class Asset(BaseModel): + _client: "Client" + url: str + + async def read(self): + return await self._client.httphandler.get_from_cdn(self.url) + + async def save(self, fp: Union[str, io.BufferedIOBase]): + data = await self.read() + if isinstance(fp, str): + with open(fp, "wb+") as file: + return file.write(data) + + return fp.write(data) \ No newline at end of file diff --git a/discord/cache.py b/discord/cache.py index 9045488..b20fee6 100644 --- a/discord/cache.py +++ b/discord/cache.py @@ -1,122 +1,46 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, OrderedDict +from typing import TYPE_CHECKING, Any, Dict, OrderedDict + +from .types.snowflake import Snowflake + if TYPE_CHECKING: - from .api.handler import Handler + from .api.httphandler import HTTPHandler from .guild import Guild from .message import Message from .role import Role - from .types.snowflake import Snowflake from .user.member import Member from .user.user import User -class LFUCache: - - capacity: int - _cache: OrderedDict[Snowflake, Any] - _frequency: dict[Snowflake, int] - handler: Handler - +class LFUCache(OrderedDict): def __init__(self, capacity: int) -> None: - self.capacity = capacity - self._frequency = {} - self.length = 0 + self.capacity: int = capacity + self._frequency: Dict[Snowflake, int] = {} + self.length: int = 0 + super().__init__() - @classmethod - def _from_lfu(cls, lfu: LFUCache): - self = cls.__new__(cls) - self.capacity = lfu.capacity - self._cache = lfu._cache - self._frequency = lfu._frequency - return self + def __setitem__(self, key: Snowflake, value: Any) -> None: + frequency = self._frequency - def __eq__(self, other) -> bool: - return ( - isinstance(other, LFUCache) - and other._cache == self._cache - and self.capacity == other.capacity - ) + if key not in self: + self.length += 1 - def __ne__(self, other) -> bool: - return not self.__eq__(other) + super().__setitem__(key, value) + frequency[key] = 0 - def __setitem__(self, key: Snowflake, value: Any) -> None: - if key not in self._cache: - self.length += 1 - self._cache[key] = value - if self._frequency[key]: - self._frequency[key] += 1 - else: - self._frequency[key] = 0 if self.length > self.capacity: - snowflake: Snowflake - min_freq = float("inf") - for k in self._frequency.keys(): - if self._frequency[k] < min_freq: - min_freq = self._frequency[k] - snowflake = k - del self._cache[snowflake] + inverted = dict(zip(frequency.values(), frequency.keys())) + least_used = min(self._frequency.values()) + del self[inverted[least_used]] def __getitem__(self, key: Snowflake): - if self._cache[key]: - self._frequency[key] += 1 - return self._cache[key] - raise KeyError + self._frequency[key] += 1 + return self[key] def __delitem__(self, key: Snowflake): - del self._cache[key] + super().__delitem__(key) del self._frequency[key] self.length -= 1 - - -class UserCache(LFUCache): - _cache: dict[Snowflake, User] - - def __init__(self) -> None: - super().__init__(100000) - - def __setitem__(self, key: Snowflake, value: User) -> None: - return super().__setitem__(key, value) - - -class MemberCache(LFUCache): - _cache: dict[Snowflake, Member] - - def __init__(self) -> None: - super().__init__(100000) - - def __setitem__(self, key: Snowflake, value: Member) -> None: - return super().__setitem__(key, value) - - -class MessageCache(LFUCache): - _cache: dict[Snowflake, Message] - - def __init__(self) -> None: - super().__init__(2000) - - def __setitem__(self, key: Snowflake, value: Message) -> None: - return super().__setitem__(key, value) - - -class RoleCache(LFUCache): - _cache: dict[Snowflake, Role] - - def __init__(self) -> None: - super().__init__(250) - - def __setitem__(self, key: Snowflake, value: Role) -> None: - return super().__setitem__(key, value) - - -class GuildCache(LFUCache): - _cache: dict[Snowflake, Guild] - - def __init__(self) -> None: - super().__init__(20000) - - def __setitem__(self, key: Snowflake, value: Guild) -> None: - return super().__setitem__(key, value) diff --git a/discord/channels/___init__.py b/discord/channels/__init__.py similarity index 100% rename from discord/channels/___init__.py rename to discord/channels/__init__.py diff --git a/discord/channels/basechannel.py b/discord/channels/basechannel.py index da6f910..41600c0 100644 --- a/discord/channels/basechannel.py +++ b/discord/channels/basechannel.py @@ -1,26 +1,19 @@ from __future__ import annotations from ..abc.discordobject import DiscordObject +from ..abc.messageable import Messageable from ..types.snowflake import Snowflake -class BaseChannel(DiscordObject): - __slots__ = ("_id", "_name", "_mention") +class BaseChannel(DiscordObject, Messageable): + name: str - _id: Snowflake - _name: str - - @property - def id(self) -> Snowflake: - return self._id - - @property - def name(self): - return self._name + def _get_channel(self): + return self @property def mention(self): - return f"<#{self._id}>" + return f"<#{self.id}>" @property def created_at(self): diff --git a/discord/channels/dmchannel.py b/discord/channels/dmchannel.py index c19a535..2b59b43 100644 --- a/discord/channels/dmchannel.py +++ b/discord/channels/dmchannel.py @@ -1,5 +1,20 @@ from __future__ import annotations +from typing import List -class DMChannel: - pass + +from discord.abc.discordobject import DiscordObject +from discord.types.image import Image +from discord.user.user import User + + +class DMChannel(DiscordObject): + type: int + recipients: List[User] + last_message_id: int + + +class GroupDMChannel(DMChannel): + name: str + icon: Image + owner_id: int diff --git a/discord/channels/guildchannel.py b/discord/channels/guildchannel.py index 8326c96..3fc2c91 100644 --- a/discord/channels/guildchannel.py +++ b/discord/channels/guildchannel.py @@ -1,17 +1,22 @@ from __future__ import annotations -from ..guild.guild import GuildChannel -import ..abc +from typing import TYPE_CHECKING, List, Optional, Union +from ..message import Message from .basechannel import BaseChannel -class TextChannel(GuildChannel): - __slots__ = ( - "name", - "id", - "guild", - "nsfw", - "category_id", - "position" - ) - \ No newline at end of file +if TYPE_CHECKING: + from ..embeds import Embed + from ..interactions.components import View + + +class TextChannel(BaseChannel): + ... + + +class ThreadChannel(BaseChannel): + ... + + +class VoiceChannel(BaseChannel): + ... diff --git a/discord/client.py b/discord/client.py index f963cd4..2546313 100644 --- a/discord/client.py +++ b/discord/client.py @@ -7,49 +7,82 @@ import typing from copy import deepcopy -from .api.handler import Handler +from .api.dataConverters import DataConverter +from .api.httphandler import HTTPHandler from .api.intents import Intents from .api.websocket import WebSocket +from .commands.core import Command +from .commands.parser import CommandParser +from .commands.help_command import HelpCommand, DefaultHelpCommand + + +if typing.TYPE_CHECKING: + from . import Message + class Client: + + async def handle_event_error(self, error): + print(f"Ignoring exception in event {error.event.__name__}", file=sys.stderr) + traceback.print_exception( + type(error), error, error.__traceback__, file=sys.stderr + ) + + async def handle_commands(self, message: Message): + await self.process_commands(message) + def __init__( self, + command_prefix: str, *, intents: typing.Optional[Intents] = Intents.default(), + help_command: HelpCommand = DefaultHelpCommand(), respond_self: typing.Optional[bool] = False, + case_sensitive: bool=True, loop: typing.Optional[asyncio.AbstractEventLoop] = None, ) -> None: - self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop() + self._loop: asyncio.AbstractEventLoop = None # create the event loop when we run our client self.intents = intents self.respond_self = respond_self self.stay_alive = True - self.handler = Handler() + self.httphandler = HTTPHandler() self.lock = asyncio.Lock() self.closed = False - self.events = {} + self.events = {"message_create": [self.handle_commands], "event_error": [self.handle_event_error]} + self.once_events = {} + + self.command_prefix = command_prefix + self.commands: typing.Dict[str, Command] = {} + + self.converter = DataConverter(self) + self.command_parser = CommandParser(self.command_prefix, self.commands, case_sensitive) + + if help_command: + self.add_command(help_command) async def login(self, token: str) -> None: self.token = token async with self.lock: - self.info = await self.handler.login(token) + self.info = await self.httphandler.login(token) async def connect(self) -> None: while not self.closed: socket = WebSocket(self, self.token) async with self.lock: - g_url = await self.handler.gateway() + g_url = await self.httphandler.gateway() if not isinstance(self.intents, Intents): raise TypeError( f"Intents must be of type Intents, got {self.intents.__class__}" ) self.ws = await asyncio.wait_for(socket.start(g_url), timeout=30) - while True: + while not self.closed: await self.ws.receive_events() async def alive_loop(self, token: str) -> None: + self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() await self.login(token) try: await self.connect() @@ -57,29 +90,46 @@ async def alive_loop(self, token: str) -> None: await self.close() async def close(self) -> None: - await self.handler.close() + self.closed = True + await self.ws.close() + await self.httphandler.close() def run(self, token: str): - def stop_loop_on_completion(_): - self._loop.stop() - - future = asyncio.ensure_future(self.alive_loop(token), loop=self._loop) - future.add_done_callback(stop_loop_on_completion) + if not self._loop: + asyncio.run(self.alive_loop(token)) + else: + self._loop.run_forever(self.alive_loop(token)) - self._loop.run_forever() + def on(self, event: str = None, *, overwrite: bool = False): + def wrapper(func): + self.add_listener(func, event, overwrite=overwrite, once=False) + return func - if not future.cancelled(): - return future.result() + return wrapper - def event(self, event: str = None): + def once(self, event: str = None, *, overwrite: bool = False): def wrapper(func): - self.add_listener(func, event) + self.add_listener(func, event, overwrite=overwrite, once=True) return func return wrapper + def command(self, name=None, **kwargs): + """The decorator used to register functions as commands""" + def inner(func) -> Command: + command = Command(func, name, **kwargs) + self.add_command(command) + return command + + return inner + def add_listener( - self, func: typing.Callable, event: typing.Optional[str] = None + self, + func: typing.Callable, + event: typing.Optional[str] = None, + *, + overwrite: bool = False, + once: bool = False, ) -> None: event = event or func.__name__ if not inspect.iscoroutinefunction(func): @@ -87,25 +137,69 @@ def add_listener( "The callback is not a valid coroutine function. Did you forget to add async before def?" ) - if event in self.events: - self.events[event].append(func) - else: - self.events[event] = [func] + if once: # if it's a once event + if event in self.once_events and not overwrite: + self.once_events[event].append(func) + else: + self.once_events[event] = [func] + else: # if it's a regular event + if event in self.events and not overwrite: + self.events[event].append(func) + else: + self.events[event] = [func] async def handle_event(self, msg): - event: str = "on_" + msg["t"].lower() + event: str = msg["t"].lower() - # create a global on_message event for either guild or dm messages - if event in ("on_message_create", "on_dm_message_create"): - global_message = deepcopy(msg) - global_message["t"] = "MESSAGE" - await self.handle_event(global_message) + args = self.converter.convert(event, msg["d"]) for coro in self.events.get(event, []): try: - await coro(msg) + self._loop.create_task(coro(*args)) + except Exception as error: + error.event = coro + await self.handle_event({"d": error, "t": "event_error"}) + + for coro in self.once_events.pop(event, []): + try: + self._loop.create_task(coro(*args)) except Exception as error: - print(f"Ignoring exception in event {coro.__name__}", file=sys.stderr) - traceback.print_exception( - type(error), error, error.__traceback__, file=sys.stderr - ) + error.event = coro + await self.handle_event({"d": error, "t": "event_error"}) + + def add_command(self, command: Command): + if command.name in self.commands: + raise ValueError("Duplicate command name") + self.commands[command.name] = command + return command + + def remove_command(self, command: Command): + return self.commands.pop(command.name) + + def get_command_named(self, name: str) -> typing.Optional[Command]: + for command_name, command in self.commands.items(): + if command.is_regex_command: + if command.regex_match_func(command_name, name, command.regex_flags): + return command + + elif command_name == name: + return command + + async def process_commands(self, message: Message): + """Command handling""" + from .commands.context import Context + + if message.author.bot: + return + + command, args, kwargs, extra_kwargs = self.command_parser.parse_message(message) + context = Context(client=self, message=message, command=command) + + if command: + await command.execute(context, *args, **kwargs, **extra_kwargs) + + def get_guild(self, id: int): + return self.ws.guild_cache.get(id) + + def get_user(self, id: int): + return self.ws.user_cache.get(id) diff --git a/discord/color.py b/discord/color.py index 47d1db8..7734367 100644 --- a/discord/color.py +++ b/discord/color.py @@ -4,6 +4,7 @@ import random import re from typing import Optional, Union + from pydantic import BaseModel from .exceptions import InvalidColor diff --git a/discord/commands/__init__.py b/discord/commands/__init__.py new file mode 100644 index 0000000..634d40e --- /dev/null +++ b/discord/commands/__init__.py @@ -0,0 +1 @@ +from .core import Command, check diff --git a/discord/commands/context.py b/discord/commands/context.py new file mode 100644 index 0000000..6406d41 --- /dev/null +++ b/discord/commands/context.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Optional, Any + +from pydantic import BaseModel + +from discord.abc.messageable import Messageable + +from discord.commands import Command + + +class Context(BaseModel, Messageable): + """The context model which will be used in commands""" + class Config: + arbitrary_types_allowed = True + + client: Any + message: Any # To avoid circular import errors + command: Optional[Command] = None + + @property + def _client(self): + return self.client # Make an alias for client because Messageable uses it + + @property + def guild(self): + """The guild the command was used in""" + return self.message.guild + + @property + def channel(self): + """The channel the command was used in""" + return self.message.channel + + @property + def author(self): + return self.message.author + + def _get_channel(self): + return self.channel + + +Context.update_forward_refs() diff --git a/discord/commands/core.py b/discord/commands/core.py new file mode 100644 index 0000000..5d70987 --- /dev/null +++ b/discord/commands/core.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import asyncio +import re +from typing import Optional, Any, Callable, Iterable, TYPE_CHECKING + +from discord.message import Message +from .errors import CheckFailure + +if TYPE_CHECKING: + from .context import Context + + +NOT_ASYNC_FUNCTION_MESSAGE = ( + "{0} must be coroutine.\nMaybe you forgot to add the 'async' keyword?." +) + + +class Command: + def __init__( + self, + callback: Callable[[Any], Any], + name: str = None, + *, + qualified_name: str = None, + aliases: Iterable[str] = None, + regex_command: bool = False, + regex_match_func=re.match, + regex_flags=0, + **kwargs + ): + if not asyncio.iscoroutinefunction(callback): + raise ValueError(NOT_ASYNC_FUNCTION_MESSAGE.format("Command callback")) + + self._callback = callback + self.name = name or self._callback.__name__ + self.is_regex_command = regex_command + self.regex_match_func = regex_match_func + self.regex_flags = regex_flags + self.qualified_name = qualified_name + + if self.qualified_name is None and regex_command is True: + raise TypeError("You need to supply the qualified_name for regex commands") + + elif self.qualified_name is None and regex_command is False: + self.qualified_name = name + + if aliases is None: + self.aliases = [] + else: + self.aliases = aliases + + self.checks = [] + self.description = kwargs.get("description") or self._callback.__doc__ + self.on_error: Optional[Callable[[Any], Any]] = None + + @property + def callback(self): + return self._callback + + def add_check(self, function: Callable[[Message], bool]): + self.checks.append(function) + + def error(self, function): + if not asyncio.iscoroutinefunction(function): + raise ValueError(NOT_ASYNC_FUNCTION_MESSAGE.format("Command error handler")) + + self.on_error = function + return function + + async def run_checks(self, context): + for check in self.checks: + if asyncio.iscoroutinefunction(check): + result = await check(context) + else: + result = check(context) + + if result is not True: + raise CheckFailure(self) + + + async def execute(self, context: Context, *args, **kwargs): + """Runs the checks and execute the command""" + await self.run_checks(context) + + try: + await self.callback(context, *args, **kwargs) + except Exception as error: + if self.on_error: + await self.on_error(context, error) + else: + raise error + + async def __call__(self, context: Context, *args, **kwargs): + """Execute the command when the instance is called + NOTE: This method does not validate checks""" + await self.callback(context, *args, **kwargs) + + +def check(function: Callable[[Message], bool]): + + def inner(command: Command): + command.add_check(function) + return command + + return inner diff --git a/discord/commands/errors.py b/discord/commands/errors.py new file mode 100644 index 0000000..7ff00eb --- /dev/null +++ b/discord/commands/errors.py @@ -0,0 +1,15 @@ +class CommandError(Exception): + pass + + +class CheckFailure(CommandError): + def __init__(self, command): + self.command = command + super().__init__(f"Check failed for command {command.name}") + + +class CommandNotFound(CommandError): + def __init__(self, command_name): + self.command_name = command_name + + super().__init__(f"Command with the name {self.command_name} does not exist") diff --git a/discord/commands/help_command.py b/discord/commands/help_command.py new file mode 100644 index 0000000..e3a4bdb --- /dev/null +++ b/discord/commands/help_command.py @@ -0,0 +1,69 @@ +import abc +import sys +import traceback +import typing + +from discord.embeds import Embed + +from .context import Context +from .core import Command +from .errors import CommandNotFound + + +async def help_cmd_callback(context: Context, *args): pass + + +class HelpCommand(abc.ABC, Command): + def __init__(self, + name: str = "help", + *, + description: str = None, + aliases: typing.Iterable[str] = None, + regex_command: bool = False, + regex_flags=0 + ): + help_cmd_callback.__doc__ = description + super().__init__(callback=help_cmd_callback, name=name, aliases=aliases, regex_command=regex_command, + regex_flags=regex_flags) + self.on_error = self.on_help_error + + async def execute(self, context: Context, *args, **kwargs): + await self.run_checks(context) + + client = context.client + + target_command_name = args[0] if args else None + + if target_command_name is None: + await self.send_bot_help(context) + return + + command = client.get_command_named(target_command_name) + + if command is None: + error = CommandNotFound(target_command_name) + await self.on_error(context, error) + return + + await self.send_command_help(context, command) + + @abc.abstractmethod + async def send_bot_help(self, context: Context): + pass + + @abc.abstractmethod + async def send_command_help(self, context: Context, command: Command): + pass + + async def on_help_error(self, context: Context, error): + traceback.print_exception(type(error), error, error.__traceback__, file=sys.stderr) + + +class DefaultHelpCommand(HelpCommand): + async def send_bot_help(self, context: Context): + embed = Embed(title="Help", description=f"All commands: {', '.join([str(cmd.qualified_name) for cmd in context.client.commands.values()])}") + await context.send(embeds=embed) + + async def send_command_help(self, context: Context, command: Command): + embed = Embed(title=command.qualified_name, description=f"Description: {command.description}") + await context.send(embeds=embed) diff --git a/discord/commands/parser.py b/discord/commands/parser.py new file mode 100644 index 0000000..0a40155 --- /dev/null +++ b/discord/commands/parser.py @@ -0,0 +1,93 @@ +import inspect +import re +from typing import List, Dict, Tuple, Optional, Union, Any + +from discord.message import Message +from . import Command +from .core import Command + + +class CommandParser: + def __init__( + self, + command_prefix: str, + commands: Dict[str, Command], + case_sensitive: bool = True, + ): + self.command_prefix = command_prefix + self.case_sensitive = case_sensitive + self.commands = commands + + def remove_prefix(self, content: str): + command_prefix = self.command_prefix + if isinstance(command_prefix, (tuple, list, set)): + for prefix in self.command_prefix: + if content.startswith(prefix): + return content[len(prefix):] + + elif isinstance(command_prefix, str): + return content[len(command_prefix):] + + def get_args(self, command: Command, content: str, prefix=None): + prefix = prefix or command.name + content = content[len(prefix):].strip() + signature = inspect.signature(command.callback) + parameters = signature.parameters.copy() + parameters.pop(tuple(parameters.keys())[0]) # Remove the context parameter + + args = content.split() + positional_arguments = [] + kwargs = {} + extra_kwargs = {} + + for index, (name, parameter) in enumerate(parameters.items()): + if parameter.kind is parameter.KEYWORD_ONLY: + kwargs[name] = (" ".join(args[index:])).strip() + + elif parameter.kind is parameter.VAR_KEYWORD: + extra_kwargs = dict(re.findall(r"(\D+)=(\w+)", content)) + + elif parameter.kind is parameter.VAR_POSITIONAL: + positional_arguments.extend(args) + + else: + positional_arguments.append(args[index].strip()) + + return positional_arguments, kwargs, extra_kwargs + + def parse_message(self, message: Message): + empty = None, [], {}, {} + + if not self.commands: + return empty + + if not message.content.startswith(self.command_prefix): + return empty + + no_prefix = self.remove_prefix(message.content) # The content of the message but without the command_prefix + + if not self.case_sensitive: + no_prefix = no_prefix.lower() + + command = self.commands.get(no_prefix.split()[0]) + if command: + return command, *self.get_args(command, no_prefix) + + for regex_command in filter( + lambda cmd: cmd.is_regex_command, self.commands.values() + ): + regex = regex_command.name + regex_match_func = regex_command.regex_match_func + regex_flags = regex_command.regex_flags + + if match := regex_match_func(regex, no_prefix, regex_flags): + try: + prefix = match.group(1) + except IndexError: + raise ValueError( + "First match group of command regex does not exist" + ) + + return regex_command, *self.get_args(regex_command, no_prefix, prefix=prefix) + + return empty diff --git a/discord/embeds.py b/discord/embeds.py index a6b48de..3338630 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -106,12 +106,12 @@ def _validate_url(cls, url): class Embed(BaseModel): - color: Union[Color, int] = Color.random() + color: Union[Color, int] = None title: Optional[str] _type: Final[EmbedType] = EmbedType.rich author: Optional[EmbedAuthor] url: Optional[str] - description = Optional[str] + description: Optional[str] timestamp: Optional[Arrow] thumbnail: Optional[EmbedMedia] = None image: Optional[EmbedMedia] = None diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py deleted file mode 100644 index bbca144..0000000 --- a/discord/ext/commands/context.py +++ /dev/null @@ -1,6 +0,0 @@ -from typing import Optional, List, Any, Dict -from discord.user.user import User -from ...abc.discordobject import DiscordObject - -class Context: - pass diff --git a/discord/guild.py b/discord/guild.py index f302a06..fadeb1e 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -1,16 +1,18 @@ from __future__ import annotations -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional, List from .abc.discordobject import DiscordObject -from .channels.guildchannel import GuildChannel -from .role import Role -from .types.guildpayload import GuildPayload +from .channels.guildchannel import TextChannel, VoiceChannel from .types.snowflake import Snowflake from .user.member import Member from .user.user import User +if TYPE_CHECKING: + from .role import Role + + class BanEntry(NamedTuple): user: User reason: Optional[str] @@ -27,39 +29,9 @@ class GuildLimit(NamedTuple): class Guild(DiscordObject): - __slots__ = ( - "region" - "owner_id" - "mfa.level" - "name" - "id" - "_members" - "_channels" - "_vanity" - "_banner" - ) - - _roles: set[Role] - me: Member + owner: bool = False owner_id: Snowflake - - def __init__(self, data: GuildPayload): - self._members: dict[Snowflake, Member] = {} - self._channels: dict[Snowflake, GuildChannel] = {} - self._roles = set() - - def _add_channel(self, channel: GuildChannel, /) -> None: - self._channels[channel.id] = channel - - def _delete_channel(self, channel: DiscordObject) -> None: - self._channels.pop(channel.id, None) - - def add_member(self, member: Member) -> None: - self._members[member.id] = member - - def add_roles(self, role: Role) -> None: - for p in self._roles.values: - p.postion += not p.is_default() - # checks if role is @everyone or not - - self._roles[role.id] = role + members: List[dict] + roles: List[dict] + emojis: List[dict] + stickers: List[dict] diff --git a/discord/guild/guild.py b/discord/guild/guild.py deleted file mode 100644 index b66e101..0000000 --- a/discord/guild/guild.py +++ /dev/null @@ -1,132 +0,0 @@ -from typing import ( - NamedTuple, - Optional, - List, - MISSING -) - -from discord.abc.discordobject import DiscordObject -from discord.channels.guildchannel import GuildChannel -from discord.member.member import Member -from discord.role.role import Role -from discord.types.guildpayload import GuildPayload -from discord.types.snowflake import Snowflake -from discord.user.user import User - -from ..abc.discordobject import DiscordObject -from ..channels.guildchannel import CategoryChannel, GuildChannel -from ..role import Role -from ..user.member import Member - - -class BanEntry(NamedTuple): - user: User - reason: Optional[str] - - -class GuildLimit(NamedTuple): - filesize: int - emoji: int - channels: int - roles: int - categories: int - bitrate: int - stickers: int - - -class Guild(DiscordObject): - __slots__ = ( - "region" - "owner_id" - "mfa.level" - "name" - "id" - "_members" - "_channels" - "_vanity" - "_banner" - ) - - _roles: set[Role] - me: Member - - - def __init__(self, data: GuildPayload): - self._members: dict[Snowflake, Member] = {} - self._channels: dict[Snowflake, GuildChannel] = {} - self._roles = set() - - def _add_channel(self, channel: GuildChannel, /) -> None: - self._channels[channel.id] = channel - - def _delete_channel(self, channel: DiscordObject) -> None: - self._channels.pop(channel.id, None) - - def add_member(self, member: Member) -> None: - self._members[member.id] = member - - def add_roles(self, role: Role) -> None: - for p in self._roles.values: - p.postion += not p.is_default() - # checks if role is @everyone or not - - self._roles[role.id] = role - def remove_roles(self, role:Role) -> None: - - def remove_roles(self, role: Role) -> None: - role = self._roles.pop(role.id) - - for p in self._roles.values: - p.position -= p.position > role.position - - return role - - @property - async def channels(self) -> List[GuildChannel]: - return list(self._channels.values()) - - @property - async def roles(self) -> List[Role]: - return sorted(self._roles.values()) - - @property - async def owner(self) -> Optional[Member]: - return self.get_member(self.owner.id) - - @property - async def members(self) -> List[Member]: - return list(self._members.values()) - - def get_member(self, member_id: int) -> Optional[Member]: - return self._members.get(member_id) - - def get_channel(self, channel_id: int) -> Optional[GuildChannel]: - return self._channels(channel_id) - - async def create_channel( - self, - *, - name: str, - type: str = None, - reason: Optional[str] = None, - category: Optional[CategoryChannel] = None, - position: int = None, - slowmode_delay: int = None, - ): - return - - async def delete_channel( - self, *, channel: GuildChannel, reason: Optional[str] = None - ): - pass - - async def edit_channel( - self, - *, - name: Optional[str] = None, - position: Optional[int] = None, - slowmode_delay: Optional[int] = None, - category: Optional[CategoryChannel] = None, - ): - pass - \ No newline at end of file diff --git a/discord/interactions/__init__.py b/discord/interactions/__init__.py index b4b6730..10f8bc6 100644 --- a/discord/interactions/__init__.py +++ b/discord/interactions/__init__.py @@ -1 +1 @@ -from .components import Component, View +from .components import Component, Button, View diff --git a/discord/interactions/components.py b/discord/interactions/components.py index 2cf55eb..da06402 100644 --- a/discord/interactions/components.py +++ b/discord/interactions/components.py @@ -1,5 +1,7 @@ import os -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Callable, Any + +from discord.utils import maybe_await OptInt = Optional[int] OptStr = Optional[str] @@ -7,18 +9,18 @@ class Component: def __init__( - self, - type: int, - disabled: bool = None, - style: OptInt = None, - label: OptStr = None, - emoji: OptStr = None, - url: OptStr = None, - options: list = None, - placeholder: OptStr = None, - min_values: OptInt = None, - max_values: OptInt = None, - custom_id: OptStr = None, + self, + type: int, + disabled: bool = None, + style: OptInt = None, + label: OptStr = None, + emoji: OptStr = None, + url: OptStr = None, + options: list = None, + placeholder: OptStr = None, + min_values: OptInt = None, + max_values: OptInt = None, + custom_id: OptStr = None, ): self.type: int = type self.disabled: bool = disabled @@ -35,14 +37,52 @@ def __init__( if self.custom_id is None and self.url is None: self.custom_id = os.urandom(16).hex() + async def run_callback(self, *args, **kwargs): + raise NotImplementedError + def _to_dict(self): - return {k: v for k, v in self.__dict__.items() if v is not None} + exclude = ("_callback", "callback", "run_callback") + return {k: v for k, v in self.__dict__.items() if v is not None and k not in exclude} + + +class Button(Component): + def __init__(self, style: int, *, label: str = None, emoji: dict = None, url: str = None, disabled: bool = None, + callback: Callable[[Any, Any], Any] = None): + super().__init__(type=2, style=style, label=label, emoji=emoji, url=url, disabled=disabled) + self._callback = callback + + async def run_callback(self, *args, **kwargs): + """Runs the callback function with the given arguments""" + if hasattr(self, "callback") and self._callback: + raise ValueError("Callback is specified twice") + + if hasattr(self, "callback"): + return await maybe_await(self.callback, *args, **kwargs) + + elif self._callback: + return await maybe_await(self._callback, *args, **kwargs) + + else: + raise ValueError("Callback not specified") + + def __init_subclass__(cls, **kwargs): + # Check if subclasses implement the callback method + super().__init_subclass__(**kwargs) + + if not hasattr(cls, "callback"): + raise TypeError("Subclasses of Button must implement callback method") class View: def __init__(self, *components: Component): self.components: Tuple[Component] = components + async def run_component_callback(self, custom_id: str, *args, **kwargs): + """Runs the callback of the component with the given custom_id""" + for component in self.components: + if component.custom_id == custom_id: + await component.run_callback(*args, **kwargs) + def _to_dict(self): return { "type": 1, diff --git a/discord/interactions/interaction.py b/discord/interactions/interaction.py new file mode 100644 index 0000000..e69de29 diff --git a/discord/message.py b/discord/message.py index 92a7f06..43f49d2 100644 --- a/discord/message.py +++ b/discord/message.py @@ -1,5 +1,33 @@ from __future__ import annotations +from typing import Optional -class Message: - pass +from .abc.discordobject import DiscordObject +from .types.snowflake import Snowflake +from .user.user import User + + +class Message(DiscordObject): + channel_id: Snowflake + guild_id: Optional[Snowflake] = None + content: Optional[str] + author: Optional[User] = None + + def __init__(self, client, **data): + if data.get("author"): + data["author"] = User(client, **data["author"]) + super().__init__(client, **data) + + def __str__(self): + return self.content + + def __repr__(self): + return f"" + + @property + def guild(self): + return self._client.ws.guild_cache.get(self.guild_id) + + @property + def channel(self): + return self._client.ws.channel_cache.get(self.channel_id) diff --git a/discord/role.py b/discord/role.py index bb914b9..88ce682 100644 --- a/discord/role.py +++ b/discord/role.py @@ -1,20 +1,22 @@ from __future__ import annotations -from typing import Optional, TypeVar +from typing import TYPE_CHECKING, Optional, Any from .abc.discordobject import DiscordObject +from .asset import Asset from .color import Color -from .guild import Guild + +if TYPE_CHECKING: + from .guild import Guild __all__ = ("RoleTags", "Role") -from .cache import RoleCache +from .cache import LFUCache from .types.rolepayload import RolePayload, RoleTagsPayload from .types.snowflake import Snowflake class RoleTags: - __slots__ = ("_bot_id", "_integration_id", "_premium_subscriber") _bot_id: Snowflake _integration_id: Snowflake _premium_subscriber: bool @@ -38,50 +40,24 @@ def __repr__(self): class Role(DiscordObject): - __slots__ = ( - "_guild", - "_cache", - "_id", - "_name", - "_color", - "_hoist", - "_position", - "_permissions", - "_managed", - "_mentionable", - "_tags", - ) - - _guild: Guild - _cache: RoleCache - _id: Snowflake - _name: str - _color: Color - _hoist: bool - _position: int - _permissions: str - _managed: bool - _mentionable: bool - _tags: Optional[RoleTags] - - def __init__(self, guild: Guild, cache: RoleCache, payload: RolePayload): - self._guild = guild - self._cache = cache - self._id = payload["id"] - self._name = payload["name"] - self._color = payload["color"] - self._hoist = payload["hoist"] - self._position = payload["position"] - self._permissions = payload["permissions"] - self._managed = payload["managed"] - self._mentionable = payload["mentionable"] - self._tags = RoleTags(payload["tags"]) + id: Snowflake + guild: Any + name: str + color: Color + hoist: bool + icon: Optional[str] = None + unicode_emoji: Optional[str] = None + position: int + permissions: str + managed: bool + mentionable: bool + tags: Optional[RoleTags] = None def __str__(self): - return self._name + return self.name def __repr__(self): - return f"Role {self._name} with id {self._id}" + return f"Role {self.name} with id {self.id}" def __eq__(self, other): return ( @@ -132,43 +108,3 @@ def is_assignable(self): and not self.managed and (me.top_role > self or me.id == self.guild.owner_id) ) - - @property - def id(self) -> Snowflake: - return self._id - - @property - def guild(self) -> Guild: - return self._guild - - @property - def name(self) -> str: - return self._name - - @property - def color(self) -> Color: - return self._color - - @property - def hoist(self) -> bool: - return self._hoist - - @property - def position(self) -> int: - return self._position - - @property - def permissions(self) -> str: - return self._permissions - - @property - def managed(self) -> bool: - return self._managed - - @property - def mentionable(self) -> bool: - return self._mentionable - - @property - def tags(self) -> RoleTags: - return self._tags diff --git a/discord/types/avatar.py b/discord/types/avatar.py index 0bd9c85..431f4b2 100644 --- a/discord/types/avatar.py +++ b/discord/types/avatar.py @@ -3,12 +3,12 @@ from os.path import splitext from typing import Optional -from enums.validavatarformat import ValidAvatarFormat, ValidStaticAvatarFormat -from image import Image from yarl import URL from ..cache import LFUCache from ..exceptions import DiscordInvalidArgument +from .enums.validavatarformat import ValidAvatarFormat, ValidStaticAvatarFormat +from .image import Image class Avatar(Image): diff --git a/discord/types/enums/nsfwlevel.py b/discord/types/enums/nsfwlevel.py index a6a87b7..c24a108 100644 --- a/discord/types/enums/nsfwlevel.py +++ b/discord/types/enums/nsfwlevel.py @@ -1,7 +1,7 @@ from enum import IntEnum -class NSFWLevel(IntEnum, comparable=True): +class NSFWLevel(IntEnum): default = 0 explicit = 1 safe = 2 diff --git a/discord/types/enums/verificationlevel.py b/discord/types/enums/verificationlevel.py index e3b048f..5bdcdb5 100644 --- a/discord/types/enums/verificationlevel.py +++ b/discord/types/enums/verificationlevel.py @@ -1,7 +1,7 @@ from enum import IntEnum -class VerificationLevel(IntEnum, comparable=True): +class VerificationLevel(IntEnum): none = 0 low = 1 medium = 2 diff --git a/discord/types/guildpayload.py b/discord/types/guildpayload.py index fa07d5f..c6e09d0 100644 --- a/discord/types/guildpayload.py +++ b/discord/types/guildpayload.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Optional +from typing import Optional, Union from pydantic import BaseModel -from ..channels.guildchannel import GuildChannel +from ..channels.guildchannel import TextChannel, ThreadChannel, VoiceChannel from ..role import Role from ..user.member import Member from ..user.user import User @@ -35,9 +35,9 @@ class GuildPayload(BaseModel): member_count: int voice_states: list[GuildVoiceState] members: list[Member] - channels: list[GuildChannel] + channels: list[Union[TextChannel, VoiceChannel]] presences: list[PartialPresenceUpdate] - threads: list[Thread] + threads: list[ThreadChannel] max_presences: Optional[int] max_members: int premium_subscription_count: int diff --git a/discord/types/image.py b/discord/types/image.py index d585d7a..f1cd77d 100644 --- a/discord/types/image.py +++ b/discord/types/image.py @@ -4,14 +4,16 @@ from os import PathLike from typing import ClassVar, Optional, Union -from enums.imagetype import ImageType from pydantic import BaseModel from ..cache import LFUCache from ..exceptions import DiscordException, DiscordNotFound +from .enums.imagetype import ImageType class Image(BaseModel): + class Config: + arbitrary_types_allowed = True url: str format: ImageType diff --git a/discord/types/rolepayload.py b/discord/types/rolepayload.py index 8f6bacc..ec664b5 100644 --- a/discord/types/rolepayload.py +++ b/discord/types/rolepayload.py @@ -2,9 +2,8 @@ from typing import TypedDict -from snowflake import Snowflake - from ..color import Color +from .snowflake import Snowflake class RoleTagsPayload(TypedDict, total=False): diff --git a/discord/types/snowflake.py b/discord/types/snowflake.py index b09e264..bcacf45 100644 --- a/discord/types/snowflake.py +++ b/discord/types/snowflake.py @@ -1,44 +1,43 @@ -from __future__ import annotations - -from pydantic import BaseModel +import arrow __all__ = ("Snowflake",) -class Snowflake(BaseModel): - id: int - - def __eq__(self, other) -> bool: - if isinstance(other, (int, str)): - return self.id == other - return isinstance(other, Snowflake) and self.id == other.id - - def __ne__(self, other: object) -> bool: - return not self.__eq__(other) - - def __lt__(self, other: object): - if isinstance(other, int): - return self.id < other - if isinstance(other, str) and other.isdigit(): - return self.id < int(other) - if isinstance(other, Snowflake): - return self.id < other.id - raise NotImplementedError - - def __le__(self, other: object): - return self.__eq__(other) or self.__lt__(other) - - def __gt__(self, other: object): - return not self.__le__(other) - - def __ge__(self, other: object): - return not self.__lt__(other) - - def __str__(self) -> str: - return str(self.id) - - def __int__(self) -> int: - return int(self.id) - - def __repr__(self) -> str: - return f"Snowflake with id {self.id}" +class Snowflake(int): + @property + def timestamp(self): + return ((self >> 22) + 1420070400000) / 1000 + + @property + def worker_id(self): + return (self & 0x3E0000) >> 17 + + @property + def process_id(self): + return (self & 0x1F000) >> 12 + + @property + def increment(self): + return self & 0xFFF + + @property + def created_at(self): + return arrow.get(self.timestamp) + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value): + if isinstance(value, int): + return cls(value) + elif isinstance(value, str): + if value.isdigit(): + return cls(value) + else: + raise ValueError("Invalid Snowflake") + elif isinstance(value, Snowflake): + return value + else: + return ValueError("Invalid Snowflake") diff --git a/discord/types/userpayload.py b/discord/types/userpayload.py index 5f4fe01..e9e1f53 100644 --- a/discord/types/userpayload.py +++ b/discord/types/userpayload.py @@ -2,13 +2,13 @@ from typing import Optional, TypedDict -from enums.locale import Locale -from enums.userflags import UserFlags from pydantic.main import BaseModel -from snowflake import Snowflake from ..types.avatar import Avatar from ..types.banner import Banner +from .enums.locale import Locale +from .enums.userflags import UserFlags +from .snowflake import Snowflake class UserPayload(BaseModel): diff --git a/discord/user/baseuser.py b/discord/user/baseuser.py index 4155eb9..ddfd568 100644 --- a/discord/user/baseuser.py +++ b/discord/user/baseuser.py @@ -1,113 +1,36 @@ from __future__ import annotations -from abc.abstractuser import AbstractUser -from datetime import datetime -from types.avatar import Avatar -from types.banner import Banner -from types.enums.defaultavatar import DefaultAvatar -from types.enums.userflags import UserFlags -from types.userpayload import UserPayload from typing import Optional -from ..cache import GuildCache, UserCache +from ..abc.discordobject import DiscordObject from ..color import Color -from ..message import Message - - -class BaseUser(AbstractUser): - banner: Banner - system: bool - display_avatar: Avatar - display_name: str - public_flags: UserFlags - _cache: UserCache - guilds: GuildCache - - def __init__(self, cache: UserCache, guilds: GuildCache, payload: UserPayload): - self._cache = cache - self.guilds = guilds - self._id = payload.id - self._created_at = datetime.utcnow() - self.avatar = payload.avatar - self.username = payload.username - self.discriminator = payload.discriminator - self.banner = payload.banner - self.public_flags = payload.flags - self.system = payload.system or False - self.bot = payload.bot or False - - @classmethod - def _from_user(cls, user: BaseUser) -> BaseUser: - self = cls.__new__(cls) - self.avatar = user.avatar or user.default_avatar - self.banner = user.banner - self._cache = user._cache - self._created_at = user.created_at - self.discriminator = user.discriminator - self.display_name = user.display_name - self._id = user.id - self.public_flags = user.public_flags - self.system = user.system - self.username = user.username - return self - - async def create_dm(self): - pass - - async def fetch_message(self): - pass - - async def send( - self, - content: str = None, - *, - tts=None, - embeds: list[Message] = None, - files=None, - stickers=None, - delete_after=None, - nonce=None, - allowed_mentions: bool = None, - reference=None, - mention_author: bool = None, - view=None, - components=None - ): - pass - - async def edit(self, *, username: str = None, avatar: bytes = None): - pass - - async def typing(self): - pass - - @property - def default_avatar(self): - return self.avatar._from_default_avatar( - self._cache, int(self.discriminator) % len(DefaultAvatar) - ) +from ..types.banner import Banner +from ..types.enums.premiumtype import PremiumType +from ..types.enums.userflags import UserFlags + + +class BaseUser(DiscordObject): + username: str + discriminator: str + # avatar: Optional[Asset] TODO: Fix 'value is not a valid dict (type=type_error.dict)' error + bot: Optional[bool] = False + system: Optional[bool] = False + mfa_enabled: Optional[bool] + banner: Optional[Banner] + accent_color: Optional[Color] + flags: Optional[UserFlags] + premium_type: Optional[PremiumType] + public_flags: Optional[UserFlags] @property def color(self): return Color.default() + colour = color + @property - def colour(self): - return self.color + def mention(self): + return f"<@!{self.id}>" - def to_json(self): - return { - "username": self.username, - "discriminator": self.discriminator, - "tag": self.tag, - "id": self.id, - "created_at": self.created_at, - "avatar": self.avatar, - "default_avatar": self.default_avatar, - "display_avatar": self.display_avatar, - "bot": self.bot, - "system": self.system, - "public_flags": self.public_flags, - "display_name": self.display_name, - "banner": self.banner, - } + def __str__(self): + return f"{self.username}#{self.discriminator}" diff --git a/discord/user/clientuser.py b/discord/user/clientuser.py index 09cdc83..47ea98a 100644 --- a/discord/user/clientuser.py +++ b/discord/user/clientuser.py @@ -2,7 +2,6 @@ from typing import Optional -from ..cache import UserCache from ..types.enums.locale import Locale from ..types.enums.userflags import UserFlags from ..types.userpayload import UserPayload @@ -10,56 +9,8 @@ class ClientUser(BaseUser): - __slots__ = ( - "_id", - "_created_at", - "_avatar", - "_bot", - "_username", - "_discriminator", - "_mention", - "_cache", - "_banner", - "_default_avatar", - "_display_name", - "_public_flags", - "_cache", - "_system", - "_verified", - "_locale", - "_two_factor_enabled", - "_flags", - ) - _verified: bool - _locale: Optional[Locale] - _two_factor_enabled: bool - _flags: UserFlags - - def __init__(self, cache: UserCache, payload: UserPayload): - self._verified = payload["verified"] - self._two_factor_enabled = payload["two_factor_enabled"] - self._locale = payload["locale"] - self._flags = payload["flags"] - super().__init__(cache, payload) - async def edit(self, *, username: str = None, avatar: bytes = None): pass def __repr__(self) -> str: return f"{self.username}#{self.discriminator} ({self.id})" - - @property - def verified(self): - return self._verified - - @property - def locale(self): - return self._locale - - @property - def two_factor_enabled(self): - return self._two_factor_enabled - - @property - def flags(self): - return self._flags diff --git a/discord/user/member.py b/discord/user/member.py index ec5ef78..1c89fd8 100644 --- a/discord/user/member.py +++ b/discord/user/member.py @@ -1,24 +1,44 @@ from __future__ import annotations -from user import User +from datetime import datetime +from typing import List, Optional, Any -from ..guild import Guild -from ..role import Role +from pydantic import validator +import arrow + +from .user import User +from ..asset import Asset +from ..types.snowflake import Snowflake + + +def validate_dt(val): + if val is None: + return val + + return arrow.get(str(val)) class Member(User): - _top_role: Role - _roles: set[Role] - _guild: Guild + nick: Optional[str] = None + guild_avatar: Optional[Asset] = None + roles: List[Snowflake] + guild: Any + joined_at: datetime + premium_since: Optional[datetime] = None + deaf: Optional[bool] + mute: Optional[bool] + pending: Optional[bool] + permissions: Optional[str] = None + communication_disabled_until: Optional[datetime] = None - @property - def top_role(self): - return self._top_role + @validator("joined_at") + def validate_joined_at(cls, val): + return validate_dt(val) - @property - def roles(self): - return self._roles + @validator("premium_since") + def validate_premium_since(cls, val): + return validate_dt(val) - @property - def guild(self): - return self._guild + @validator("communication_disabled_until") + def validate_communication_disabled_until(cls, val): + return validate_dt(val) diff --git a/discord/user/user.py b/discord/user/user.py index f34e585..a87b014 100644 --- a/discord/user/user.py +++ b/discord/user/user.py @@ -10,36 +10,21 @@ class User(BaseUser): - __slots__ = ("_stored",) + _stored: bool = False - _stored: bool - - def __init__(self, cache, payload: UserPayload): - super().__init__(cache, payload) - self._stored = False - - @classmethod - def _from_user(cls, user: BU) -> BU: - self = super()._from_user(user) - self._stored = False - return self - - def __repr__(self) -> str: - return f"{self.tag} ({self.id})" + @property + def mutual_guilds(self): + return - def __str__(self) -> str: - return self.tag + async def create_dm(self): + pass - def __del__(self): - if self._stored: - del self._cache[self.id] - else: - raise KeyError + async def fetch_message(self): + pass - @property - def dm_channel(self): - return self._cache.get_user_dms(self) + async def edit(self, *, username: str = None, avatar: bytes = None): + pass - @property - def mutual_guilds(self): + async def typing(self): pass + diff --git a/discord/utils/__init__.py b/discord/utils/__init__.py index e69de29..eafdc68 100644 --- a/discord/utils/__init__.py +++ b/discord/utils/__init__.py @@ -0,0 +1,8 @@ +import asyncio + + +async def maybe_await(function, *args, **kwargs): + if asyncio.iscoroutinefunction(function): + return await function(*args, **kwargs) + else: + return function(*args, **kwargs) diff --git a/logo.png b/logo.png new file mode 100644 index 0000000..39ae495 Binary files /dev/null and b/logo.png differ diff --git a/setup.py b/setup.py index d478fdb..6c15733 100644 --- a/setup.py +++ b/setup.py @@ -1,11 +1,11 @@ -from setuptools import setup +from setuptools import find_packages, setup with open("README.md", "r") as file: long_des = file.read() setup( name="disthon", - packages=[], + packages=find_packages(), install_requires=["aiohttp", "yarl", "pydantic", "arrow"], description="An API wrapper for the discord API written in python", version="0.0.1", @@ -16,11 +16,11 @@ author_email="arshia.aghaei@gmail.com", url="https://github.com/AA1999/Disthon", keywords=["API", "discord"], - classifiers = [ + classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", - ] + ], )