scMEDAL_for_scRNAseq / scMEDAL / utils / utils_load_model.py
utils_load_model.py
Raw

#Utils to load model weights. 
class Namespace:
    """
    A simple namespace wrapper around a dictionary. Allows dictionary keys
    to be accessed as attributes for convenience.

    Parameters:
    -----------
    adict : dict
        Dictionary to be converted to a namespace.

    Example:
    --------
    >>> d = {'key': 'value'}
    >>> ns = Namespace(d)
    >>> print(ns.key)
    'value'
    """
    def __init__(self, adict):
        self.__dict__.update(adict)


def get_last_checkpoints_path(checkpoints_path):
    """
    Fetches the last checkpoint's path from a given checkpoints directory.

    Parameters:
    -----------
    checkpoints_path : str
        Path to the directory containing checkpoint files.

    Returns:
    --------
    full_checkpoint_path : str
        Full path to the last saved model checkpoint.

    Example:
    --------
    >>> path = "/path/to/checkpoints"
    >>> print(get_last_checkpoints_path(path))
    "/path/to/checkpoints/cp-0005.ckpt"
    """
    # Read the checkpoint file
    with open(checkpoints_path + "/checkpoint", 'r') as file:
        lines = file.readlines()

    # Parse the checkpoint file to get the model_checkpoint_path
    for line in lines:
        if "model_checkpoint_path" in line:
            model_checkpoint_path = line.split(': ')[1].strip('"\n')

    # Full path
    full_checkpoint_path = checkpoints_path + "/" + model_checkpoint_path
    return full_checkpoint_path

def get_latest_checkpoint(checkpoint_dir):
    import re
    from compare_results_utils import glob_like
    """
    Retrieves the latest checkpoint file from the specified directory.

    Args:
        checkpoint_dir (str): The directory where checkpoint files are stored.

    Returns:
        str: The filename of the latest checkpoint file if available, otherwise None.
    
    Notes:
        The function searches for checkpoint files matching the pattern 'cp-*.ckpt.data-00000-of-00001'
        and returns the one with the highest numeric value in its name.
    """
    # Pattern to match the checkpoint files
    checkpoint_pattern = 'cp-*.ckpt.data-00000-of-00001'
    
    # List all matching checkpoint files using glob_like
    checkpoints = glob_like(checkpoint_dir, checkpoint_pattern)
    
    # Extract the numbers from the filenames and sort the checkpoints
    checkpoints.sort(key=lambda x: int(re.findall(r'cp-(\d+)', x)[0]))
    
    # Get the latest checkpoint
    if checkpoints:
        latest_checkpoint = checkpoints[-1]
        return latest_checkpoint
    else:
        return None


def read_model_params(checkpoints_path):
    """
    Reads the model parameters from a checkpoint directory.

    Parameters:
    -----------
    checkpoints_path : str
        Path to the directory containing checkpoint files and model parameters.

    Returns:
    --------
    model_params : Namespace
        Namespace object containing all the model parameters.

    Example:
    --------
    >>> path = "/path/to/checkpoints"
    >>> params = read_model_params(path)
    >>> print(params.layer_size)
    128
    """

    import yaml
    import glob
    model_params_path = glob.glob(checkpoints_path+"/model_params*")[0]   

    # Open and read the YAML file
    with open(model_params_path, 'r') as file:
        model_params_dict = yaml.safe_load(file)

    model_params = Namespace(model_params_dict)
    return model_params

def parse_value(value):
    import ast
    try:
        # Attempt to parse the value as a Python literal
        return ast.literal_eval(value)
    except (ValueError, SyntaxError):
        # If it fails, return the value as a str
        return str(value)

def read_model_params_frompath(model_params_path,replacements):
    """
    Reads the model parameters from a checkpoint directory.

    Parameters:
    -----------
    model_params_path : str
        Path to the yaml file containing checkpoint files and model parameters.

    Returns:
    --------
    model_params : Namespace/dict
        Namespace object/dict containing all the model parameters.

    Example:
    --------
    >>> path = "/path/to/checkpoints"
    >>> params = read_model_params(path)
    >>> print(params.layer_size)
    128
    """

    import yaml
    import glob
    # model_params_path = glob.glob(checkpoints_path+"/model_params*")[0]   

    # Open and read the YAML file
    with open(model_params_path, 'r') as file:
        model_params_dict = yaml.safe_load(file)

    try:
        model_params = Namespace(model_params_dict)
        return model_params
    except:
        print("Returning dict, fix tensorflow inputs")

        model_params_dict = {k: parse_value(v) for item in model_params_dict for k, v in [item.split(':', 1)]}
        
        for key, value in model_params_dict.items():
            if isinstance(value, str) and key in replacements:
                print(key,"replaced by",value)
                model_params_dict[key] = replacements[key]
        return model_params_dict