Skip to content

Network Architectures


Definition of building blocks for networks

get_keras_block(config, input_tensor=None, input_shape=None)

Create a block of connected keras layers based on the given configurations

Parameters:

Name Type Description Default
config KerasModelConfig

configuration to create the layers. Consists of name of the model, its layers and params

required
input_tensor Optional[tf.Tensor]

a tensor of inputs to the model. If not provided, will be inferred from input_shape. If it isn't provided too, an error will be raised

None
input_shape Optional[List[int]]

input shape of the block. Note that the shape doesn't include the batch dimension

None

Returns:

Name Type Description
model tf.keras.models.Model

a functional model

Source code in conftrainer/arch/blocks.py
def get_keras_block(config: KerasModelConfig, input_tensor: Optional[tf.Tensor] = None,
                    input_shape: Optional[List[int]] = None) -> Model:
    """
    Create a block of connected keras layers based on the given configurations

    Parameters
    ----------
    config : KerasModelConfig
        configuration to create the layers. Consists of name of the model, its layers and params
    input_tensor : Optional[tf.Tensor]
        a tensor of inputs to the model. If not provided, will be inferred from input_shape. If it isn't provided
        too, an error will be raised
    input_shape : Optional[List[int]]
        input shape of the block. Note that the shape doesn't include the batch dimension

    Returns
    -------
    model : tf.keras.models.Model
        a functional model
    """
    if config is None:
        return
    network_layers = [create_object(layer_config, modules=[keras_layers, conftrainer_auglayers])
                      for layer_config in config.layers]
    if input_tensor is None:
        assert input_shape is not None, "Please provide either input shape or input tensor when building a keras block"
        input_tensor = tf.keras.Input(input_shape)
    outputs = input_tensor
    for layer in network_layers:
        outputs = layer(outputs)
    model = tf.keras.models.Model(inputs=input_tensor, outputs=outputs, name=config.name)
    return model

get_preprocessor(config, input_shape)

Create a preprocessor with given configuration

Source code in conftrainer/arch/blocks.py
def get_preprocessor(config: PreprocessingConfig, input_shape: List[int]) -> Model:
    """Create a preprocessor with given configuration"""
    if config is None:
        return
    inputs = tf.keras.Input(input_shape)
    rescaled = keras_layers.Rescaling(scale=config.scale, offset=config.offset)(inputs)
    normalized = keras_layers.Normalization(mean=config.mean, variance=config.variance)(rescaled)

    preprocessor = Model(inputs=inputs, outputs=normalized, name=config.name)
    return preprocessor

Functional models to use

BaseCNN

Bases: Model

Extends keras's native Model. Can be used to build a Functional model or to inherit and write custom loops

Methods

unfreeze_layers() : allows to unfreeze some layers of the network or its subnetworks. Usable for fine-tuning

Source code in conftrainer/arch/models.py
class BaseCNN(Model):
    """
    Extends keras's native Model. Can be used to build a Functional model or to inherit and write custom loops

    Methods
    --------
    unfreeze_layers() :
        allows to unfreeze some layers of the network or its subnetworks. Usable for fine-tuning
    """

    def unfreeze_layers(self, freeze_layer_name: Optional[str], freeze_bn: bool,
                        unfreeze_all: bool, submodel_name: str = "backbone"):
        """
        Unfreeze some layers of the network starting from layer_name

        Parameters
        ----------
        freeze_layer_name : str
            name of the layer to start unfreezing from
        freeze_bn : bool
            whether to unfreeze all the BatchNormalization layers. If false, BatchNormalization
            layers will be treated as regular layers
        unfreeze_all : bool
            whether to unfreeze the whole network
        submodel_name: str = "backbone"
            submodel to unfreeze instead of the whole network. Is useful when the functional model consists of
            several other models as blocks
        """
        self.trainable = True
        if submodel_name:
            to_freeze = self.get_layer(submodel_name)
        else:
            to_freeze = self
        freeze_model_by_layer_name(network=to_freeze,
                                   freeze_layer_name=freeze_layer_name,
                                   freeze_bn=freeze_bn,
                                   unfreeze_all=unfreeze_all)

unfreeze_layers(freeze_layer_name, freeze_bn, unfreeze_all, submodel_name='backbone')

Unfreeze some layers of the network starting from layer_name

Parameters:

Name Type Description Default
freeze_layer_name str

name of the layer to start unfreezing from

required
freeze_bn bool

