Callbacks
Custom callbacks for tracking experiments with aim
AimPlotTracker
Bases: AimCallback
Abstract class for tracking augmentation examples and metric plots on given data via aim. Inherits from AimCallback
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
datagen |
tf.data.Dataset
|
validation data to predict on |
required |
labels |
np.ndarray
|
One hot encoded true labels |
required |
class_names |
List[str]
|
names of classes to display |
required |
title |
str
|
title of generated plot |
''
|
frequency |
int
|
frequency (in epochs) of logging |
5
|
Source code in conftrainer/callbacks/aim_callbacks.py
on_train_begin(*args)
Override on_train_batch_end to log a batch augmentation to aim
Source code in conftrainer/callbacks/aim_callbacks.py
create_plot(y_pred)
predict()
Infer the network on given data and postprocess the predictions
track_metric_plot(epoch)
Create and track the plot with metrics
Source code in conftrainer/callbacks/aim_callbacks.py
on_epoch_end(epoch, logs=None)
Create and track the plot
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
epoch |
int
|
index of current epoch |
required |
logs |
dict
|
aim_logs of the model.fit |
None
|
Source code in conftrainer/callbacks/aim_callbacks.py
ConfusionMatrixTracker
Bases: AimPlotTracker
Track Confusion Matrix on given data after each epoch. Usable for multiclass classification.
Source code in conftrainer/callbacks/aim_callbacks.py
create_plot(y_pred)
Calculate a confusion matrix and plot it
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred |
np.array
|
predictions of the network |
required |
Returns:
| Name | Type | Description |
|---|---|---|
out |
matplotlib.figure.Figure
|
a plot containing confusion matrix |
Source code in conftrainer/callbacks/aim_callbacks.py
MultilabelReportTracker
Bases: AimPlotTracker
Track class-wise precision and recall after each epoch. Usable for multilabel classification tasks.
Source code in conftrainer/callbacks/aim_callbacks.py
create_plot(y_pred)
Compute precision and recall for given predictions & return a heatmap showing those metrics for each class
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_pred |
np.array
|
array with binary values |
required |
Returns:
| Name | Type | Description |
|---|---|---|
out |
matplotlib.figure.Figure
|
heatmap of classwise metrics |
Source code in conftrainer/callbacks/aim_callbacks.py
MultiBranchPlotTracker
Bases: AimPlotTracker
Track the confusion matrix / metric plots for branches of a multioutput network
Source code in conftrainer/callbacks/aim_callbacks.py
track_metric_plot(epoch)
Create and track the plot with metrics
Source code in conftrainer/callbacks/aim_callbacks.py
Custom tensorflow callbacks
ClearMemory
Bases: Callback
Collect garbage via gc.collect() and release memory using tf.keras.backend.clear_session after each epoch. Inherits from tf.keras.callbacks.Callback
Source code in conftrainer/callbacks/tf_callbacks.py
UpdateTarget
Bases: Callback
Update the weights of the target branch of siamese networks as described in BYOL paper Inherits from tf.keras.Callback
Updates the weights of the target network so it's a weighted mix of itself and an online network:
.. math:: Weights_{target} = \beta * Weights_{target} + 1 - \beta * Weights_{online}
Passed model should have following attributes:
- target_encoder
- online_encoder
- target_projection_head
- online_projection_head
Parameters:
initial_beta: float in (0,1) used to infer beta multiplier for mixing the online and target networks. Gradually increases with each step, equalling 1 at the end of the training num_steps: int total number of training steps. Used for inferring the step size of beta
Attributes:
beta: float current weight of the target network in the mix step_size: float increment of beta after each step
Source code in conftrainer/callbacks/tf_callbacks.py
mix_weights(target, online, target_ratio)
staticmethod
Averages weights of given target & online networks by given weight and assigns new weights to the target network
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target |
tf.Model
|
target network to update the weights of |
required |
online |
tf.Model
|
network to take weights from and average with the target networks weights |
required |
target_ratio |
float
|
ratio of the weights of target network in the mix. New weights will be calculated by formula: target_ratio * target_weights + (1 - target_ratio) * online_weights |
required |
Source code in conftrainer/callbacks/tf_callbacks.py
on_batch_end(batch, logs=None)
Update the weights of the target network once per self.frequency batches
Source code in conftrainer/callbacks/tf_callbacks.py
Helpers related to callbacks
AimCallbackType
Bases: Enum
Enumerate Aim Callbacks based on task type
Source code in conftrainer/callbacks/utils.py
import_callback(value)
classmethod
get_base_callbacks(save_config, params=None, **aim_tracker_args)
Return a list of base callbacks usable for any model. The list includes callbacks for logging and checkpointing the network
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
save_config |
str
|
path to save the model |
required |
params |
dict
|
parameters to track in aim_callback |
None
|
**aim_tracker_args |
dict
|
arguments for tracking heatmaps with aip |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
list
|
callbacks ready to pass to the network |