SCARLET / notebooks / 1_model_runs / run_mammals_separate_models.py
run_mammals_separate_models.py
Raw
###################
### Description ###
###################

# This script is for running SCARLET on data from different mammals, with each species run separately.

# Relevant figures: Fig. 3b, Fig. 3c, Supp. Fig. 3c

# Inputs:
# - List of Anndata objects of methylation beta values for different mammals, with CpGs as obs and samples as vars
#   - Each Anndata object also contains relevant metadata in .uns and .obs (e.g. lifespan)
# - An Anndata object of human methylation data (replacing the human data in the mammal list to increase sample size), with CpGs as obs and samples as vars

# Outputs:
# - Model outputs (traces) saved as pk files for each type of model run

##############
### Author ###
##############
 
# Sam Crofts (sam.crofts@ed.ac.uk)

##################################
##################################

# Imports
import sys
sys.path.append("..")   # fix to import modules from root
from src.general_imports import *
from statsmodels.regression.linear_model import OLS
from statsmodels.stats.diagnostic import het_white
import pymc.sampling.jax as pmjax
import jax
import statsmodels.api as sm
import statsmodels.formula.api as smf

#ensure running on gpu
print(jax.default_backend())
print(jax.devices())

# Variable for site selection (spearman = based on mean change, white = based on variance change)
selection_method = "spearman"

#Other parameters
n_sites = 200
min_sample_size = 50

# set seed for reproducibility
np.random.seed(12)

# bring in mammal data
file = open('../data/pan_mammal_blood_with_site_details_filtered.pk', 'rb')
data_list = pickle.load(file)
file.close()

#bring in GenScot to replace the human data
file = open('../data/genscot_full_with_site_details.pk', 'rb')
humans = pickle.load(file)
file.close()

# #Repopulate a few variables we need
humans.obs['variance_second_half_recalc'] = humans.obs['variance_second_half']
humans.obs['variance_first_half_recalc'] = humans.obs['variance_first_half']
humans.uns['common_name']= 'Human'
humans.uns['organism']= 'Homo sapiens'
humans.uns['sex_maturity']= 13
humans.uns['lifespan']=122.5

#replace organism='Homo sapiens' in data_list with humans
j=0
for data in data_list:
    if data.uns['organism'] == 'Homo sapiens':
        data_list[j] = humans
    j+=1

#let's remove any samples under the age of sexual maturity
j=0
for data in data_list:
    data = data[:,data.var.age >= data.uns['sex_maturity']]
    #replace the data in the list with the new data
    data_list[j] = data
    j+=1

#uniformly sample humans to just 250 samples
j=0
for data in data_list:
    if data.uns['organism'] == 'Homo sapiens':
        data_list[j] = sample_to_uniform_age(humans, 250)

        #also sample .obs to 50000 random sites for humans
        data_list[j] = data_list[j][np.random.choice(data_list[j].obs.index, 50000, replace=False)]

    j+=1

#filter sites
j=0
for data in data_list:

    #make a var_change_recal column
    data.obs['var_change_recal'] = data.obs['variance_second_half_recalc'] - data.obs['variance_first_half_recalc']

    #make a mean meth recal column
    data.obs['mean_meth_recal'] = data.X.mean(axis=1)

    #only keep if mean methylation is between 0.1 and 0.9
    data = data[(data.obs.mean_meth_recal > 0.1) & (data.obs.mean_meth_recal < 0.9)]

    #remove bimodal sites
    data = data[data.obs.n_peaks == 1]

    #make sure site is increasing in variance
    data = data[data.obs.var_change_recal > 0]

    #replace the data in the list with the new data
    data_list[j] = data
    j+=1

#let's reset all ages to be relative to the minimum age
j=0
for data in data_list:
    data.var.age = data.var.age - data.var.age.min()
    data_list[j] = data
    j+=1

#just keep animals with more than min_sample_size samples
data_list = [x for x in data_list if x.shape[1] > min_sample_size]

#only keep 
names = ["Macaca mulatta", "Tursiops truncatus", "Ovis aries"]
data_list = [x for x in data_list if x.uns['organism'] in names]

#for all animals in data_list
for animal_data in data_list:

    #get lifespan
    lifespan = animal_data.uns['lifespan']
    name = animal_data.uns['organism']

    #print name
    print(name)

    #if bos taurus, remove a few problem samples
    if name == 'Bos taurus':
        #remove observations if starts with 204027420028 or 204027420029
        for animal_index in animal_data.var.index:
            if animal_index.startswith('204027420028') or animal_index.startswith('204027420029'):
                animal_data = animal_data[:,animal_data.var.index != animal_index]
                
    ###recalculate CpG attributes (e.g. R2) based on the restricted range (now that we've trimmed the data)###
    df = pd.DataFrame(animal_data.X, columns=animal_data.var.index)
    #make the first column the index
    df.index = animal_data.obs.index

    if selection_method == "spearman":
        #Assign
        animal_data.obs['spearman_r'] = df.apply(lambda x: spearmanr(x, animal_data.var.age)[0], axis=1)

        #get absolute value
        animal_data.obs['spearman_r'] = np.abs(animal_data.obs['spearman_r'])

        # #get the top top_sites sites
        animal_data = animal_data[np.argsort(animal_data.obs['spearman_r'])[-n_sites:]]

    elif selection_method == "white":

        #initialise columns
        animal_data.obs['beta_pval_recal'] = np.nan
        animal_data.obs['white_pval_recal'] = np.nan
        animal_data.obs['var_change_recal'] = np.nan

        # Perform the White test and OLS regression for each row
        for site in animal_data.obs.index:
            data = {'age': animal_data.var.age.values, 'meth': animal_data[site].X.T[:,0]}
            df = pd.DataFrame(data)
            
            #get median age
            median_age = np.median(animal_data.var.age.values)

            #calculate variances for each half
            below_median_values = animal_data[site].X.T[animal_data.var.age.values < median_age]
            above_median_values = animal_data[site].X.T[animal_data.var.age.values >= median_age]
            variance_first_half = np.var(below_median_values)
            variance_second_half = np.var(above_median_values)

            #get variance change 
            var_change = variance_second_half - variance_first_half

            model = smf.ols('meth ~ age', data=df).fit()
            residuals = model.resid
            beta_pval = model.pvalues["age"]

            # Perform White test for heteroscedasticity
            _, p_value_white, _, _ = sm.stats.diagnostic.het_white(residuals, model.model.exog)

            animal_data.obs.loc[site, 'beta_pval_recal'] = beta_pval
            animal_data.obs.loc[site, 'white_pval_recal'] = p_value_white
            animal_data.obs.loc[site, 'var_change_recal'] = var_change

        #now sort by white_pval_recal, and take the top n_sites
        animal_data = animal_data[animal_data.obs.sort_values('white_pval_recal', ascending=True).index]

        #take the first n_sites
        animal_data = animal_data[:n_sites]

    #make model
    sym_model = make_mcmc_order1(animal_data)

    with sym_model:
        #for difficult ones, let's tune 5000
        trace = pmjax.sample_numpyro_nuts(progressbar=True, target_accept=0.98, tune=5000, chain_method='parallel', random_seed=18, chains=2)

    #also get posterior predictive
    # pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, )

    with open('../exports/model_outputs/all_mammals/spearman_filtered_separate_model_min_sample_size_'+str(min_sample_size)+'_nsites_'+str(n_sites)+'_'+name+'.pk', 'wb') as f:
        pickle.dump(trace, f)