whether to unfreeze all the BatchNormalization layers. If false, BatchNormalization layers will be treated as regular layers

required
unfreeze_all bool

whether to unfreeze the whole network

required
submodel_name str

submodel to unfreeze instead of the whole network. Is useful when the functional model consists of several other models as blocks

'backbone'
Source code in conftrainer/arch/models.py
def unfreeze_layers(self, freeze_layer_name: Optional[str], freeze_bn: bool,
                    unfreeze_all: bool, submodel_name: str = "backbone"):
    """
    Unfreeze some layers of the network starting from layer_name

    Parameters
    ----------
    freeze_layer_name : str
        name of the layer to start unfreezing from
    freeze_bn : bool
        whether to unfreeze all the BatchNormalization layers. If false, BatchNormalization
        layers will be treated as regular layers
    unfreeze_all : bool
        whether to unfreeze the whole network
    submodel_name: str = "backbone"
        submodel to unfreeze instead of the whole network. Is useful when the functional model consists of
        several other models as blocks
    """
    self.trainable = True
    if submodel_name:
        to_freeze = self.get_layer(submodel_name)
    else:
        to_freeze = self
    freeze_model_by_layer_name(network=to_freeze,
                               freeze_layer_name=freeze_layer_name,
                               freeze_bn=freeze_bn,
                               unfreeze_all=unfreeze_all)

Classifier

Bases: BaseCNN

CNN classifier

Inherits from tf.keras.models.Model

Consists of 3 parts: - Augmenter - Backbone - Classification Head

Parameters:

Name Type Description Default
config ClassifierConfig

configuration with all necessary parameters. See ClassifierConfig for more details

required
Source code in conftrainer/arch/models.py
class Classifier(BaseCNN):
    """
    CNN classifier

    Inherits from tf.keras.models.Model

    Consists of 3 parts:
        - Augmenter
        - Backbone
        - Classification Head

    Parameters
    ----------
    config : ClassifierConfig
        configuration with all necessary parameters. See ClassifierConfig for more details
    """

    def __init__(self, config: ClassifierConfig, **kwargs) -> None:
        super().__init__(**kwargs)
        self.config = config
        self.augmentor = blocks.get_keras_block(config=config.augmentor, input_shape=config.input_shape)
        if config.prebuilt_path:
            self.classifier = tf.keras.models.load_model(config.prebuilt_path, compile=False)
            self.backbone = self.classifier.get_layer("backbone")
        else:
            self.backbone = get_backbone(backbone=config.backbone.name,
                                         load_path=config.backbone.load_path,
                                         backbone_args=config.backbone.params,
                                         cut_from=config.backbone.cut_from)
            self.classifier = self.create_classifier()
        self.classifier.compile()
        self.build([None] + config.input_shape)

    def create_classifier(self):
        """Create a classifier consisting of preprocessor, backbone and classification head"""
        inputs = tf.keras.Input(self.config.input_shape, name='input')
        outputs = inputs
        outputs = self.preprocessor(outputs) if self.preprocessor else outputs
        outputs = self.backbone(outputs)
        outputs = self.classification_head(outputs)
        classifier = tf.keras.models.Model(inputs=inputs, outputs=outputs, name=self.config.name)
        return classifier

    @property
    def classification_head(self) -> Model:
        """Classification head consisting of a prediction layer and optionally from hidden layers defined in config"""
        activation = ActivationType(self.config.task_type).name
        head_params_list = generate_pred_layer_params(num_classes=self.config.num_classes,
                                                      activation=activation)
        self.config.mlp_args.layers.extend(head_params_list)
        return blocks.get_keras_block(self.config.mlp_args,
                                      input_shape=self.backbone.output_shape[1:])

    @property
    def preprocessor(self) -> Model:
        """Preprocessor consisting of Normalize and Rescaling layers"""
        return blocks.get_preprocessor(self.config.preprocessor,
                                       input_shape=self.config.input_shape)

    def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
        """
        Run the model on given inputs

        Parameters
        ----------
        inputs : tf.Tensor
            batch of inputs
        training : bool
            whether to use training or inference mode. In inference mode Augmentation and
            Dropout layers are inactive, and BatchNormalization acts differently

        Returns
        -------
        out : tf.Tensor
            processed outputs
        """
        augmented = self.augmentor(inputs, training=training)
        outputs = self.classifier(augmented, training=training)

        return outputs

    def save(self, *args, **kwargs):
        """Override the save method to return classifier only"""
        return self.classifier.save(*args, **kwargs)

