Skip to content

Inference


inference utils

get_filepaths(data_dir, extensions)

Given the directory path, get all filepaths of given extensions

Source code in conftrainer/inference/utils.py
def get_filepaths(data_dir: str, extensions: List[str]) -> List[str]:
    """Given the directory path, get all filepaths of given extensions"""
    paths = []
    for root, _, fn_list in os.walk(data_dir):
        paths.extend([os.path.join(root, fn) for fn in fn_list
                      if extensions is None or any(fn.endswith(extension) for extension in extensions)])

    return paths

infer_and_report(network, data_dir, csv_save_path, batch_size, class_names=None, input_shape=None, extensions=None)

Infer the network on given data

Source code in conftrainer/inference/utils.py
def infer_and_report(network: Model, data_dir: str, csv_save_path: str, batch_size: int,
                     class_names: str = None, input_shape: List[int] = None,
                     extensions: List[str] = None) -> None:
    """Infer the network on given data"""
    if input_shape is None:
        input_shape = list(network.input.get_shape()[1:])  # skip batch dimension
        assert not all(dim is None for dim in input_shape), f"Please provide either a valid input shape or a network " \
                                                            f"with defined input shape. Found: {input_shape}"

    filepaths = get_filepaths(data_dir=data_dir, extensions=extensions)
    datagen = ImageDatagen(filepaths=filepaths, labels=None, classes=class_names, shape=input_shape)
    predict_and_report(network=network,
                       gen=datagen,
                       csv_save_path=csv_save_path,
                       save_df=True,
                       report=False,
                       input_shape=input_shape,
                       batch_size=batch_size)

load_and_infer(network_path, *args, **kwargs)

Load the network and pass to infer_and_report

Source code in conftrainer/inference/utils.py
def load_and_infer(network_path: str, *args, **kwargs):
    """Load the network and pass to infer_and_report"""
    network = tf.keras.models.load_model(network_path, compile=False)
    infer_and_report(network=network, *args, **kwargs)