#Utils to load model weights.
class Namespace:
"""
A simple namespace wrapper around a dictionary. Allows dictionary keys
to be accessed as attributes for convenience.
Parameters:
-----------
adict : dict
Dictionary to be converted to a namespace.
Example:
--------
>>> d = {'key': 'value'}
>>> ns = Namespace(d)
>>> print(ns.key)
'value'
"""
def __init__(self, adict):
self.__dict__.update(adict)
def get_last_checkpoints_path(checkpoints_path):
"""
Fetches the last checkpoint's path from a given checkpoints directory.
Parameters:
-----------
checkpoints_path : str
Path to the directory containing checkpoint files.
Returns:
--------
full_checkpoint_path : str
Full path to the last saved model checkpoint.
Example:
--------
>>> path = "/path/to/checkpoints"
>>> print(get_last_checkpoints_path(path))
"/path/to/checkpoints/cp-0005.ckpt"
"""
# Read the checkpoint file
with open(checkpoints_path + "/checkpoint", 'r') as file:
lines = file.readlines()
# Parse the checkpoint file to get the model_checkpoint_path
for line in lines:
if "model_checkpoint_path" in line:
model_checkpoint_path = line.split(': ')[1].strip('"\n')
# Full path
full_checkpoint_path = checkpoints_path + "/" + model_checkpoint_path
return full_checkpoint_path
def get_latest_checkpoint(checkpoint_dir):
import re
from compare_results_utils import glob_like
"""
Retrieves the latest checkpoint file from the specified directory.
Args:
checkpoint_dir (str): The directory where checkpoint files are stored.
Returns:
str: The filename of the latest checkpoint file if available, otherwise None.
Notes:
The function searches for checkpoint files matching the pattern 'cp-*.ckpt.data-00000-of-00001'
and returns the one with the highest numeric value in its name.
"""
# Pattern to match the checkpoint files
checkpoint_pattern = 'cp-*.ckpt.data-00000-of-00001'
# List all matching checkpoint files using glob_like
checkpoints = glob_like(checkpoint_dir, checkpoint_pattern)
# Extract the numbers from the filenames and sort the checkpoints
checkpoints.sort(key=lambda x: int(re.findall(r'cp-(\d+)', x)[0]))
# Get the latest checkpoint
if checkpoints:
latest_checkpoint = checkpoints[-1]
return latest_checkpoint
else:
return None
def read_model_params(checkpoints_path):
"""
Reads the model parameters from a checkpoint directory.
Parameters:
-----------
checkpoints_path : str
Path to the directory containing checkpoint files and model parameters.
Returns:
--------
model_params : Namespace
Namespace object containing all the model parameters.
Example:
--------
>>> path = "/path/to/checkpoints"
>>> params = read_model_params(path)
>>> print(params.layer_size)
128
"""
import yaml
import glob
model_params_path = glob.glob(checkpoints_path+"/model_params*")[0]
# Open and read the YAML file
with open(model_params_path, 'r') as file:
model_params_dict = yaml.safe_load(file)
model_params = Namespace(model_params_dict)
return model_params
def parse_value(value):
import ast
try:
# Attempt to parse the value as a Python literal
return ast.literal_eval(value)
except (ValueError, SyntaxError):
# If it fails, return the value as a str
return str(value)
def read_model_params_frompath(model_params_path,replacements):
"""
Reads the model parameters from a checkpoint directory.
Parameters:
-----------
model_params_path : str
Path to the yaml file containing checkpoint files and model parameters.
Returns:
--------
model_params : Namespace/dict
Namespace object/dict containing all the model parameters.
Example:
--------
>>> path = "/path/to/checkpoints"
>>> params = read_model_params(path)
>>> print(params.layer_size)
128
"""
import yaml
import glob
# model_params_path = glob.glob(checkpoints_path+"/model_params*")[0]
# Open and read the YAML file
with open(model_params_path, 'r') as file:
model_params_dict = yaml.safe_load(file)
try:
model_params = Namespace(model_params_dict)
return model_params
except:
print("Returning dict, fix tensorflow inputs")
model_params_dict = {k: parse_value(v) for item in model_params_dict for k, v in [item.split(':', 1)]}
for key, value in model_params_dict.items():
if isinstance(value, str) and key in replacements:
print(key,"replaced by",value)
model_params_dict[key] = replacements[key]
return model_params_dict