Source code for transformer_discord_notifier.transformers

import logging
import time
from concurrent.futures import TimeoutError
from datetime import timedelta

from tqdm import tqdm
from transformers.trainer_callback import ProgressCallback
from transformers.trainer_callback import TrainerControl
from transformers.trainer_callback import TrainerState
from transformers.training_args import TrainingArguments

# from transformers.trainer import DataLoader

from typing import Optional, Union, Dict, Any, Tuple

from .discord import DiscordClient


LOGGER = logging.getLogger(__name__)


__all__ = ["DiscordProgressCallback"]


# ----------------------------------------------------------------------------


class MessageWrapperTQDMWriter:
    def __init__(
        self,
        client: DiscordClient,
        msg_fmt: str,
        delete_after: bool = True,
    ):
        self.client = client

        self.msg_id: Optional[int] = None
        self.delete_after = delete_after

        self.msg_fmt = msg_fmt
        self.last_msg: Optional[str] = None

    def write(self, text: str):
        text = text.strip("\r\n")
        if not text.strip():
            return

        self.last_msg = text

        msg_s = self.msg_fmt.format(text=text)
        self.msg_id = self.client.update_or_send_message(msg_id=self.msg_id, text=msg_s)

    def flush(self):
        pass

    def close(self):
        if self.delete_after and self.msg_id is not None and self.client:
            try:
                self.client.delete_later(self.msg_id, delay=10)
            except AttributeError:
                pass
            self.msg_id = None

    def __del__(self):
        LOGGER.debug("__del__ of MessageWrapperTQDMWriter")
        self.close()


[docs]class DiscordProgressCallback(ProgressCallback): """An extended :class:`transformers.trainer_callback.ProgressCallback` that logs training and evaluation progress and statistics to a Discord channel. Attributes ---------- client : DiscordClient a blocking Discord client disabled : bool ``True`` if Discord client couldn't not be initialized successfully, all callback methods are disabled silently """ def __init__( self, token: Optional[str] = None, channel: Optional[Union[str, int]] = None ): """ Parameters ---------- token : Optional[str], optional Discord bot token, by default None channel : Optional[Union[str, int]], optional Discord channel name or numeric id, by default None """ super().__init__() self.disabled = True self.client = DiscordClient(token, channel) self.last_embed_id: Optional[int] = None self.epoch_start_time: Optional[float] = None self.writer_train: Any = None self.writer_predict: Any = None # --------------------------------
[docs] def start(self) -> None: """Start the Discord bot.""" is_ok, err_msg = True, None try: self.client.init() self.disabled = False except (RuntimeError, TimeoutError, TypeError) as ex: is_ok = False err_msg = str(ex) if not is_ok or not self.client or not self.client._initialized: LOGGER.warning( "Failure to initialize Discord client." " Silently disable callback handler." + (f" Error: {err_msg}" if err_msg else "") ) self.disabled = True
[docs] def end(self) -> None: """Stop the Discord bot. Cleans up resources.""" if self.client: self.client.quit() self.disabled = True
# --------------------------------
[docs] def on_init_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): self.start()
def __del__(self): self.end() # --------------------------------
[docs] def _new_tqdm_bar( self, desc: str, msg_fmt: str, delete_after: bool = True, **kwargs, ) -> Tuple[tqdm, MessageWrapperTQDMWriter]: """Builds an internal ``tqdm`` wrapper for progress tracking. Patches its ``file.write`` method to forward it to Discord. Tries to update existing messages to avoid spamming the channel. """ writer = MessageWrapperTQDMWriter( self.client, msg_fmt=msg_fmt, delete_after=delete_after, ) pgbr = tqdm( desc=desc, ascii=False, leave=False, position=0, file=writer, **kwargs, ) return pgbr, writer
[docs] def on_train_begin( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): if self.disabled: return if state.is_local_process_zero: msg_fmt = "```\n{text}\n```" if args.run_name: msg_fmt = f"Run: **{args.run_name}**\n{msg_fmt}" self.training_bar, self.writer_train = self._new_tqdm_bar( desc="train", msg_fmt=msg_fmt, delete_after=False, total=state.max_steps, ) self.current_step = 0
[docs] def on_prediction_step( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, eval_dataloader=None, **kwargs, ): if self.disabled: return if state.is_local_process_zero: if self.prediction_bar is None: if self.writer_predict is None: msg_fmt = "```\n{text}\n```" if args.run_name: msg_fmt = f"Run: **{args.run_name}**\n{msg_fmt}" else: msg_fmt = self.writer_predict.msg_fmt self.prediction_bar, self.writer_predict = self._new_tqdm_bar( desc="predict", msg_fmt=msg_fmt, delete_after=True, total=len(eval_dataloader), ) self.prediction_bar.update(1)
[docs] def on_step_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): if self.disabled: return super().on_step_end(args, state, control, **kwargs)
# --------------------------------
[docs] def on_epoch_begin( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): if self.disabled: return super().on_epoch_begin(args, state, control, **kwargs) if state.is_local_process_zero: self.epoch_start_time = time.time()
[docs] def on_epoch_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): if self.disabled: return super().on_epoch_end(args, state, control, **kwargs) if state.is_local_process_zero: time_diff = time.time() - self.epoch_start_time self.writer_train.msg_fmt = self.writer_train.msg_fmt.format( text=( f"{self.writer_train.last_msg}\n" f" Epoch {int(state.epoch)}: " f"{timedelta(seconds=round(time_diff))!s}\n" f"{{text}}" ) )
# --------------------------------
[docs] def on_train_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): if self.disabled: return super().on_train_end(args, state, control, **kwargs) if state.is_local_process_zero: if self.writer_train is not None: self.writer_train.close() self.writer_train = None
[docs] def on_evaluate( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): if self.disabled: return super().on_evaluate(args, state, control, **kwargs) if state.is_local_process_zero: self.writer_predict.msg_fmt = self.writer_predict.msg_fmt.format( text=f"{self.writer_predict.last_msg}\n{{text}}" ) if self.prediction_bar is not None: self.prediction_bar.close() self.prediction_bar = None if self.writer_predict is not None: self.writer_predict.close()
# --------------------------------
[docs] def _send_log_results( self, logs: Dict[str, Any], state: TrainerState, args: TrainingArguments, is_train: bool, ) -> int: """Formats current log metrics as Embed message. Given a huggingface transformers Trainer callback parameters, we create an :class:`discord.Embed` with the metrics as key-values. Send the message and returns the message id.""" results_embed = self.client.build_embed( kvs=logs, title="Results (training)" if is_train else "Results (evaluation)", footer=f"Global step: {state.global_step} | Run: {args.run_name}", ) return self.client.send_message(text="", embed=results_embed)
[docs] def on_log( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs: Optional[Dict[str, Any]] = None, **kwargs, ): if self.disabled: return if state.is_local_process_zero: is_train = False if self.training_bar is not None: is_train = True _ = logs.pop("total_flos", None) msg_id = self._send_log_results(logs, state, args, is_train) self.last_embed_id = msg_id
# -------------------------------- # ----------------------------------------------------------------------------