Skip to content

Utils


Helper functions

predict_and_report(network, gen, save_path='./', csv_save_path=None, report=True, save_df=True, batch_size=32, input_shape=None)

Inference the network on given dataset and save predictions in csv file. If needed, a separate report with f1-score, precision and recall for each class will be saved.

Parameters:

Name Type Description Default
network tf.keras.Model

network to predict with

required
gen src.datasets.datagen.ImageDatagen

dataset generator. Will be used to infer dataset, features and filenames

required
save_path str

root directory to save the prediction file in

'./'
report bool

whether to save separate classification report. For more info, check sklearn.metrics.classification_report

True
save_df bool

whether to save a dataframe containing predictions. It will be filtered by the number of mispredictions (from best to worst)

True
input_shape Optional[List[int]]

input shape of the network

None
batch_size int

batch size to use during inference

32
csv_save_path str = None

save path for csv. If not provided, will be inferred from save_path directory

None
Source code in conftrainer/utils.py
def predict_and_report(network: Any, gen: Any, save_path: str = "./",
                       csv_save_path: Optional[str] = None, report: bool = True,
                       save_df: bool = True, batch_size: int = 32,
                       input_shape: Optional[List[int]] = None) -> None:
    """
    Inference the network on given dataset and save predictions in csv file. If needed,
    a separate report with f1-score, precision and recall for each class will be saved.

    Parameters
    ----------
    network : tf.keras.Model
        network to predict with
    gen : src.datasets.datagen.ImageDatagen
        dataset generator. Will be used to infer dataset, features and filenames
    save_path : str
        root directory to save the prediction file in
    report : bool, default: True
        whether to save separate classification report. For more info,
        check sklearn.metrics.classification_report
    save_df : bool, default: True
        whether to save a dataframe containing predictions. It will be filtered by the number of
        mispredictions (from best to worst)
    input_shape : Optional[List[int]]
        input shape of the network
    batch_size : int, default: 32
        batch size to use during inference
    csv_save_path : str = None
        save path for csv. If not provided, will be inferred from save_path directory
    """
    dataset = gen.create_ds(shape=input_shape, batch_size=batch_size, training=False)
    dest_path = os.path.join(save_path, gen.name + "_predictions")
    y_prob = network.predict(dataset, verbose=1)
    is_multibranch = isinstance(y_prob, list)  # Regular network outputs a np.ndarray
    y_true = gen.labels
    y_true_processed = gen.probs_to_labels(y_true)
    y_pred = gen.probs_to_labels(y_prob)
    if report and (y_true is not None):
        if is_multibranch:
            report = {}
            for labels, preds, classes, name in zip(y_true_processed, y_pred, gen.classes, gen.task_names):
                report[name] = classification_report(labels,
                                                     preds,
                                                     target_names=classes,
                                                     output_dict=True
                                                     )

        else:
            report = classification_report(y_true_processed,
                                           y_pred,
                                           target_names=gen.classes,
                                           output_dict=True
                                           )
        report_path = dest_path + "_report.json"
        with open(report_path, "w+", encoding="utf-8") as file:
            json.dump(report, file)

    if save_df:
        filenames = np.array(gen.filepaths)
        col_names = gen.classes
        if is_multibranch:
            col_names = [name for task_classes in col_names for name in task_classes] if col_names else None
            y_prob = np.concatenate(y_prob, axis=1)
        df_pred = pd.DataFrame(y_prob, columns=col_names)
        df_pred["Path"] = filenames
        if not csv_save_path:
            csv_save_path = dest_path + ".csv"
        df_pred.to_csv(csv_save_path, index=False)

save_yaml(obj, filepath, default_flow_style=False)

Save the dict-like object into YAML file

Parameters:

Name Type Description Default
obj object

object to save

required
filepath str

path to the file to create

required
default_flow_style bool

corresponding parameter for yaml.dump. If True, nested dicts will be written in JSON style (A: {key: value, ...}) , otherwise block style will be used:

False
Source code in conftrainer/utils.py
def save_yaml(obj: object, filepath: str, default_flow_style: bool = False) -> None:
    """
    Save the dict-like object into YAML file

    Parameters
    ----------
    obj : object
        object to save
    filepath : str
        path to the file to create
    default_flow_style : bool
        corresponding parameter for yaml.dump. If True, nested dicts will be written in JSON
        style (A: {key: value, ...}) , otherwise block style will be used:
    """
    with open(filepath, "w+", encoding="utf-8") as file:
        yaml.dump(obj, file, default_flow_style=default_flow_style, sort_keys=False)

json_load(path, encoding='utf-8')

Load json file

Parameters:

Name Type Description Default
path str

path of json file

required
encoding str

encoding of json file

'utf-8'

Returns:

Name Type Description
obj Dict | List

loaded json file

