SCARLET / notebooks / 1_model_runs / run_humans_fixed_n_s.py
run_humans_fixed_n_s.py
Raw
###################
### Description ###
###################

# This script is running our conditional model (fixed values of N and s) on human data. 
# Relevant figures: Fig. 2a, Fig. 3a

# Inputs: 
# - Anndata object of methylation beta values, with CpGs as obs and samples as vars

# Outputs:
# - Model outputs (traces) saved as pk files for each combination of N and s

##############
### Author ###
##############
 
# Sam Crofts (sam.crofts@ed.ac.uk)

##################################
##################################

#Imports
import sys
sys.path.append("..")   # fix to import modules from root
from src.general_imports import *
import pymc.sampling.jax as pmjax
import jax

# Set seed for reproducibility
np.random.seed(125)

#Set parameters
sample_size = 100
n_sites = 100

#set s and N grid
# N_grid = [250, 500, 750, 1000, 2500, 5000, 7500, 10000, 25000, 50000]
# s_grid  = [0.10, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]

#do the above range but with even log spacing
N_grid = np.unique(np.logspace(np.log10(100), np.log10(100000), num=10, dtype=int)).tolist()
s_grid = np.unique(np.logspace(np.log10(0.01), np.log10(10.0), num=10)).tolist()

#Now combine each permutation of N_grid and s_grid - for each combination. 3x3 total.
all_tuples = list(product(N_grid, s_grid))

# Load adata
file = open('../data/genscot_full_with_site_details.pk', 'rb')
adata = pickle.load(file)
file.close()

# Replace 0 and 1
adata.X = np.where(adata.X <= 0, 0.0001, adata.X)
adata.X = np.where(adata.X >= 1, 0.9999, adata.X)

# #keep only sites that have a single peak
adata = adata[(adata.obs["n_peaks"]==1)]

# #only keep sites increasing in variance
adata = adata[adata.obs["var_change"] > 0]

# #make sure mean methylation is between 0.1 and 0.9
adata = adata[(adata.obs.mean_meth > 0.1) & (adata.obs.mean_meth < 0.9)]

# #reset ages so that they start at 0
adata.var.age = adata.var.age - adata.var.age.min()

#take the top n_sites of absolute spearman value  
adata.obs['spearman_stat'] = np.abs(adata.obs['spearman_stat'])
adata = adata[adata.obs.sort_values('spearman_stat', ascending=False).index]
adata = adata[:n_sites]

#set random seed
np.random.seed(10)
random.seed(10)

#sample to uniform with respect to age
adata = sample_to_uniform_age(adata, sample_size)

#run
for params in tqdm(all_tuples):

    sym_model = make_cond_beta_model(adata, params[0], params[1])

    with sym_model:
        trace = pmjax.sample_numpyro_nuts(chains=2, progressbar=True, target_accept=0.95, tune=5000, chain_method='vectorized', draws=1000)
        
    #calculate log likelihood
    with sym_model:
        pm.compute_log_likelihood(trace)

    #also get posterior predictive
    pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, )

    filename = "nsites" + str(n_sites) + "_sample_size_" + str(sample_size) + "_N" + str(params[0]) + "_s" + str(params[1])

    with open('../exports/model_outputs/humans/fixed_n_s/log_scale_final/'+filename+'.pk', 'wb') as f:
        pickle.dump(trace, f)