classification_head: Model property

Classification head consisting of a prediction layer and optionally from hidden layers defined in config

preprocessor: Model property

Preprocessor consisting of Normalize and Rescaling layers

create_classifier()

Create a classifier consisting of preprocessor, backbone and classification head

Source code in conftrainer/arch/models.py
def create_classifier(self):
    """Create a classifier consisting of preprocessor, backbone and classification head"""
    inputs = tf.keras.Input(self.config.input_shape, name='input')
    outputs = inputs
    outputs = self.preprocessor(outputs) if self.preprocessor else outputs
    outputs = self.backbone(outputs)
    outputs = self.classification_head(outputs)
    classifier = tf.keras.models.Model(inputs=inputs, outputs=outputs, name=self.config.name)
    return classifier

call(inputs, training=False)

Run the model on given inputs

Parameters:

Name Type Description Default
inputs tf.Tensor

batch of inputs

required
training bool

whether to use training or inference mode. In inference mode Augmentation and Dropout layers are inactive, and BatchNormalization acts differently

False

Returns:

Name Type Description
out tf.Tensor

processed outputs

Source code in conftrainer/arch/models.py
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
    """
    Run the model on given inputs

    Parameters
    ----------
    inputs : tf.Tensor
        batch of inputs
    training : bool
        whether to use training or inference mode. In inference mode Augmentation and
        Dropout layers are inactive, and BatchNormalization acts differently

    Returns
    -------
    out : tf.Tensor
        processed outputs
    """
    augmented = self.augmentor(inputs, training=training)
    outputs = self.classifier(augmented, training=training)

    return outputs

save(*args, **kwargs)

Override the save method to return classifier only

Source code in conftrainer/arch/models.py
def save(self, *args, **kwargs):
    """Override the save method to return classifier only"""
    return self.classifier.save(*args, **kwargs)

BYOL

Bases: BaseCNN

Implement a BYOL network

BYOL is a Siamese network consisting of: - Online network -> encoder + projection head + predictor - Target network -> encoder + projection head

BYOL creates 2 different augmentations of input image and runs them through online and target networks, then the online networks' weights are updated to minimize difference between its' and target networks predictions. Target network, on the other hand, is updated to have averaged weights of itself and the online network. This is handled by UpdateTarget callback

Parameters:

config : BYOLConfig configuration for creating a BYOL network. See BYOLConfig for more details

Methods:

call: augment the input and return their embeddings obtained from online and target network train_step: call the network and update the online network based on loss

