Utils
Helper functions
predict_and_report(network, gen, save_path='./', csv_save_path=None, report=True, save_df=True, batch_size=32, input_shape=None)
Inference the network on given dataset and save predictions in csv file. If needed, a separate report with f1-score, precision and recall for each class will be saved.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
network |
tf.keras.Model
|
network to predict with |
required |
gen |
src.datasets.datagen.ImageDatagen
|
dataset generator. Will be used to infer dataset, features and filenames |
required |
save_path |
str
|
root directory to save the prediction file in |
'./'
|
report |
bool
|
whether to save separate classification report. For more info, check sklearn.metrics.classification_report |
True
|
save_df |
bool
|
whether to save a dataframe containing predictions. It will be filtered by the number of mispredictions (from best to worst) |
True
|
input_shape |
Optional[List[int]]
|
input shape of the network |
None
|
batch_size |
int
|
batch size to use during inference |
32
|
csv_save_path |
str = None
|
save path for csv. If not provided, will be inferred from save_path directory |
None
|
Source code in conftrainer/utils.py
save_yaml(obj, filepath, default_flow_style=False)
Save the dict-like object into YAML file
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obj |
object
|
object to save |
required |
filepath |
str
|
path to the file to create |
required |
default_flow_style |
bool
|
corresponding parameter for yaml.dump. If True, nested dicts will be written in JSON style (A: {key: value, ...}) , otherwise block style will be used: |
False
|
Source code in conftrainer/utils.py
json_load(path, encoding='utf-8')
Load json file
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path |
str
|
path of json file |
required |
encoding |
str
|
encoding of json file |
'utf-8'
|
Returns:
| Name | Type | Description |
|---|---|---|
obj |
Dict | List
|
loaded json file |
Source code in conftrainer/utils.py
json_save(path, obj)
Save dict to json
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path |
str
|
path of json file |
required |
obj |
Dict
|
input |
required |
Source code in conftrainer/utils.py
plot_class_metrics(y_true, y_pred, class_names)
Plot a heatmap showing f1 score, precision and recall for each class
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_true |
np.ndarray
|
true and predicted labels |
required |
y_pred |
np.ndarray
|
true and predicted labels |
required |
class_names |
List[str]
|
names of classes to display |
required |
Returns:
| Name | Type | Description |
|---|---|---|
out |
Figure
|
heatmap with classwise metrix |
Source code in conftrainer/utils.py
create_object(config, modules)
Create an object based on given config. First the function import the class from given modules, then initializes an object with given params. Note that if several modules have the same class, first one will be imported and initialized.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config |
ObjInitConfig
|
configuration to create an object. Contains name and arguments |
required |
modules |
List
|
modules to attempt an import from |
required |
Returns:
| Name | Type | Description |
|---|---|---|
out |
initialized object |
Source code in conftrainer/utils.py
plot_confusion_matrix(y_true, y_pred, class_names=None, **kwargs)
Calculate and plot confusion matrix
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
y_true |
np.ndarray
|
true labels |
required |
y_pred |
np.ndarray
|
predictions of the network |
required |
class_names |
List[str]
|
names of classes to display on the plot |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
matplotlib.figure.Figure
|
a plot containing confusion matrix |