SCARLET / notebooks / 1_model_runs / run_humans_sensitivity_timespans.py
run_humans_sensitivity_timespans.py
Raw
###################
### Description ###
###################

# This script is the for checking how sensitive our N/s estimates are to the timespan of the data.

# Relevant figures: Supp. Figs 3a-b

# Inputs:
# - Anndata object of methylation beta values, with CpGs as obs and samples as vars

# Outputs:
# - Model outputs (traces) saved as pk files for each timespan analysed

##############
### Author ###
##############
 
# Sam Crofts (sam.crofts@ed.ac.uk)

##################################
##################################

#Imports
import sys
sys.path.append("..")   # fix to import modules from root
from src.general_imports import *
import pymc.sampling.jax as pmjax
import jax
import statsmodels.api as sm
import statsmodels.formula.api as smf

#Ensure GPU is being used
print(jax.default_backend())
print(jax.devices())

#other parameters
sample_size = 150
n_sites = 100
filename = "nsites" + str(n_sites) + "_ss" + str(sample_size) + "_" + selection_method

#set random seed
np.random.seed(14)
random.seed(14)

for max_age in [5, 10, 15, 20, 30, 40, 50, 60, 70]:

    # Load adata
    file = open('../data/genscot_full_with_site_details.pk', 'rb')
    adata = pickle.load(file)
    file.close()

    #reset ages to start at 0
    adata.var.age = adata.var.age - adata.var.age.min()

    #only take up to max_age
    adata = adata[:, adata.var.age <= max_age]

    #uniform age sampling
    adata = sample_to_uniform_age(adata, sample_size)

    # Replace 0 and 1
    adata.X = np.where(adata.X <= 0, 0.0001, adata.X)
    adata.X = np.where(adata.X >= 1, 0.9999, adata.X)

    #keep only sites that have a single peak
    adata = adata[(adata.obs["n_peaks"]==1)]

    #set random seed
    np.random.seed(11)
    random.seed(11)

    #recalculate spearman R based on the restricted range (now that we've trimmed the data)
    df = pd.DataFrame(adata.X, columns=adata.var.index)
    
    #make the first column the index
    df.index = adata.obs.index

    #Assign
    adata.obs['spearman_r_recal'] = df.apply(lambda x: spearmanr(x, adata.var.age)[0], axis=1)
    
    #get absolute value
    adata.obs['spearman_r_recal'] = np.abs(adata.obs['spearman_r_recal'])

    #only keep sites increasing in variance
    adata = adata[adata.obs["var_change"] > 0]

    #make sure mean methylation is between 0.1 and 0.9
    adata = adata[(adata.obs.mean_meth > 0.1) & (adata.obs.mean_meth < 0.9)]
    
    #get the top top_sites sites
    adata = adata[np.argsort(adata.obs['spearman_r_recal'])[-n_sites:]]

    #make model
    sym_model = make_mcmc_order1(adata)

    with sym_model:
        trace = pmjax.sample_numpyro_nuts(chains=2, progressbar=True, target_accept=0.98, tune=5000, chain_method='vectorized', random_seed=16)

    with open('../exports/model_outputs/humans/sensitivity_analyses/timespans/looseretanew_'+filename+'_minage_'+str(max_age)+'.pk', 'wb') as f:
        pickle.dump(trace, f)