Source code in conftrainer/arch/models.py
class BYOL(BaseCNN):
    """
    Implement a BYOL network

    BYOL is a Siamese network consisting of:
        - Online network -> encoder + projection head + predictor
        - Target network -> encoder + projection head

    BYOL creates 2 different augmentations of input image and runs them through online and target
    networks, then the online networks' weights are updated to minimize difference between its' and
    target networks predictions. Target network, on the other hand, is updated to have averaged
    weights of itself and the online network. This is handled by UpdateTarget callback

    Parameters:
    ----------
    config : BYOLConfig
        configuration for creating a BYOL network. See BYOLConfig for more details


    Methods:
    --------
    call: augment the input and return their embeddings obtained from online and target network
    train_step: call the network and update the online network based on loss
    """

    def __init__(self, config: BYOLConfig, **kwargs) -> None:
        super().__init__(**kwargs)
        self.config = config
        self.augmentor1 = blocks.get_keras_block(config=config.online_aug_args, input_shape=config.input_shape)
        self.augmentor2 = blocks.get_keras_block(config=config.target_aug_args, input_shape=config.input_shape)

        self.online_encoder = get_normalized_encoder(backbone=config.backbone.name,
                                                     load_path=config.backbone.load_path,
                                                     backbone_args=config.backbone.params,
                                                     name="online_encoder")

        self.online_projection_head = blocks.get_keras_block(config.mlp_args,
                                                             input_shape=self.online_encoder.output_shape[1:])
        self.online_predictor = blocks.get_keras_block(config.mlp_args,
                                                       input_shape=self.online_projection_head.output_shape[1:])
        self.target_encoder = get_normalized_encoder(backbone=config.backbone.name,
                                                     load_path=config.backbone.load_path,
                                                     backbone_args=config.backbone.params,
                                                     name="target_encoder")
        self.target_projection_head = blocks.get_keras_block(config.mlp_args,
                                                             input_shape=self.target_encoder.output_shape[1:])
        self.build([None, ] + config.input_shape)
        self.loss_tracker = Mean(name="loss")
        self.summary()

    @property
    def metrics(self) -> List[type(tf.keras.metrics.Metric)]:
        """
        Override the metrics property of the base class

        Define a metric tracker to use during training

        Returns
        -------
        out: list
            metric trackers of the network
        """
        return [
            self.loss_tracker
        ]

    @tf.function
    def train_step(self, data: tf.Tensor) -> dict:
        """
        Implementation of the training step for BYOL network

        The method is called during fit method and represents a single training step, during which
        a batch of inputs is processed by the network and the parameters are updated

        Parameters
        ----------
        data: tf.Tensor
                input of the network

        Returns
        -------
        out: dict
            tracked metrics
        """
        with tf.GradientTape() as tape:
            online_1, online_2, target_1, target_2 = self(data, training=True)
            loss = 1 / 2 * (self.compiled_loss(online_1, target_1) +
                            self.compiled_loss(online_2, target_2))
        gradients = tape.gradient(
            loss,
            self.online_encoder.trainable_weights +
            self.online_projection_head.trainable_weights +
            self.online_predictor.trainable_weights
        )
        self.optimizer.apply_gradients(
            zip(
                gradients,
                self.online_encoder.trainable_weights +
                self.online_projection_head.trainable_weights +
                self.online_predictor.trainable_weights
            )
        )

        self.loss_tracker.update_state(loss)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, inputs: tf.Tensor) -> dict:
        """Overwrite the test step"""
        # Use training=true so the augmentations take place
        online_1, online_2, target_1, target_2 = self(inputs, training=True)
        loss = 1 / 2 * (self.compiled_loss(online_1, target_1)
                        + self.compiled_loss(online_2, target_2))
        self.loss_tracker.update_state(loss)

        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data: tf.Tensor) -> tf.Tensor:
        """Infer on data using online encoder only. Overwrites predict_step method of the parent
        class"""
        return self.online_encoder.predict_step(data)

    def online_forward_pass(self, inputs, training: bool = False) -> tf.Tensor:
        """Online network forward pass"""
        online_features = self.online_encoder(inputs, training=training)
        online_projections = self.online_projection_head(online_features, training=training)
        online_predictions = self.online_predictor(online_projections, training=training)

        return online_predictions

    def target_forward_pass(self, inputs, training: bool = False) -> tf.Tensor:
        """Target network forward pass"""
        target_features = tf.stop_gradient(self.target_encoder(inputs, training=training))
        target_projections = tf.stop_gradient(self.target_projection_head(target_features,
                                                                          training=training))
        return target_projections

    @tf.function
    def call(self, inputs: tf.Tensor, training: bool = False):
        """
        Implement call method for the BYOL model

        This function is used for fit, predict, evaluate methods
        Parameters
        ----------
        inputs: tf.Tensor
                Inputs to the network.
        training: bool
                Whether to use training (True) or inference (False) mode.
                During inference, augmentation layers are inactive

        Returns
        -------
        out: tuple of 2 tf.Tensors
            Representations of the inputs obtained by online and target (teacher) networks
            respectively.
        """
        view1 = tf.stop_gradient(self.augmentor1(inputs, training=training))
        view2 = tf.stop_gradient(self.augmentor2(inputs, training=training))
        online_view1 = self.online_forward_pass(view1, training=training)
        online_view2 = self.online_forward_pass(view2, training=training)
        target_view1 = self.target_forward_pass(view1, training=training)
        target_view2 = self.target_forward_pass(view2, training=training)
        return online_view1, online_view2, target_view1, target_view2

    def unfreeze_layers(self, freeze_layer_name: str, freeze_bn: bool, unfreeze_all: bool, **kwargs):
        """Unfreeze the same layers both in online and target networks"""
        for mlp in [self.online_predictor, self.online_projection_head,
                    self.target_projection_head]:
            mlp.trainable = True
        for net in [self.target_encoder, self.online_encoder]:
            freeze_model_by_layer_name(network=net,
                                       freeze_layer_name=freeze_layer_name,
                                       freeze_bn=freeze_bn,
                                       unfreeze_all=unfreeze_all)

    def fit(self, *args, **kwargs):
        """Add network specific callbacks before fitting the network"""
        epochs = kwargs.get("epochs")
        num_training_batches = len(args[0]) if args else len(**kwargs.get("x"))  # num batches
        update_target = UpdateTarget(initial_beta=self.config.initial_beta,
                                     num_steps=epochs * num_training_batches,
                                     frequency=self.config.update_frequency)
        if "callbacks" in kwargs:
            kwargs.get("callbacks").append(update_target)
        else:
            kwargs["callbacks"] = [update_target]
        return super().fit(*args, **kwargs)

    def save(self, *args, **kwargs):
        """Override the save method to save the online encoder only"""
        return self.online_encoder.save(*args, **kwargs)

