###################
### Description ###
###################
# This script is for running SCARLET on data from different mammals, with each species run separately.
# Relevant figures: Fig. 3b, Fig. 3c, Supp. Fig. 3c
# Inputs:
# - List of Anndata objects of methylation beta values for different mammals, with CpGs as obs and samples as vars
# - Each Anndata object also contains relevant metadata in .uns and .obs (e.g. lifespan)
# - An Anndata object of human methylation data (replacing the human data in the mammal list to increase sample size), with CpGs as obs and samples as vars
# Outputs:
# - Model outputs (traces) saved as pk files for each type of model run
##############
### Author ###
##############
# Sam Crofts (sam.crofts@ed.ac.uk)
##################################
##################################
# Imports
import sys
sys.path.append("..") # fix to import modules from root
from src.general_imports import *
from statsmodels.regression.linear_model import OLS
from statsmodels.stats.diagnostic import het_white
import pymc.sampling.jax as pmjax
import jax
import statsmodels.api as sm
import statsmodels.formula.api as smf
#ensure running on gpu
print(jax.default_backend())
print(jax.devices())
# Variable for site selection (spearman = based on mean change, white = based on variance change)
selection_method = "spearman"
#Other parameters
n_sites = 200
min_sample_size = 50
# set seed for reproducibility
np.random.seed(12)
# bring in mammal data
file = open('../data/pan_mammal_blood_with_site_details_filtered.pk', 'rb')
data_list = pickle.load(file)
file.close()
#bring in GenScot to replace the human data
file = open('../data/genscot_full_with_site_details.pk', 'rb')
humans = pickle.load(file)
file.close()
# #Repopulate a few variables we need
humans.obs['variance_second_half_recalc'] = humans.obs['variance_second_half']
humans.obs['variance_first_half_recalc'] = humans.obs['variance_first_half']
humans.uns['common_name']= 'Human'
humans.uns['organism']= 'Homo sapiens'
humans.uns['sex_maturity']= 13
humans.uns['lifespan']=122.5
#replace organism='Homo sapiens' in data_list with humans
j=0
for data in data_list:
if data.uns['organism'] == 'Homo sapiens':
data_list[j] = humans
j+=1
#let's remove any samples under the age of sexual maturity
j=0
for data in data_list:
data = data[:,data.var.age >= data.uns['sex_maturity']]
#replace the data in the list with the new data
data_list[j] = data
j+=1
#uniformly sample humans to just 250 samples
j=0
for data in data_list:
if data.uns['organism'] == 'Homo sapiens':
data_list[j] = sample_to_uniform_age(humans, 250)
#also sample .obs to 50000 random sites for humans
data_list[j] = data_list[j][np.random.choice(data_list[j].obs.index, 50000, replace=False)]
j+=1
#filter sites
j=0
for data in data_list:
#make a var_change_recal column
data.obs['var_change_recal'] = data.obs['variance_second_half_recalc'] - data.obs['variance_first_half_recalc']
#make a mean meth recal column
data.obs['mean_meth_recal'] = data.X.mean(axis=1)
#only keep if mean methylation is between 0.1 and 0.9
data = data[(data.obs.mean_meth_recal > 0.1) & (data.obs.mean_meth_recal < 0.9)]
#remove bimodal sites
data = data[data.obs.n_peaks == 1]
#make sure site is increasing in variance
data = data[data.obs.var_change_recal > 0]
#replace the data in the list with the new data
data_list[j] = data
j+=1
#let's reset all ages to be relative to the minimum age
j=0
for data in data_list:
data.var.age = data.var.age - data.var.age.min()
data_list[j] = data
j+=1
#just keep animals with more than min_sample_size samples
data_list = [x for x in data_list if x.shape[1] > min_sample_size]
#only keep
names = ["Macaca mulatta", "Tursiops truncatus", "Ovis aries"]
data_list = [x for x in data_list if x.uns['organism'] in names]
#for all animals in data_list
for animal_data in data_list:
#get lifespan
lifespan = animal_data.uns['lifespan']
name = animal_data.uns['organism']
#print name
print(name)
#if bos taurus, remove a few problem samples
if name == 'Bos taurus':
#remove observations if starts with 204027420028 or 204027420029
for animal_index in animal_data.var.index:
if animal_index.startswith('204027420028') or animal_index.startswith('204027420029'):
animal_data = animal_data[:,animal_data.var.index != animal_index]
###recalculate CpG attributes (e.g. R2) based on the restricted range (now that we've trimmed the data)###
df = pd.DataFrame(animal_data.X, columns=animal_data.var.index)
#make the first column the index
df.index = animal_data.obs.index
if selection_method == "spearman":
#Assign
animal_data.obs['spearman_r'] = df.apply(lambda x: spearmanr(x, animal_data.var.age)[0], axis=1)
#get absolute value
animal_data.obs['spearman_r'] = np.abs(animal_data.obs['spearman_r'])
# #get the top top_sites sites
animal_data = animal_data[np.argsort(animal_data.obs['spearman_r'])[-n_sites:]]
elif selection_method == "white":
#initialise columns
animal_data.obs['beta_pval_recal'] = np.nan
animal_data.obs['white_pval_recal'] = np.nan
animal_data.obs['var_change_recal'] = np.nan
# Perform the White test and OLS regression for each row
for site in animal_data.obs.index:
data = {'age': animal_data.var.age.values, 'meth': animal_data[site].X.T[:,0]}
df = pd.DataFrame(data)
#get median age
median_age = np.median(animal_data.var.age.values)
#calculate variances for each half
below_median_values = animal_data[site].X.T[animal_data.var.age.values < median_age]
above_median_values = animal_data[site].X.T[animal_data.var.age.values >= median_age]
variance_first_half = np.var(below_median_values)
variance_second_half = np.var(above_median_values)
#get variance change
var_change = variance_second_half - variance_first_half
model = smf.ols('meth ~ age', data=df).fit()
residuals = model.resid
beta_pval = model.pvalues["age"]
# Perform White test for heteroscedasticity
_, p_value_white, _, _ = sm.stats.diagnostic.het_white(residuals, model.model.exog)
animal_data.obs.loc[site, 'beta_pval_recal'] = beta_pval
animal_data.obs.loc[site, 'white_pval_recal'] = p_value_white
animal_data.obs.loc[site, 'var_change_recal'] = var_change
#now sort by white_pval_recal, and take the top n_sites
animal_data = animal_data[animal_data.obs.sort_values('white_pval_recal', ascending=True).index]
#take the first n_sites
animal_data = animal_data[:n_sites]
#make model
sym_model = make_mcmc_order1(animal_data)
with sym_model:
#for difficult ones, let's tune 5000
trace = pmjax.sample_numpyro_nuts(progressbar=True, target_accept=0.98, tune=5000, chain_method='parallel', random_seed=18, chains=2)
#also get posterior predictive
# pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, )
with open('../exports/model_outputs/all_mammals/spearman_filtered_separate_model_min_sample_size_'+str(min_sample_size)+'_nsites_'+str(n_sites)+'_'+name+'.pk', 'wb') as f:
pickle.dump(trace, f)