Skip to content

Trainer


Trainer class to train and finetune networks

BaseTrainer

Base abstraction for a trainer. Takes network, configs and datagens, allows to fit the network based on configs

Source code in conftrainer/training/trainer.py
class BaseTrainer:
    """Base abstraction for a trainer. Takes network, configs and datagens, allows to fit the network based on
    configs"""

    def __init__(self, network: BaseCNN, config: BaseTrainConfig, datagens: Dict[str, BaseImageDatagen]):
        self.config = config
        self.datagens = datagens
        self.network = network
        self.static_callbacks = self.base_callbacks + self.get_callbacks(config.callbacks_config.persistent)
        self.epochs_trained = 0

    @property
    @abstractmethod
    def metrics(self):
        """Create and metrics based on config"""

    @property
    @abstractmethod
    def loss(self):
        """Loss(es) to use during compilation"""

    @property
    def loss_weights(self):
        """Define loss weights. Usable only for setting with multiple losses. Defaults to None"""
        return

    @property
    def base_callbacks(self) -> Iterable:
        """Prepare the base callbacks that are the same for pretraining and finetuning phases"""

        aim_tracker_args = {"task_type": self.config.net_config.task_type or None,
                            "datagen": self.datagens.get("val", None)}

        return get_base_callbacks(save_config=self.config.save_config,
                                  params=self.config.dict(),
                                  **aim_tracker_args)

    @staticmethod
    def get_callbacks(callback_config_list: List[ObjInitConfig]) -> List[type(callbacks.Callback)]:
        """Create callbacks for the """
        return [create_object(config=config, modules=[callbacks, tf_callbacks])
                for config in callback_config_list]

    def prepare_train_args(self, conf: OptimizationConfig) -> Tuple:
        """Prepare the training phase parameters based on given config"""
        lr = get_lr(learning_rate=conf.learning_rate, schedule=conf.schedule,
                    **conf.schedule_kwargs)
        optimizer = get_optimizer(name=conf.optimizer, learning_rate=lr,
                                  **conf.optimizer_kwargs)

        extra_callbacks = self.get_callbacks(self.config.callbacks_config.other)

        return optimizer, extra_callbacks

    def train(self, conf: CNNOptimizationConfig) -> callbacks.History:
        """
        Train a network based on given config

        Parameters
        ----------
        conf : CNNOptimizationConfig
            training configuration to determine learning rate, batch size and other training
            parameters

        Returns
        -------
        history : callbacks.History
            callbacks object containing information about training process
        """
        self.network.unfreeze_layers(freeze_bn=conf.freeze_bn,
                                     unfreeze_all=conf.unfreeze_all,
                                     freeze_layer_name=conf.freeze_layer_name,
                                     submodel_name=conf.submodel_name)

        optimizer, extra_callbacks = self.prepare_train_args(conf=conf)
        self.network.compile(loss=self.loss, optimizer=optimizer,
                             metrics=self.metrics, loss_weights=self.loss_weights)
        self.network.summary()
        create_datasets(datagens=self.datagens.values(),
                        batch_size=conf.batch_size)
        validation_data = None
        validation_datagen = self.datagens.get("val", None)
        if validation_datagen:
            validation_data = validation_datagen.dataset
        history = self.network.fit(self.datagens['train'].dataset,
                                   validation_data=validation_data,
                                   epochs=conf.epochs + self.epochs_trained,
                                   callbacks=self.static_callbacks + extra_callbacks,
                                   initial_epoch=self.epochs_trained,
                                   verbose=1,
                                   **conf.fit_kwargs
                                   )
        self.epochs_trained = len(history.history['loss'])
        self.network.compile()

        return history

metrics abstractmethod property

Create and metrics based on config

loss abstractmethod property

Loss(es) to use during compilation

loss_weights property

Define loss weights. Usable only for setting with multiple losses. Defaults to None

base_callbacks: Iterable property

Prepare the base callbacks that are the same for pretraining and finetuning phases

get_callbacks(callback_config_list) staticmethod

Create callbacks for the

Source code in conftrainer/training/trainer.py
@staticmethod
def get_callbacks(callback_config_list: List[ObjInitConfig]) -> List[type(callbacks.Callback)]:
    """Create callbacks for the """
    return [create_object(config=config, modules=[callbacks, tf_callbacks])
            for config in callback_config_list]

prepare_train_args(conf)

Prepare the training phase parameters based on given config

Source code in conftrainer/training/trainer.py
def prepare_train_args(self, conf: OptimizationConfig) -> Tuple:
    """Prepare the training phase parameters based on given config"""
    lr = get_lr(learning_rate=conf.learning_rate, schedule=conf.schedule,
                **conf.schedule_kwargs)
    optimizer = get_optimizer(name=conf.optimizer, learning_rate=lr,
                              **conf.optimizer_kwargs)

    extra_callbacks = self.get_callbacks(self.config.callbacks_config.other)

    return optimizer, extra_callbacks