metrics: List[type(tf.keras.metrics.Metric)] property

Override the metrics property of the base class

Define a metric tracker to use during training

Returns:

Name Type Description
out list

metric trackers of the network

train_step(data)

Implementation of the training step for BYOL network

The method is called during fit method and represents a single training step, during which a batch of inputs is processed by the network and the parameters are updated

Parameters:

Name Type Description Default
data tf.Tensor

input of the network

required

Returns:

Name Type Description
out dict

tracked metrics

Source code in conftrainer/arch/models.py
@tf.function
def train_step(self, data: tf.Tensor) -> dict:
    """
    Implementation of the training step for BYOL network

    The method is called during fit method and represents a single training step, during which
    a batch of inputs is processed by the network and the parameters are updated

    Parameters
    ----------
    data: tf.Tensor
            input of the network

    Returns
    -------
    out: dict
        tracked metrics
    """
    with tf.GradientTape() as tape:
        online_1, online_2, target_1, target_2 = self(data, training=True)
        loss = 1 / 2 * (self.compiled_loss(online_1, target_1) +
                        self.compiled_loss(online_2, target_2))
    gradients = tape.gradient(
        loss,
        self.online_encoder.trainable_weights +
        self.online_projection_head.trainable_weights +
        self.online_predictor.trainable_weights
    )
    self.optimizer.apply_gradients(
        zip(
            gradients,
            self.online_encoder.trainable_weights +
            self.online_projection_head.trainable_weights +
            self.online_predictor.trainable_weights
        )
    )

    self.loss_tracker.update_state(loss)

    return {m.name: m.result() for m in self.metrics}

test_step(inputs)

Overwrite the test step

Source code in conftrainer/arch/models.py
def test_step(self, inputs: tf.Tensor) -> dict:
    """Overwrite the test step"""
    # Use training=true so the augmentations take place
    online_1, online_2, target_1, target_2 = self(inputs, training=True)
    loss = 1 / 2 * (self.compiled_loss(online_1, target_1)
                    + self.compiled_loss(online_2, target_2))
    self.loss_tracker.update_state(loss)

    return {m.name: m.result() for m in self.metrics}

predict_step(data)

Infer on data using online encoder only. Overwrites predict_step method of the parent class

Source code in conftrainer/arch/models.py
def predict_step(self, data: tf.Tensor) -> tf.Tensor:
    """Infer on data using online encoder only. Overwrites predict_step method of the parent
    class"""
    return self.online_encoder.predict_step(data)

online_forward_pass(inputs, training=False)

Online network forward pass

Source code in conftrainer/arch/models.py
def online_forward_pass(self, inputs, training: bool = False) -> tf.Tensor:
    """Online network forward pass"""
    online_features = self.online_encoder(inputs, training=training)
    online_projections = self.online_projection_head(online_features, training=training)
    online_predictions = self.online_predictor(online_projections, training=training)

    return online_predictions

target_forward_pass(inputs, training=False)

Target network forward pass

Source code in conftrainer/arch/models.py
def target_forward_pass(self, inputs, training: bool = False) -> tf.Tensor:
    """Target network forward pass"""
    target_features = tf.stop_gradient(self.target_encoder(inputs, training=training))
    target_projections = tf.stop_gradient(self.target_projection_head(target_features,
                                                                      training=training))
    return target_projections

call(inputs, training=False)

Implement call method for the BYOL model

This function is used for fit, predict, evaluate methods

Parameters:

Name Type Description Default
inputs tf.Tensor

Inputs to the network.

required
training bool

Whether to use training (True) or inference (False) mode. During inference, augmentation layers are inactive

False

Returns:

Name Type Description
out tuple of 2 tf.Tensors

Representations of the inputs obtained by online and target (teacher) networks respectively.

