Skip to content

Callbacks


Custom callbacks for tracking experiments with aim

AimPlotTracker

Bases: AimCallback

Abstract class for tracking augmentation examples and metric plots on given data via aim. Inherits from AimCallback

Parameters:

Name Type Description Default
datagen tf.data.Dataset

validation data to predict on

required
labels np.ndarray

One hot encoded true labels

required
class_names List[str]

names of classes to display

required
title str

title of generated plot

''
frequency int

frequency (in epochs) of logging

5
Source code in conftrainer/callbacks/aim_callbacks.py
class AimPlotTracker(AimCallback):
    """
    Abstract class for tracking augmentation examples and metric plots on given data via aim.
    Inherits from AimCallback

    Parameters
    ----------
    datagen : tf.data.Dataset
        validation data to predict on
    labels : np.ndarray
        One hot encoded true labels
    class_names : List[str]
        names of classes to display
    title : str
        title of generated plot
    frequency : int
        frequency (in epochs) of logging
    """

    def __init__(
        self,
        datagen: Union[ImageDatagen, MultiOutputDatagen],
        title: str = "",
        frequency: int = 5,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.datagen = datagen
        self.dataset = datagen.dataset
        self.postprocess_labels = datagen.probs_to_labels
        self.labels = self.postprocess_labels(datagen.labels)
        self.class_names = datagen.classes
        self.title = title
        self.frequency = frequency

    def on_train_begin(self, *args) -> None:
        """Override on_train_batch_end to log a batch augmentation to aim"""
        super().on_train_begin(*args)
        try:
            augmentor = self.model.get_layer('augmentor')
        except (AttributeError, ValueError) as exc:
            logging.warning(f"Failed to get augmentor layer from the network: {str(exc)}")
            augmentor = None
        if augmentor is None:
            return
        batch, _ = next(self.dataset.take(1).as_numpy_iterator())
        images = augmentor(batch[:32], training=True)
        nrows = (len(images) + 1) // 4
        fig, axes = plt.subplots(nrows=nrows, ncols=4, figsize=(16, 4 * nrows))
        fig.tight_layout(pad=12, w_pad=0, h_pad=0)
        for img, ax in zip(images, axes.ravel()):
            img = img.numpy().astype(np.float32)
            ax.set_axis_off()
            ax.imshow(img / 255)
        aim_image = Image(fig, caption="Example of augmented batch")
        self.experiment.track(aim_image, name="Augmentation examples")
        plt.close("all")

    def create_plot(self, y_pred: np.ndarray) -> Figure:
        """
        Create a matplotlib.figure.Figure with the plot to track
        """

        raise NotImplementedError

    def predict(self) -> np.ndarray:
        """Infer the network on given data and postprocess the predictions"""
        predictions = self.model.predict(self.dataset, verbose=0)
        return self.postprocess_labels(predictions)

    def track_metric_plot(self, epoch: int) -> None:
        """Create and track the plot with metrics"""
        predictions = self.predict()
        fig = self.create_plot(y_pred=predictions)
        caption = f"{self.title} for epoch {epoch}"
        fig.tight_layout()
        aim_image = Image(fig, caption=caption)
        self.experiment.track(aim_image, name=self.title, epoch=epoch, step=None, context={"subset": self.datagen.name})
        plt.close("all")

    def on_epoch_end(self, epoch: int, logs: Optional[Any] = None) -> None:
        """
        Create and track the plot

        Parameters
        ----------
        epoch : int
            index of current epoch
        logs : dict, optional, default: None
            aim_logs of the model.fit

        """
        super().on_epoch_end(epoch=epoch, logs=logs)
        if not epoch % self.frequency:
            self.track_metric_plot(epoch=epoch)

on_train_begin(*args)

Override on_train_batch_end to log a batch augmentation to aim

Source code in conftrainer/callbacks/aim_callbacks.py
def on_train_begin(self, *args) -> None:
    """Override on_train_batch_end to log a batch augmentation to aim"""
    super().on_train_begin(*args)
    try:
        augmentor = self.model.get_layer('augmentor')
    except (AttributeError, ValueError) as exc:
        logging.warning(f"Failed to get augmentor layer from the network: {str(exc)}")
        augmentor = None
    if augmentor is None:
        return
    batch, _ = next(self.dataset.take(1).as_numpy_iterator())
    images = augmentor(batch[:32], training=True)
    nrows = (len(images) + 1) // 4
    fig, axes = plt.subplots(nrows=nrows, ncols=4, figsize=(16, 4 * nrows))
    fig.tight_layout(pad=12, w_pad=0, h_pad=0)
    for img, ax in zip(images, axes.ravel()):
        img = img.numpy().astype(np.float32)
        ax.set_axis_off()
        ax.imshow(img / 255)
    aim_image = Image(fig, caption="Example of augmented batch")
    self.experiment.track(aim_image, name="Augmentation examples")
    plt.close("all")

create_plot(y_pred)

Create a matplotlib.figure.Figure with the plot to track

Source code in conftrainer/callbacks/aim_callbacks.py
def create_plot(self, y_pred: np.ndarray) -> Figure:
    """
    Create a matplotlib.figure.Figure with the plot to track
    """

    raise NotImplementedError

predict()

Infer the network on given data and postprocess the predictions

Source code in conftrainer/callbacks/aim_callbacks.py
def predict(self) -> np.ndarray:
    """Infer the network on given data and postprocess the predictions"""
    predictions = self.model.predict(self.dataset, verbose=0)
    return self.postprocess_labels(predictions)

track_metric_plot(epoch)

Create and track the plot with metrics

Source code in conftrainer/callbacks/aim_callbacks.py
def track_metric_plot(self, epoch: int) -> None:
    """Create and track the plot with metrics"""
    predictions = self.predict()
    fig = self.create_plot(y_pred=predictions)
    caption = f"{self.title} for epoch {epoch}"
    fig.tight_layout()
    aim_image = Image(fig, caption=caption)
    self.experiment.track(aim_image, name=self.title, epoch=epoch, step=None, context={"subset": self.datagen.name})
    plt.close("all")

on_epoch_end(epoch, logs=None)

Create and track the plot

Parameters:

Name Type Description Default
epoch int

index of current epoch

required
logs dict

aim_logs of the model.fit

None
Source code in conftrainer/callbacks/aim_callbacks.py
def on_epoch_end(self, epoch: int, logs: Optional[Any] = None) -> None:
    """
    Create and track the plot

    Parameters
    ----------
    epoch : int
        index of current epoch
    logs : dict, optional, default: None
        aim_logs of the model.fit

    """
    super().on_epoch_end(epoch=epoch, logs=logs)
    if not epoch % self.frequency:
        self.track_metric_plot(epoch=epoch)

ConfusionMatrixTracker

Bases: AimPlotTracker

Track Confusion Matrix on given data after each epoch. Usable for multiclass classification.

Source code in conftrainer/callbacks/aim_callbacks.py
class ConfusionMatrixTracker(AimPlotTracker):
    """
    Track Confusion Matrix on given data after each epoch. Usable for multiclass classification.
    """

    def __init__(self, title: str = "Confusion matrix", **kwargs) -> None:
        super().__init__(title=title, **kwargs)

    def create_plot(self, y_pred: np.ndarray) -> Figure:
        """
        Calculate a confusion matrix and plot it

        Parameters
        ----------
        y_pred : np.array
            predictions of the network

        Returns
        -------
        out : matplotlib.figure.Figure
            a plot containing confusion matrix
        """
        return plot_confusion_matrix(self.labels, y_pred, self.class_names)

create_plot(y_pred)

Calculate a confusion matrix and plot it

Parameters:

Name Type Description Default
y_pred np.array

predictions of the network

required

Returns:

Name Type Description
out matplotlib.figure.Figure

a plot containing confusion matrix

Source code in conftrainer/callbacks/aim_callbacks.py
def create_plot(self, y_pred: np.ndarray) -> Figure:
    """
    Calculate a confusion matrix and plot it

    Parameters
    ----------
    y_pred : np.array
        predictions of the network

    Returns
    -------
    out : matplotlib.figure.Figure
        a plot containing confusion matrix
    """
    return plot_confusion_matrix(self.labels, y_pred, self.class_names)

MultilabelReportTracker

Bases: AimPlotTracker

Track class-wise precision and recall after each epoch. Usable for multilabel classification tasks.

Source code in conftrainer/callbacks/aim_callbacks.py
class MultilabelReportTracker(AimPlotTracker):
    """
    Track class-wise precision and recall after each epoch. Usable for multilabel classification
    tasks.
    """

    def __init__(self, title: str = "Classwise F1Score, Precision & Recall", **kwargs) -> None:
        super().__init__(title=title, **kwargs)

    def create_plot(self, y_pred: np.ndarray) -> Figure:
        """
        Compute precision and recall for given predictions & return a heatmap showing those
        metrics for each class

        Parameters
        ----------
        y_pred : np.array
            array with binary values

        Returns
        -------
        out : matplotlib.figure.Figure
            heatmap of classwise metrics
        """
        return plot_class_metrics(y_true=self.labels, y_pred=y_pred, class_names=self.class_names)

create_plot(y_pred)

Compute precision and recall for given predictions & return a heatmap showing those metrics for each class

Parameters:

Name Type Description Default
y_pred np.array

array with binary values

required

Returns:

Name Type Description
out matplotlib.figure.Figure

heatmap of classwise metrics

Source code in conftrainer/callbacks/aim_callbacks.py
def create_plot(self, y_pred: np.ndarray) -> Figure:
    """
    Compute precision and recall for given predictions & return a heatmap showing those
    metrics for each class

    Parameters
    ----------
    y_pred : np.array
        array with binary values

    Returns
    -------
    out : matplotlib.figure.Figure
        heatmap of classwise metrics
    """
    return plot_class_metrics(y_true=self.labels, y_pred=y_pred, class_names=self.class_names)

MultiBranchPlotTracker

Bases: AimPlotTracker

Track the confusion matrix / metric plots for branches of a multioutput network

Source code in conftrainer/callbacks/aim_callbacks.py
class MultiBranchPlotTracker(AimPlotTracker):
    """Track the confusion matrix / metric plots for branches of a multioutput network"""

    def create_plot(self, y_pred: np.ndarray) -> List[Figure]:
        figures = []
        for true, pred, task_type, cls_names in zip(self.labels, y_pred, self.datagen.task_types, self.class_names):
            if task_type == 'multiclass':
                fig = plot_confusion_matrix(y_true=true, y_pred=pred, class_names=cls_names, normalize='true')
            else:
                fig = plot_class_metrics(y_true=true, y_pred=pred, class_names=cls_names)
            figures.append(fig)
        return figures

    def track_metric_plot(self, epoch: int) -> None:
        """Create and track the plot with metrics"""
        predictions = self.predict()
        figures = self.create_plot(y_pred=predictions)
        for fig, task_name in zip(figures, self.datagen.task_names):
            caption = f"{task_name} metrics on epoch {epoch}"
            fig.tight_layout()
            aim_image = Image(fig, caption=caption)
            self.experiment.track(aim_image,
                                  name=task_name,
                                  epoch=epoch,
                                  step=None,
                                  context={"subset": self.datagen.name})
        plt.close("all")

track_metric_plot(epoch)

Create and track the plot with metrics

Source code in conftrainer/callbacks/aim_callbacks.py
def track_metric_plot(self, epoch: int) -> None:
    """Create and track the plot with metrics"""
    predictions = self.predict()
    figures = self.create_plot(y_pred=predictions)
    for fig, task_name in zip(figures, self.datagen.task_names):
        caption = f"{task_name} metrics on epoch {epoch}"
        fig.tight_layout()
        aim_image = Image(fig, caption=caption)
        self.experiment.track(aim_image,
                              name=task_name,
                              epoch=epoch,
                              step=None,
                              context={"subset": self.datagen.name})
    plt.close("all")

Custom tensorflow callbacks

ClearMemory

Bases: Callback

Collect garbage via gc.collect() and release memory using tf.keras.backend.clear_session after each epoch. Inherits from tf.keras.callbacks.Callback

Source code in conftrainer/callbacks/tf_callbacks.py
class ClearMemory(Callback):
    """
    Collect garbage via gc.collect() and release memory using tf.keras.backend.clear_session after
    each epoch. Inherits from tf.keras.callbacks.Callback
    """

    @staticmethod
    def on_epoch_end(*args):
        """
        Release memory after each epoch
        """
        gc.collect()
        clear_session()

on_epoch_end(*args) staticmethod

Release memory after each epoch

Source code in conftrainer/callbacks/tf_callbacks.py
@staticmethod
def on_epoch_end(*args):
    """
    Release memory after each epoch
    """
    gc.collect()
    clear_session()

UpdateTarget

Bases: Callback

Update the weights of the target branch of siamese networks as described in BYOL paper Inherits from tf.keras.Callback

Updates the weights of the target network so it's a weighted mix of itself and an online network:

.. math:: Weights_{target} = \beta * Weights_{target} + 1 - \beta * Weights_{online}

Passed model should have following attributes:

  • target_encoder
  • online_encoder
  • target_projection_head
  • online_projection_head

Parameters:

initial_beta: float in (0,1) used to infer beta multiplier for mixing the online and target networks. Gradually increases with each step, equalling 1 at the end of the training num_steps: int total number of training steps. Used for inferring the step size of beta

Attributes:

beta: float current weight of the target network in the mix step_size: float increment of beta after each step

Source code in conftrainer/callbacks/tf_callbacks.py
class UpdateTarget(Callback):
    """
    Update the weights of the target branch of siamese networks as described in BYOL paper
    Inherits from tf.keras.Callback

    Updates the weights of the target network so it's a weighted mix of itself and an online
    network:

    .. math::
        Weights_{target} = \\beta * Weights_{target} + 1 - \\beta * Weights_{online}

    Passed model should have following attributes:

    - target_encoder
    - online_encoder
    - target_projection_head
    - online_projection_head

    Parameters:
    ----------
    initial_beta: float in (0,1)
                used to infer beta multiplier for mixing the online and target networks.
                Gradually increases with each step, equalling 1 at the end of the training
    num_steps: int
                total number of training steps. Used for inferring the step size of beta

    Attributes:
    ----------
    beta: float
            current weight of the target network in the mix
    step_size: float
            increment of beta after each step

    """

    def __init__(self, initial_beta, num_steps, frequency=1, **kwargs):
        super().__init__(**kwargs)
        self.beta = initial_beta
        self.initial_beta = initial_beta
        self.total_steps = num_steps
        self.frequency = frequency

    @staticmethod
    def mix_weights(target: tf.keras.Model, online: tf.keras.Model, target_ratio: float) -> None:
        """
        Averages weights of given target & online networks by given weight and assigns new
        weights to the target network

        Parameters
        ----------
        target : tf.Model
            target network to update the weights of
        online : tf.Model
            network to take weights from and average with the target networks weights
        target_ratio : float
            ratio of the weights of target network in the mix. New weights will be
            calculated by formula:
            target_ratio * target_weights + (1 - target_ratio) * online_weights
        """

        target_weights = target.get_weights()
        online_weights = online.get_weights()
        new_weights = jnp.tree_map(lambda x, y: target_ratio * x + (1 - target_ratio) * y,
                                   target_weights, online_weights)
        target.set_weights(weights=new_weights)

    def on_batch_end(self, batch, logs=None):
        """
        Update the weights of the target network once per self.frequency batches
        """
        if not batch % self.frequency:
            self.mix_weights(target=self.model.target_encoder,
                             online=self.model.online_encoder,
                             target_ratio=self.beta)

            self.mix_weights(target=self.model.target_projection_head,
                             online=self.model.online_projection_head,
                             target_ratio=self.beta)
            cosine_decay_value = 0.5 * (1 + tf.math.cos(pi * batch / self.total_steps))
            self.beta = 1 - (1 - self.initial_beta) * cosine_decay_value
            if self.beta > 1:
                message = f"\n Beta parameter {self.beta} for updating the target network got " \
                          f"greater than 1. Make sure you passed right number of total steps \n " \
                          "=================================================================\n" \
                          "The beta is rolled back to 0.996, but there might be performance issues"
                logging.error(message)
                self.beta = 0.996

mix_weights(target, online, target_ratio) staticmethod

Averages weights of given target & online networks by given weight and assigns new weights to the target network

Parameters:

Name Type Description Default
target tf.Model

target network to update the weights of

required
online tf.Model

network to take weights from and average with the target networks weights

required
target_ratio float

ratio of the weights of target network in the mix. New weights will be calculated by formula: target_ratio * target_weights + (1 - target_ratio) * online_weights

required
Source code in conftrainer/callbacks/tf_callbacks.py
@staticmethod
def mix_weights(target: tf.keras.Model, online: tf.keras.Model, target_ratio: float) -> None:
    """
    Averages weights of given target & online networks by given weight and assigns new
    weights to the target network

    Parameters
    ----------
    target : tf.Model
        target network to update the weights of
    online : tf.Model
        network to take weights from and average with the target networks weights
    target_ratio : float
        ratio of the weights of target network in the mix. New weights will be
        calculated by formula:
        target_ratio * target_weights + (1 - target_ratio) * online_weights
    """

    target_weights = target.get_weights()
    online_weights = online.get_weights()
    new_weights = jnp.tree_map(lambda x, y: target_ratio * x + (1 - target_ratio) * y,
                               target_weights, online_weights)
    target.set_weights(weights=new_weights)

on_batch_end(batch, logs=None)

Update the weights of the target network once per self.frequency batches

Source code in conftrainer/callbacks/tf_callbacks.py
def on_batch_end(self, batch, logs=None):
    """
    Update the weights of the target network once per self.frequency batches
    """
    if not batch % self.frequency:
        self.mix_weights(target=self.model.target_encoder,
                         online=self.model.online_encoder,
                         target_ratio=self.beta)

        self.mix_weights(target=self.model.target_projection_head,
                         online=self.model.online_projection_head,
                         target_ratio=self.beta)
        cosine_decay_value = 0.5 * (1 + tf.math.cos(pi * batch / self.total_steps))
        self.beta = 1 - (1 - self.initial_beta) * cosine_decay_value
        if self.beta > 1:
            message = f"\n Beta parameter {self.beta} for updating the target network got " \
                      f"greater than 1. Make sure you passed right number of total steps \n " \
                      "=================================================================\n" \
                      "The beta is rolled back to 0.996, but there might be performance issues"
            logging.error(message)
            self.beta = 0.996

AimCallbackType

Bases: Enum

Enumerate Aim Callbacks based on task type

Source code in conftrainer/callbacks/utils.py
class AimCallbackType(Enum):
    """Enumerate Aim Callbacks based on task type"""
    ConfusionMatrixTracker = "multiclass"
    MultilabelReportTracker = "multilabel"
    MultiBranchPlotTracker = "multibranch"

    @classmethod
    def import_callback(cls, value: str) -> Callback:
        """Import a classification head depending on provided enum value"""
        return getattr(aim_callbacks, cls(value).name)

import_callback(value) classmethod

Import a classification head depending on provided enum value

Source code in conftrainer/callbacks/utils.py
@classmethod
def import_callback(cls, value: str) -> Callback:
    """Import a classification head depending on provided enum value"""
    return getattr(aim_callbacks, cls(value).name)

get_base_callbacks(save_config, params=None, **aim_tracker_args)

Return a list of base callbacks usable for any model. The list includes callbacks for logging and checkpointing the network

Parameters:

Name Type Description Default
save_config str

path to save the model

required
params dict

parameters to track in aim_callback

None
**aim_tracker_args dict

arguments for tracking heatmaps with aip

{}

Returns:

Name Type Description
out list

callbacks ready to pass to the network

Source code in conftrainer/callbacks/utils.py
def get_base_callbacks(
        save_config: SaveLogConfig,
        params: Optional[dict] = None,
        **aim_tracker_args
) -> List[type(Callback)]:
    """
    Return a list of base callbacks usable for any model. The list includes callbacks for
    logging and checkpointing the network

    Parameters
    ----------
    save_config : str
        path to save the model
    params : dict, optional, default: None
        parameters to track in aim_callback
    **aim_tracker_args : dict
        arguments for tracking heatmaps with aip

    Returns
    -------
    out : list
        callbacks ready to pass to the network
    """
    if params is None:
        params = {}
    model_folder = save_config.model_folder
    log_save_path = os.path.join(model_folder, "logs.csv")
    csv_logger = CSVLogger(log_save_path, append=True)
    checkpoint = ModelCheckpoint(filepath=save_config.best_checkpoint_path,
                                 monitor=params["callbacks_config"]["monitor_base"],
                                 mode=params["callbacks_config"]["mode_base"], save_best_only=True)
    clear_memory = ClearMemory()
    task_type = aim_tracker_args.pop("task_type", None)
    try:
        callback_type = AimCallbackType.import_callback(task_type)
    except ValueError:
        logging.warning(f"Couldnt import aim callback for {task_type}. Falling back to default aim callback")
        callback_type = AimCallback
        aim_tracker_args = {}
    aim_callback = callback_type(repo=save_config.aim_log_path,
                                 experiment=save_config.aim_experiment_name,
                                 **aim_tracker_args)
    aim_callback._run["hparams"] = params

    callbacks = [
        csv_logger,
        aim_callback,
        checkpoint,
        clear_memory,
    ]

    return callbacks