train(conf)

Train a network based on given config

Parameters:

Name Type Description Default
conf CNNOptimizationConfig

training configuration to determine learning rate, batch size and other training parameters

required

Returns:

Name Type Description
history callbacks.History

callbacks object containing information about training process

Source code in conftrainer/training/trainer.py
def train(self, conf: CNNOptimizationConfig) -> callbacks.History:
    """
    Train a network based on given config

    Parameters
    ----------
    conf : CNNOptimizationConfig
        training configuration to determine learning rate, batch size and other training
        parameters

    Returns
    -------
    history : callbacks.History
        callbacks object containing information about training process
    """
    self.network.unfreeze_layers(freeze_bn=conf.freeze_bn,
                                 unfreeze_all=conf.unfreeze_all,
                                 freeze_layer_name=conf.freeze_layer_name,
                                 submodel_name=conf.submodel_name)

    optimizer, extra_callbacks = self.prepare_train_args(conf=conf)
    self.network.compile(loss=self.loss, optimizer=optimizer,
                         metrics=self.metrics, loss_weights=self.loss_weights)
    self.network.summary()
    create_datasets(datagens=self.datagens.values(),
                    batch_size=conf.batch_size)
    validation_data = None
    validation_datagen = self.datagens.get("val", None)
    if validation_datagen:
        validation_data = validation_datagen.dataset
    history = self.network.fit(self.datagens['train'].dataset,
                               validation_data=validation_data,
                               epochs=conf.epochs + self.epochs_trained,
                               callbacks=self.static_callbacks + extra_callbacks,
                               initial_epoch=self.epochs_trained,
                               verbose=1,
                               **conf.fit_kwargs
                               )
    self.epochs_trained = len(history.history['loss'])
    self.network.compile()

    return history

Trainer

Bases: BaseTrainer

A wrapper around tensorflow network that trains and finetunes it based on given config

Source code in conftrainer/training/trainer.py
class Trainer(BaseTrainer):
    """A wrapper around tensorflow network that trains and finetunes it based on given config"""

    @property
    def metrics(self):
        """Create and metrics based on config"""
        metric_config = self.config.loss_and_metrics.metrics
        metrics = get_metrics(metric_config)
        return list(metrics.values())

    @property
    def loss(self):
        """Loss(es) to use during compilation"""
        loss_config = self.config.loss_and_metrics
        loss = get_loss(name=loss_config.loss, **loss_config.loss_kwargs)
        return loss

    def predict_and_report(self):
        """Make predictions on train, test and val datasets, and save them alongside the model"""
        for gen in self.datagens.values():
            if gen.labels is not None:
                predict_and_report(network=self.network,
                                   gen=gen,
                                   save_path=self.config.save_config.model_folder,
                                   report=True,
                                   save_df=True)

metrics property

Create and metrics based on config

loss property

Loss(es) to use during compilation

predict_and_report()

Make predictions on train, test and val datasets, and save them alongside the model

Source code in conftrainer/training/trainer.py
def predict_and_report(self):
    """Make predictions on train, test and val datasets, and save them alongside the model"""
    for gen in self.datagens.values():
        if gen.labels is not None:
            predict_and_report(network=self.network,
                               gen=gen,
                               save_path=self.config.save_config.model_folder,
                               report=True,
                               save_df=True)

MultiOutputTrainer

Bases: BaseTrainer

Trainer to work with multi output networks. Parses metrics and losses and wraps them into dictionaries

Source code in conftrainer/training/trainer.py
class MultiOutputTrainer(BaseTrainer):
    """Trainer to work with multi output networks. Parses metrics and losses and wraps them into dictionaries"""

    @property
    def metrics(self):
        metrics = {}
        for branch in self.config.net_config.branches:
            metric_values = get_metrics(branch.loss_and_metrics.metrics).values()
            metrics[branch.name] = list(metric_values)
        return metrics

    @property
    def loss(self):
        """Loss(es) to use during compilation"""
        losses = {}
        for branch in self.config.net_config.branches:
            loss_config = branch.loss_and_metrics
            losses[branch.name] = get_loss(name=loss_config.loss, **loss_config.loss_kwargs)
        return losses

    @property
    def loss_weights(self):
        """Parse loss weights for each branch and wrap them into a dictionary"""
        loss_weights = {}
        for branch in self.config.net_config.branches:
            loss_weights[branch.name] = branch.loss_and_metrics.loss_weight
        return loss_weights

loss property

Loss(es) to use during compilation

loss_weights property

Parse loss weights for each branch and wrap them into a dictionary