Source code in conftrainer/utils.py
def json_load(path: str, encoding: str = 'utf-8') -> Union[Dict, List]:
    """
    Load json file

    Parameters
    ----------
    path : str
        path of json file
    encoding : str
        encoding of json file

    Returns
    -------
    obj : Dict | List
        loaded json file
    """
    with open(path, 'r', encoding=encoding) as file:
        obj = json.load(file)
        return obj

json_save(path, obj)

Save dict to json

Parameters:

Name Type Description Default
path str

path of json file

required
obj Dict

input

required
Source code in conftrainer/utils.py
def json_save(path: str, obj: Union[Dict, List[Dict]]) -> None:
    """
    Save dict to json

    Parameters
    ----------
    path : str
        path of json file
    obj : Dict
        input
    Returns
    -------
    """
    with open(path, 'w', encoding='utf-8') as file:
        json.dump(obj, file)

plot_class_metrics(y_true, y_pred, class_names)

Plot a heatmap showing f1 score, precision and recall for each class

Parameters:

Name Type Description Default
y_true np.ndarray

true and predicted labels

required
y_pred np.ndarray

true and predicted labels

required
class_names List[str]

names of classes to display

required

Returns:

Name Type Description
out Figure

heatmap with classwise metrix

Source code in conftrainer/utils.py
def plot_class_metrics(y_true: np.ndarray, y_pred: np.ndarray, class_names: List[str]) -> Figure:
    """
    Plot a heatmap showing f1 score, precision and recall for each class

    Parameters
    ----------
    y_true, y_pred : np.ndarray
        true and predicted labels
    class_names : List[str]
        names of classes to display

    Returns
    -------
    out : Figure
        heatmap with classwise metrix
    """
    precision = precision_score(y_true=y_true, y_pred=y_pred, average=None)
    recall = recall_score(y_true=y_true, y_pred=y_pred, average=None)
    f1 = f1_score(y_true=y_true, y_pred=y_pred, average=None)
    inds = np.argsort(f1)
    metrics = np.stack([f1[inds], precision[inds], recall[inds]], axis=0).round(3)
    metric_names = ["f1_score", "precision", "recall"]
    fig, ax = plt.subplots(figsize=(12, len(class_names)))

    ax.imshow(metrics.T, cmap="Wistia")

    # Show all ticks and label them with the respective list entries
    ax.set_xticks(np.arange(len(metric_names)), labels=metric_names)
    ax.set_yticks(np.arange(len(class_names)), labels=np.array(class_names)[inds])

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(class_names)):
        for j in range(len(metrics)):
            ax.text(j, i, metrics[j, i],
                    ha="center", va="center", color="b")
    return fig

create_object(config, modules)

Create an object based on given config. First the function import the class from given modules, then initializes an object with given params. Note that if several modules have the same class, first one will be imported and initialized.

Parameters:

Name Type Description Default
config ObjInitConfig

configuration to create an object. Contains name and arguments

required
modules List

modules to attempt an import from

required

Returns:

Name Type Description
out

initialized object

Source code in conftrainer/utils.py
def create_object(config: ObjInitConfig, modules: List):
    """
    Create an object based on given config. First the function import the class from given
    modules, then initializes an object with given params. Note that if several modules have the
    same class, first one will be imported and initialized.

    Parameters
    ----------
    config : ObjInitConfig
        configuration to create an object. Contains name and arguments
    modules : List
        modules to attempt an import from

    Returns
    -------
    out:
        initialized object
    """
    if not modules:
        raise ValueError("Please provide a non-empty list of modules to attempt import from")

    for module in modules:
        obj_class = getattr(module, config.name, None)
        if obj_class:
            return obj_class(**config.args)
    module_names = [module.__name__ for module in modules]
    raise ValueError(f"Please provide valid class from {module_names}. Current config: {config}")

plot_confusion_matrix(y_true, y_pred, class_names=None, **kwargs)

Calculate and plot confusion matrix

Parameters:

Name Type Description Default
y_true np.ndarray

true labels

required
y_pred np.ndarray

predictions of the network

required
class_names List[str]

names of classes to display on the plot

None

Returns:

Name Type Description
out matplotlib.figure.Figure

a plot containing confusion matrix

Source code in conftrainer/utils.py
def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, class_names: List[str] = None, **kwargs) -> Figure:
    """
    Calculate and plot confusion matrix

    Parameters
    ----------
    y_true : np.ndarray
        true labels
    y_pred : np.ndarray
        predictions of the network
    class_names : List[str]
        names of classes to display on the plot

    Returns
    -------
    out : matplotlib.figure.Figure
        a plot containing confusion matrix
    """
    matrix = confusion_matrix(y_true=y_true,
                              y_pred=y_pred,
                              **kwargs)
    matrix = np.round(matrix, 2)
    fig, ax = plt.subplots(1, 1, figsize=(16, 12))
    display = ConfusionMatrixDisplay(matrix, display_labels=class_names)
    display.plot(ax=ax, xticks_rotation="vertical")
    return fig