###################
### Description ###
###################
# This script is running our conditional model (fixed values of N and s) on human data.
# Relevant figures: Fig. 2a, Fig. 3a
# Inputs:
# - Anndata object of methylation beta values, with CpGs as obs and samples as vars
# Outputs:
# - Model outputs (traces) saved as pk files for each combination of N and s
##############
### Author ###
##############
# Sam Crofts (sam.crofts@ed.ac.uk)
##################################
##################################
#Imports
import sys
sys.path.append("..") # fix to import modules from root
from src.general_imports import *
import pymc.sampling.jax as pmjax
import jax
# Set seed for reproducibility
np.random.seed(125)
#Set parameters
sample_size = 100
n_sites = 100
#set s and N grid
# N_grid = [250, 500, 750, 1000, 2500, 5000, 7500, 10000, 25000, 50000]
# s_grid = [0.10, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
#do the above range but with even log spacing
N_grid = np.unique(np.logspace(np.log10(100), np.log10(100000), num=10, dtype=int)).tolist()
s_grid = np.unique(np.logspace(np.log10(0.01), np.log10(10.0), num=10)).tolist()
#Now combine each permutation of N_grid and s_grid - for each combination. 3x3 total.
all_tuples = list(product(N_grid, s_grid))
# Load adata
file = open('../data/genscot_full_with_site_details.pk', 'rb')
adata = pickle.load(file)
file.close()
# Replace 0 and 1
adata.X = np.where(adata.X <= 0, 0.0001, adata.X)
adata.X = np.where(adata.X >= 1, 0.9999, adata.X)
# #keep only sites that have a single peak
adata = adata[(adata.obs["n_peaks"]==1)]
# #only keep sites increasing in variance
adata = adata[adata.obs["var_change"] > 0]
# #make sure mean methylation is between 0.1 and 0.9
adata = adata[(adata.obs.mean_meth > 0.1) & (adata.obs.mean_meth < 0.9)]
# #reset ages so that they start at 0
adata.var.age = adata.var.age - adata.var.age.min()
#take the top n_sites of absolute spearman value
adata.obs['spearman_stat'] = np.abs(adata.obs['spearman_stat'])
adata = adata[adata.obs.sort_values('spearman_stat', ascending=False).index]
adata = adata[:n_sites]
#set random seed
np.random.seed(10)
random.seed(10)
#sample to uniform with respect to age
adata = sample_to_uniform_age(adata, sample_size)
#run
for params in tqdm(all_tuples):
sym_model = make_cond_beta_model(adata, params[0], params[1])
with sym_model:
trace = pmjax.sample_numpyro_nuts(chains=2, progressbar=True, target_accept=0.95, tune=5000, chain_method='vectorized', draws=1000)
#calculate log likelihood
with sym_model:
pm.compute_log_likelihood(trace)
#also get posterior predictive
pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, )
filename = "nsites" + str(n_sites) + "_sample_size_" + str(sample_size) + "_N" + str(params[0]) + "_s" + str(params[1])
with open('../exports/model_outputs/humans/fixed_n_s/log_scale_final/'+filename+'.pk', 'wb') as f:
pickle.dump(trace, f)