Source code in conftrainer/arch/models.py
@tf.function
def call(self, inputs: tf.Tensor, training: bool = False):
    """
    Implement call method for the BYOL model

    This function is used for fit, predict, evaluate methods
    Parameters
    ----------
    inputs: tf.Tensor
            Inputs to the network.
    training: bool
            Whether to use training (True) or inference (False) mode.
            During inference, augmentation layers are inactive

    Returns
    -------
    out: tuple of 2 tf.Tensors
        Representations of the inputs obtained by online and target (teacher) networks
        respectively.
    """
    view1 = tf.stop_gradient(self.augmentor1(inputs, training=training))
    view2 = tf.stop_gradient(self.augmentor2(inputs, training=training))
    online_view1 = self.online_forward_pass(view1, training=training)
    online_view2 = self.online_forward_pass(view2, training=training)
    target_view1 = self.target_forward_pass(view1, training=training)
    target_view2 = self.target_forward_pass(view2, training=training)
    return online_view1, online_view2, target_view1, target_view2

unfreeze_layers(freeze_layer_name, freeze_bn, unfreeze_all, **kwargs)

Unfreeze the same layers both in online and target networks

Source code in conftrainer/arch/models.py
def unfreeze_layers(self, freeze_layer_name: str, freeze_bn: bool, unfreeze_all: bool, **kwargs):
    """Unfreeze the same layers both in online and target networks"""
    for mlp in [self.online_predictor, self.online_projection_head,
                self.target_projection_head]:
        mlp.trainable = True
    for net in [self.target_encoder, self.online_encoder]:
        freeze_model_by_layer_name(network=net,
                                   freeze_layer_name=freeze_layer_name,
                                   freeze_bn=freeze_bn,
                                   unfreeze_all=unfreeze_all)

fit(*args, **kwargs)

Add network specific callbacks before fitting the network

Source code in conftrainer/arch/models.py
def fit(self, *args, **kwargs):
    """Add network specific callbacks before fitting the network"""
    epochs = kwargs.get("epochs")
    num_training_batches = len(args[0]) if args else len(**kwargs.get("x"))  # num batches
    update_target = UpdateTarget(initial_beta=self.config.initial_beta,
                                 num_steps=epochs * num_training_batches,
                                 frequency=self.config.update_frequency)
    if "callbacks" in kwargs:
        kwargs.get("callbacks").append(update_target)
    else:
        kwargs["callbacks"] = [update_target]
    return super().fit(*args, **kwargs)

save(*args, **kwargs)

Override the save method to save the online encoder only

Source code in conftrainer/arch/models.py
def save(self, *args, **kwargs):
    """Override the save method to save the online encoder only"""
    return self.online_encoder.save(*args, **kwargs)

multibranch_network(config)

Create a multibranch network based on given configuration

Parameters:

Name Type Description Default
config MultiBranchNetworkConfig

configuration for creating augmentor, preprocessor, backbone and output branches of the network

required

Returns:

Name Type Description
out BaseCNN

keras functional with multiple outputs

Source code in conftrainer/arch/models.py
def multibranch_network(config: MultiBranchNetworkConfig) -> BaseCNN:
    """
    Create a multibranch network based on given configuration

    Parameters
    ----------
    config : MultiBranchNetworkConfig
        configuration for creating augmentor, preprocessor, backbone and output branches of the network

    Returns
    -------
    out : BaseCNN
        keras functional with multiple outputs
    """
    augmentor = get_keras_block(config.augmentor, input_shape=config.input_shape)
    backbone = get_backbone(load_path=config.backbone.load_path, backbone=config.backbone.name,
                            backbone_args=config.backbone.params, add_pooling=False,
                            cut_from=config.backbone.cut_from)

    if config.mlp_args:
        hidden_block = get_keras_block(config=config.mlp_args, input_shape=backbone.output_shape[1:])
        backbone = tf.keras.models.Model(backbone.input, hidden_block(backbone.output), name='backbone')
    preprocessor = get_preprocessor(config.preprocessor, input_shape=config.input_shape)

    branches = []
    for branch in config.branches:
        activation = ActivationType(branch.task_type).name
        head_params_list = generate_pred_layer_params(num_classes=branch.num_classes,
                                                      activation=activation)
        branch.mlp.layers.extend(head_params_list)
        branches.append(get_keras_block(branch.mlp, input_shape=backbone.output_shape[1:]))

    inp = tf.keras.layers.Input(shape=(None, None, None), name='input')
    out = inp
    out = augmentor(out) if augmentor is not None else out
    out = preprocessor(out) if preprocessor is not None else out
    out = backbone(out)
    output = [branch(out) for branch in branches]

    multibranch_net = BaseCNN(inputs=inp, outputs=output, name=config.name)
    return multibranch_net