SCARLET / notebooks / 1_model_runs / run_mammals_joint_models.py
run_mammals_joint_models.py
Raw
###################
### Description ###
###################

# This script is for running SCARLET on data from different mammals, with all 
# mammals included in one single model.

# Relevant figures: Fig. 3d, Supp. Figs. 3d-i

# Inputs:
# - List of Anndata objects of methylation beta values for different mammals, with CpGs as obs and samples as vars
#   - Each Anndata object also contains relevant metadata in .uns and .obs (e.g. lifespan)
# - An Anndata object of human methylation data (replacing the human data in the mammal list to increase sample size), with CpGs as obs and samples as vars

# Outputs:
# - Model outputs (traces) saved as pk files for each type of model run

##############
### Author ###
##############
 
# Sam Crofts (sam.crofts@ed.ac.uk)

##################################
##################################

#Imports
import sys
sys.path.append("..")   # fix to import modules from root
from src.general_imports import *
import pymc.sampling.jax as pmjax
import jax

#Set seed for reproducibility
np.random.seed(12)
random.seed(12)

#set parameters
n_sites = 200
min_sample_size = 100

for model_type in ["free_ns"]:
    
    # bring in mammal data
    file = open('../data/pan_mammal_blood_with_site_details_filtered.pk', 'rb')
    data_list = pickle.load(file)
    file.close()

    #bring in GenScot to replace the human data
    file = open('../data/genscot_full_with_site_details.pk', 'rb')
    humans = pickle.load(file)
    file.close()

    #Repopulate a few variables we need
    humans.obs['variance_second_half_recalc'] = humans.obs['variance_second_half']
    humans.obs['variance_first_half_recalc'] = humans.obs['variance_first_half']
    humans.uns['common_name']= 'Human'
    humans.uns['organism']= 'Homo sapiens'
    humans.uns['sex_maturity']= 13
    humans.uns['lifespan']=122.5

    #randomly sample 1000 humans to ease memory requirements
    human_indices = np.random.choice(humans.shape[1], 1000, replace=False)
    humans = humans[:, human_indices]

    #replace organism='Homo sapiens' in data_list with humans
    j=0
    for data in data_list:
        if data.uns['organism'] == 'Homo sapiens':
            data_list[j] = humans
        j+=1

    #let's remove any samples under the age of sexual maturity
    j=0
    for data in data_list:
        data = data[:,data.var.age >= data.uns['sex_maturity']]
        #replace the data in the list with the new data
        data_list[j] = data
        j+=1

    #just keep animals with more than 100 samples
    data_list = [x for x in data_list if x.shape[1] > min_sample_size]

    #Exclude rhesus macaque because it's not converging at this lower sample size
    data_list = [x for x in data_list if x.uns['common_name'] != 'Rhesus macaque']

    #only keep sites with 1 peak and between 0.1 and 0.9 methylation
    j=0
    for data in data_list:

        #make a var_change_recal column
        data.obs['var_change_recal'] = data.obs['variance_second_half_recalc'] - data.obs['variance_first_half_recalc']

        #make a mean meth recal column
        data.obs['mean_meth_recal'] = data.X.mean(axis=1)

        #only keep if mean methylation is between 0.1 and 0.9
        data = data[(data.obs.mean_meth_recal > 0.1) & (data.obs.mean_meth_recal < 0.9)]

        #maybe change this again to recal if needed - but perhaps too stringent for low sample sizes
        data = data[data.obs.n_peaks == 1]

        #make sure site is increasing in variance
        data = data[data.obs.var_change_recal > 0]

        #replace the data in the list with the new data
        data_list[j] = data
        j+=1

    #let's reset all ages to be relative to the minimum age
    j=0
    for data in data_list:
        data.var.age = data.var.age - data.var.age.min()
        data_list[j] = data
        j+=1

    #don't include 'Sheep' - convergence issues due to small lifepsan range
    data_list = [x for x in data_list if x.uns['common_name'] != 'Sheep']

    #get minimum number of samples
    min_samples = min([data.shape[1] for data in data_list])

    #loop through and uniformly sample min_samples from each animal
    for i in range(len(data_list)):
        data_list[i] = sample_to_uniform_age(data_list[i], min_samples)

    #loop through and calculate spearman_r for each animal, keeping only the top n_sites
    for i in range(len(data_list)):
        
        data = data_list[i]
        #recalculate spearman R based on the restricted range (now that we've trimmed the data)
        df = pd.DataFrame(data.X, columns=data.var.index)
        #make the first column the index
        df.index = data.obs.index
        #Assign
        data.obs['spearman_r'] = df.apply(lambda x: spearmanr(x, data.var.age)[0], axis=1)
        #get absolute value
        data.obs['spearman_r'] = np.abs(data.obs['spearman_r'])
        #get the top top_sites sites
        data_list[i] = data[np.argsort(data.obs['spearman_r'])[-n_sites:]]

    # Combine both data into a single structure
    data_matrix = np.stack([data.X.T for data in data_list], axis=-1)

    # Combine ages for both species
    t_matrix = np.stack([np.tile(np.array(data.var.age)[:, None],n_sites)
                        for data in data_list ], axis=-1)

    #make model
    sym_model = make_mcmc_beta_model_stacked(data_matrix, t_matrix, data_list, min_samples, n_sites, model_type)

    with sym_model:
        trace = pmjax.sample_numpyro_nuts(chains=2, progressbar=True, target_accept=0.95, tune=5000, chain_method='parallel', random_seed=13)

    #calculate log likelihood
    with sym_model:
        pm.compute_log_likelihood(trace)

    #also get posterior predictive
    pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, )

    #pickle
    with open('../exports/model_outputs/all_mammals/joint_model_filtered_'+model_type+'.pk', 'wb') as f:
        pickle.dump(trace, f)