Source code for transformer_discord_notifier

__version__ = "0.0.0"

import asyncio
import concurrent.futures
import logging
import os
import time
import threading
from datetime import timedelta
from functools import partial
from pprint import pformat

import discord
from discord.client import _cancel_tasks
from discord.client import _cleanup_loop

from transformers.trainer_callback import TrainerCallback
from transformers.trainer_callback import TrainerControl
from transformers.trainer_callback import TrainerState
from transformers.training_args import TrainingArguments

from typing import Optional, Union, Dict, Set, Any


LOGGER = logging.getLogger(__name__)


# ----------------------------------------------------------------------------


[docs]class MyClient(discord.Client): async def on_ready(self): LOGGER.debug("Logged on as {0}!".format(self.user)) async def on_message(self, message): LOGGER.debug("Message from {0.author}: {0.content}".format(message))
# ----------------------------------------------------------------------------
[docs]class DiscordProgressCallback(TrainerCallback): def __init__( self, token: Optional[str] = None, channel: Optional[Union[str, int]] = None ): super().__init__() self._discord_token: Optional[str] = token self._discord_channel: Optional[Union[str, int]] = channel self.all_message_ids: Set[int] = set() self.last_message_ids: Dict[str, int] = dict() self.progressbars: Dict[str, int] = dict() self.timediffs: Dict[str, float] = dict() self._initialized: bool = False self.client_thread: Optional[threading.Thread] = None self.client: Optional[discord.Client] = None self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) # -------------------------------- def _load_credentials(self): LOGGER.debug("Load credentials from env vars ...") if not self._discord_token: token = os.environ.get("DISCORD_TOKEN", None) if not token: raise RuntimeError("No DISCORD_TOKEN environment variable set!") self._discord_token = token if not self._discord_channel: channel = os.environ.get("DISCORD_CHANNEL", None) if channel: # TODO: try to strip leading '#'? try: channel = int(channel) except ValueError: pass self._discord_channel = channel def _find_default_channel( self, name: Optional[str] = None, default_name: str = "default" ) -> int: LOGGER.debug("Search for text channel to write to in Discord ...") guilds = self.client.guilds if not guilds: raise RuntimeError("No guilds found!") def serch_for_channel_by_name( name: str, ) -> Optional[discord.channel.TextChannel]: # all text channels where we can send messages text_channels = [ channel for guild in guilds for channel in guild.channels if channel.type == discord.ChannelType.text and channel.permissions_for(guild.me).send_messages ] # only those with matching name text_channels = [ channel for channel in text_channels if channel.name == name ] # sort which lowest position/id first (created first) text_channels = sorted(text_channels, key=lambda c: (c.position, c.id)) if text_channels: return text_channels[0] return None channel = None # search by name if provided if name: channel = serch_for_channel_by_name(name) # search by envvar channel name if possible if not channel and isinstance(self._discord_channel, str): channel = serch_for_channel_by_name(self._discord_channel) # search by envvar channel id if possible if not channel and isinstance(self._discord_channel, int): try: channel = self.client.get_channel(self._discord_channel) except discord.errors.NotFound: channel = None # fall back to default channel names if not channel: channel = serch_for_channel_by_name(default_name) # fail if not channel: raise RuntimeError("No Text channel found!") return channel.id # -------------------------------- def _init_discord(self): if self._initialized: LOGGER.debug("Already initialized, do nothing.") return self._load_credentials() self.client = MyClient(loop=self.loop) def client_thread_func(): LOGGER.info( f"Running Discord AsyncIO Loop in Thread: {threading.current_thread()}" ) asyncio.set_event_loop(self.loop) self.loop.create_task(self.client.start(self._discord_token)) self.loop.run_forever() self.client_thread = threading.Thread(target=client_thread_func, daemon=True) self.client_thread.start() def search_text_channel(): # LOGGER.debug("Wait for client to finish? ...") # await self.client.wait_until_ready() LOGGER.debug("Search for text channel ...") try: self._discord_channel = self._find_default_channel() except RuntimeError: LOGGER.warning("Found no default channel!") return None LOGGER.info(f"Found channel: {self._discord_channel}") return self._discord_channel # NOTE: that we have to set the loop in both the main and background thread! # else it will raise errors in Lock/Event classes ... # while not self.client.is_ready(): # pass future = asyncio.run_coroutine_threadsafe( self.client.wait_until_ready(), self.loop ) _ = future.result() # th_future = concurrent.futures.Future() # def _async_wrap(func, fut): # try: # ret = func() # fut.set_result(ret) # except Exception as ex: # print("ex", ex) # fut.set_exception(ex) # self.loop.call_soon_threadsafe( # _async_wrap, # partial( # asyncio.run_coroutine_threadsafe, # self.client.wait_until_ready(), # loop=self.loop, # ), # th_future, # ) # th_future.result().result() channel_id = search_text_channel() print(channel_id) LOGGER.debug("Discord handler initialized.") def _quit_discord(self): if not self.client: LOGGER.debug("Discord already shutdown, do nothing.") return LOGGER.debug("Shutdown Discord handler ...") # stop client asyncio.run_coroutine_threadsafe(self.client.close(), self.loop) self.loop.close() self.client_thread.join(timeout=3) self.client = None self.client_thread = None # -------------------------------- def _send_message(self, text: str) -> int: async def send(): await self.client.wait_until_ready() channel = self.client.get_channel(self._discord_channel) msg = await channel.send(text) return msg future = asyncio.run_coroutine_threadsafe(send(), self.loop) message = future.result() # self.all_message_ids.add(message.id) return message.id # th_future = concurrent.futures.Future() # def _async_send(func, fut): # try: # ret = func() # ret = ret.result() # fut.set_result(ret) # except Exception as ex: # fut.set_exception(ex) # func = partial(asyncio.run_coroutine_threadsafe, send(), loop=self.loop) # self.loop.call_soon_threadsafe(_async_send, func, th_future) # message = th_future.result() def _delete_later(self, msg_id: int, delay: Union[int, float] = 5) -> bool: try: channel = self.client.get_channel(self._discord_channel) coro = channel.fetch_message(msg_id) future = asyncio.run_coroutine_threadsafe(coro, self.loop) message = future.result() except discord.errors.NotFound: return False coro = message.delete(delay=delay) future = asyncio.run_coroutine_threadsafe(coro, self.loop) return True def _send_log_results( self, logs: Dict[str, Any], state: TrainerState, args: TrainingArguments ) -> int: results_embed = discord.Embed(title="Results") for k, v in logs.items(): if isinstance(v, (int, float)): results_embed.add_field(name=k, value=v, inline=True) else: results_embed.add_field( name=k, value=f"```json\n{pformat(v)}\n```", inline=False ) results_embed.set_footer( text=f"Global step: {state.global_step} | Run: {args.run_name}" ) async def send(): await self.client.wait_until_ready() channel = self.client.get_channel(self._discord_channel) msg = await channel.send(embed=results_embed) return msg future = asyncio.run_coroutine_threadsafe(send(), self.loop) message = future.result() self.all_message_ids.add(message.id) return message.id def _send_progress_msg(self, ptype: str, time_diff: Optional[float] = None) -> None: if ptype not in self.progressbars: return msg_id = self.last_message_ids.get(ptype, None) cur_step, max_step = self.progressbars[ptype] msg_s = ( f"Progress [{ptype}]: " f"{round(cur_step / max_step * 100):d}% | " f"{cur_step} / {max_step}" + (f" | took: {timedelta(seconds=round(time_diff))!s}" if time_diff else "") ) if msg_id: try: channel = self.client.get_channel(self._discord_channel) coro = channel.fetch_message(msg_id) future = asyncio.run_coroutine_threadsafe(coro, self.loop) msg = future.result() except discord.errors.NotFound: msg_id = None msg = None if msg: coro = msg.edit(content=msg_s) asyncio.run_coroutine_threadsafe(coro, self.loop) if not msg_id: msg_id = self._send_message(msg_s) self.last_message_ids[ptype] = msg_id self.all_message_ids.add(msg_id) # -------------------------------- def start(self): self._init_discord() def end(self): self._quit_discord() # --------------------------------
[docs] def on_init_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called at the end of the initialization of the :class:`~transformers.Trainer`. """ self._init_discord()
[docs] def on_train_begin( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called at the beginning of training. """ if state.is_local_process_zero: self._send_message(f"Begin training on {args.run_name}") self.progressbars["train"] = (0, state.max_steps) self.timediffs["train"] = time.time()
[docs] def on_train_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called at the end of training. """ if state.is_local_process_zero: self._send_progress_msg("train") time_diff = time.time() - self.timediffs["train"] self._send_message( f"Finish training on {args.run_name}, took: {timedelta(seconds=time_diff)}" ) if "pred" in self.progressbars: del self.progressbars["train"]
[docs] def on_epoch_begin( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called at the beginning of an epoch. """ if state.is_local_process_zero: self.timediffs["epoch"] = time.time() self.timediffs["step"] = time.time() self.last_message_ids.pop("train", None) self._send_progress_msg("train") msg_id = self._send_message(f"Begin epoch: {state.epoch:.1f}") self._delete_later(msg_id, delay=5)
[docs] def on_epoch_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called at the end of an epoch. """ if state.is_local_process_zero: time_diff = time.time() - self.timediffs["epoch"] self._send_message( f"Epoch done, took {timedelta(seconds=round(time_diff))!s}" ) self.timediffs["step"] = time.time()
[docs] def on_step_begin( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called at the beginning of a training step. If using gradient accumulation, one training step might take several inputs. """
# seems to be called at the end of all gradient accumulation sub-steps
[docs] def on_step_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs. """ if state.is_local_process_zero: _, max_step = self.progressbars["train"] self.progressbars["train"] = (state.global_step, max_step) time_diff = time.time() - self.timediffs["step"] self._send_progress_msg("train", time_diff=time_diff) self.timediffs["step"] = time.time()
[docs] def on_evaluate( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called after an evaluation phase. """ if state.is_local_process_zero: if "pred" in self.progressbars: del self.progressbars["pred"] msg_id = self._send_message(f"After eval ...") self._delete_later(msg_id, delay=5)
# TODO: reset step timer?
[docs] def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called after a checkpoint save. """ msg_id = self._send_message(f"Saving in epoch {state.epoch:.1f}") self._delete_later(msg_id, delay=5)
[docs] def on_log( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called after logging the last logs. """ # is_world_process_zero # is_local_process_zero if state.is_local_process_zero: logs = kwargs["logs"] self._send_log_results(logs, state, args)
[docs] def on_prediction_step( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): """ Event called after a prediction step. """ if state.is_local_process_zero: eval_dataloader = kwargs["eval_dataloader"] if "pred" not in self.progressbars: self.progressbars["pred"] = (0, len(eval_dataloader)) cur_step, max_step = self.progressbars["pred"] self.progressbars["pred"] = (cur_step + 1, max_step) if "pred" not in self.timediffs: self.timediffs["pred"] = time.time() time_diff = time.time() - self.timediffs["pred"] self.timediffs["pred"] = time.time() self._send_progress_msg("pred", time_diff=time_diff)
def __del__(self): # NOTE: probably never called as there will still be some references # left somewhere (thread, callbacks?, self.*?) msg_id = self._send_message(f"Ending.") self._delete_later(msg_id, delay=3) self._quit_discord()
# ----------------------------------------------------------------------------