import asyncio
import logging
import os
import threading
from pprint import pformat
import discord
from typing import Optional, Union, Dict, Set, Any
LOGGER = logging.getLogger(__name__)
__all__ = ["DiscordClient"]
# ----------------------------------------------------------------------------
class MyClient(discord.Client):
async def on_ready(self):
LOGGER.info("Logged on as {0}!".format(self.user))
async def on_message(self, message):
LOGGER.debug("Message from {0.author}: {0.content}".format(message))
# ----------------------------------------------------------------------------
[docs]class DiscordClient:
"""A blocking wrapper around the asyncio Discord.py client."""
def __init__(
self, token: Optional[str] = None, channel: Optional[Union[str, int]] = None
):
self._discord_token: Optional[str] = token
self._discord_channel: Optional[Union[str, int]] = channel
self.all_message_ids: Set[int] = set()
self._initialized: bool = False
self.client_thread: Optional[threading.Thread] = None
self.client: Optional[discord.Client] = None
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
# --------------------------------
[docs] def _load_credentials(self) -> None:
"""Try to load missing Discord configs (token, channel) from
environment variables."""
LOGGER.debug("Load credentials from env vars ...")
if not self._discord_token:
token = os.environ.get("DISCORD_TOKEN", None)
if not token:
raise RuntimeError("No DISCORD_TOKEN environment variable set!")
self._discord_token = token
if not self._discord_channel:
channel = os.environ.get("DISCORD_CHANNEL", None)
if channel:
# TODO: try to strip leading '#'?
try:
channel = int(channel)
except ValueError:
pass
self._discord_channel = channel
[docs] def _find_default_channel(
self, name: Optional[str] = None, default_name: str = "default"
) -> int:
"""Try to find a writable text channel.
Follow the following algorithm:
1. if ``name`` is being provided, search for this channel first
2. if not found, search for ``self._discord_channel``, then
channel that can be configured on instance creation or by
loading environment variables. Check first for a channel with
the given name as string, then fall back to an integer
channel id.
3. if still not found, search for a channel with a given default
name, like "default" or "Allgemein". As this seems to depend
on the language, it might not find one.
If after all this still no channel has been found, either because no
channel with the given names/id exists, or because the Discord token
gives no acces to guilds/channels which we have access to, we throw
a ``RuntimeError``. We now can't use this callback handler.
Parameters
----------
name : Optional[str], optional
channel name to search for first, by default None
default_name : str, optional
alternative default Discord channel name, by default "default"
Returns
-------
int
channel id
Raises
------
RuntimeError
raised if no `guild` Discord server found (i.e. Discord bot
has no permissions / was not yet invited to a Discord server)
RuntimeError
raised if channel could not be found
"""
LOGGER.debug("Search for text channel to write to in Discord ...")
guilds = self.client.guilds
if not guilds:
raise RuntimeError("No guilds found!")
def serch_for_channel_by_name(
name: str,
) -> Optional[discord.channel.TextChannel]:
# all text channels where we can send messages
text_channels = [
channel
for guild in guilds
for channel in guild.channels
if channel.type == discord.ChannelType.text
and channel.permissions_for(guild.me).send_messages
]
# only those with matching name
text_channels = [
channel for channel in text_channels if channel.name == name
]
# sort which lowest position/id first (created first)
text_channels = sorted(text_channels, key=lambda c: (c.position, c.id))
if text_channels:
return text_channels[0]
return None
channel = None
# search by name if provided
if name:
channel = serch_for_channel_by_name(name)
# search by envvar channel name if possible
if not channel and isinstance(self._discord_channel, str):
channel = serch_for_channel_by_name(self._discord_channel)
# search by envvar channel id if possible
if not channel and isinstance(self._discord_channel, int):
try:
channel = self.client.get_channel(self._discord_channel)
except discord.errors.NotFound:
channel = None
# fall back to default channel names
if not channel:
channel = serch_for_channel_by_name(default_name)
# fail
if not channel:
raise RuntimeError("No Text channel found!")
return channel.id
# --------------------------------
[docs] def init(self):
"""Initialize Discord bot for accessing Discord/writing messages.
It loads the credentials, starts the asyncio Discord bot in a
separate thread and after connecting searches for our target channel.
Raises
------
RuntimeError
raised on error while initializing the Discord bot, like invalid
token or channel not found, etc.
"""
if self._initialized:
LOGGER.debug("Already initialized, do nothing.")
return
self._load_credentials()
self.client = MyClient(loop=self.loop)
def client_thread_func():
LOGGER.info(
f"Running Discord AsyncIO Loop in Thread: {threading.current_thread()}"
)
asyncio.set_event_loop(self.loop)
async def client_runner():
try:
await self.client.start(self._discord_token)
except discord.errors.LoginFailure as ex:
LOGGER.warning("Login error! %s", ex)
await self.client.close()
self.loop.stop()
except asyncio.CancelledError as ex:
LOGGER.exception("cancelled? %s", ex)
except Exception as ex:
LOGGER.exception("%s", ex)
LOGGER.debug("client_runner: Error? %s", ex)
if self.loop and self.loop.is_running():
self.loop.stop()
finally:
if not self.client:
# just to be sure, should never happen
return
LOGGER.debug(
"client_runner: close() - is_ready: %s, is_closed: %s",
self.client.is_ready(),
self.client.is_closed(),
)
if self.client.is_ready() and not self.client.is_closed():
await self.client.close()
LOGGER.debug("client_runner: done.")
def stop_loop_on_completion(_future):
if self.loop and not self.loop.is_closed:
LOGGER.debug("Closing loop")
self.loop.stop()
future = asyncio.ensure_future(client_runner(), loop=self.loop)
future.add_done_callback(stop_loop_on_completion)
try:
self.loop.run_forever()
finally:
LOGGER.debug("Try closing Discord AsyncIO Loop ...")
future.remove_done_callback(stop_loop_on_completion)
if self.loop and not self.loop.is_closed():
self.loop.close()
LOGGER.debug("Discord AsyncIO Loop closed.")
self.client_thread = threading.Thread(
target=client_thread_func, name="discord-asyncio", daemon=True
)
self.client_thread.start()
if self.loop.is_running():
raise RuntimeError("Loop not running!")
# NOTE: that we have to set the loop in both the main and background thread!
# else it will raise errors in Lock/Event classes ...
future = asyncio.run_coroutine_threadsafe(
self.client.wait_until_ready(), self.loop
)
_ = future.result(timeout=30)
LOGGER.debug("Search for text channel ...")
try:
self._discord_channel = self._find_default_channel()
LOGGER.info(f"Found channel: {self._discord_channel}")
except RuntimeError:
LOGGER.warning("Found no default channel!")
raise
self._initialized = True
LOGGER.debug("Discord handler initialized.")
[docs] def _quit_client(self):
"""Internal. Try to properly quit the Discord client if neccessary,
and close the asyncio loop if required.
"""
if not self.client:
LOGGER.debug("No Discord client, do nothing.")
return
if not self.loop or self.loop.is_closed():
LOGGER.debug("Asyncio loop already closed, do nothing.")
return
LOGGER.debug("Shutdown Discord handler ...")
# stop client
if not self.client.is_closed():
coro = self.client.close()
future = asyncio.run_coroutine_threadsafe(coro, self.loop)
try:
future.result(timeout=10)
except Exception as ex:
LOGGER.exception("Error while waiting for client to close ... %s", ex)
# cancel remaining tasks
def _cancel_tasks(loop):
"""Cancel reamining tasks. Try to wait until finished."""
# Code adapted from discord.client to work with threads
try:
task_retriever = asyncio.Task.all_tasks
except AttributeError:
# future proofing for 3.9 I guess
task_retriever = asyncio.all_tasks
tasks = {t for t in task_retriever(loop=loop) if not t.done()}
if not tasks:
return
LOGGER.info("Cleaning up after %d tasks.", len(tasks))
for task in tasks:
task.cancel()
LOGGER.info("All tasks finished cancelling.")
future = asyncio.gather(*tasks, loop=loop, return_exceptions=True)
coro = asyncio.wait_for(future, timeout=5, loop=loop)
future = asyncio.run_coroutine_threadsafe(coro, loop)
future.result()
for task in tasks:
if task.cancelled():
continue
try:
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "Unhandled exception during Client.run shutdown.",
"exception": task.exception(),
"task": task,
}
)
except Exception as ex:
LOGGER.debug("task cancel error? %s %s", task, ex)
continue
self.loop.call_soon_threadsafe(_cancel_tasks, self.loop)
# stop loop
self.loop.stop()
# clear state?
self.client = None
self.loop = None
self._initialized = False
[docs] def quit(self):
"""Shutdown the Discord bot.
Tries to close the Discord bot safely, closes the asyncio loop,
waits for the background thread to stop (deamonized, so on program
exit it will quit anyway)."""
self._quit_client()
# asyncio background thread should have finished, but wait
# if still not joined after timeout, just quit
# (thread will stop on program end)
if self.client_thread:
self.client_thread.join(timeout=3)
# properly reset all attributes
self.client = None
self.client_thread = None
self.loop = None
self._initialized = False
LOGGER.debug("Discord handler shut down.")
# --------------------------------
[docs] def send_message(
self, text: str = "", embed: Optional[discord.Embed] = None
) -> Optional[int]:
"""Sends a message to our Discord channel. Returns the message id.
Parameters
----------
text : str, optional
text message to send, by default ""
embed : Optional[discord.Embed], optional
embed object to attach to message, by default None
Returns
-------
Optional[int]
message id if `text` and `embed` were both not ``None``,
``None`` if nothing was sent
"""
# if not initialized, return
# TODO: or raise error?
if not self._initialized:
return None
# if nothing to send, return
if not text and not embed:
return None
async def _send():
await self.client.wait_until_ready()
channel: discord.TextChannel = self.client.get_channel(
self._discord_channel
)
msg: discord.Message = await channel.send(text, embed=embed)
return msg
future = asyncio.run_coroutine_threadsafe(_send(), self.loop)
message = future.result()
self.all_message_ids.add(message.id)
return message.id
[docs] def get_message_by_id(self, msg_id: int) -> Optional[discord.Message]:
"""Try to retrieve a Discord message by its id.
Parameters
----------
msg_id : int
message id of message sent in Discord channel
Returns
-------
Optional[discord.Message]
``None`` if message could not be found by `msg_id`,
else return the message object
"""
# if not initialized, return
if not self._initialized:
return None
try:
channel: discord.TextChannel = self.client.get_channel(
self._discord_channel
)
coro = channel.fetch_message(msg_id)
future = asyncio.run_coroutine_threadsafe(coro, self.loop)
message: discord.Message = future.result()
return message
except discord.errors.NotFound:
return None
[docs] def update_or_send_message(
self, msg_id: Optional[int] = None, **fields
) -> Optional[int]:
"""Wrapper for :func:`send_message` to updated an existing message,
identified by `msg_id` or simply send a new message if no prior
message found.
Parameters
----------
msg_id : Optional[int], optional
message id of prior message sent in channel, if not provided
then send a new message.
text : str, optional
text message, if set to ``None`` it will remove prior message content
embed : Optional[discord.Embed], optional
Discord embed, set to ``None`` to delete existing embed
Returns
-------
Optional[int]
message id of updated or newly sent message,
``None`` if nothing was sent
"""
# if not initialized, return
if not self._initialized:
return None
message = None
if msg_id:
message = self.get_message_by_id(msg_id)
if message:
# filter allowed keywords
fields = {k: v for k, v in fields.items() if k in ("text", "embed")}
if "text" in fields:
fields["content"] = fields.pop("text")
coro = message.edit(**fields)
_ = asyncio.run_coroutine_threadsafe(coro, self.loop)
else:
msg_id = self.send_message(
text=fields.get("text", None), embed=fields.get("embed", None)
)
return msg_id
[docs] def delete_later(self, msg_id: int, delay: Union[int, float] = 5) -> bool:
"""Runs a delayed message deletion function.
Parameters
----------
msg_id : int
message id of message sent in Discord channel
delay : Union[int, float], optional
delay in seconds for then to delete the message, by default 5
Returns
-------
bool
``True`` if message deletion is queued,
``False`` if message could not be found in channel
"""
# if not initialized, return
if not self._initialized:
return False
# NOTE: delete_after is an option of send/edit of channel/message
message = self.get_message_by_id(msg_id)
if not message:
return False
coro = message.delete(delay=delay)
_ = asyncio.run_coroutine_threadsafe(coro, self.loop)
return True
[docs] @staticmethod
def build_embed(
kvs: Dict[str, Any],
title: Optional[str] = None,
footer: Optional[str] = None,
) -> discord.Embed:
"""Builds an rich Embed from key-values.
Parameters
----------
kvs : Dict[str, Any]
Key-Value dictionary for embed fields, non ``int``/``float``
values will be formatted with :func:`pprint.pformat`
title : Optional[str], optional
title string, by default None
footer : Optional[str], optional
footer string, by default None
Returns
-------
discord.Embed
embed object to send via :meth:`send_message`
"""
embed = discord.Embed(title=title)
for k, v in kvs.items():
if isinstance(v, (int, float)):
embed.add_field(name=k, value=v, inline=True)
else:
embed.add_field(
name=k, value=f"```json\n{pformat(v)}\n```", inline=False
)
if footer:
embed.set_footer(text=footer)
return embed
# --------------------------------
# ----------------------------------------------------------------------------