################### ### Description ### ################### # This script is running our conditional model (fixed values of N and s) on human data. # Relevant figures: Fig. 2a, Fig. 3a # Inputs: # - Anndata object of methylation beta values, with CpGs as obs and samples as vars # Outputs: # - Model outputs (traces) saved as pk files for each combination of N and s ############## ### Author ### ############## # Sam Crofts (sam.crofts@ed.ac.uk) ################################## ################################## #Imports import sys sys.path.append("..") # fix to import modules from root from src.general_imports import * import pymc.sampling.jax as pmjax import jax # Set seed for reproducibility np.random.seed(125) #Set parameters sample_size = 100 n_sites = 100 #set s and N grid # N_grid = [250, 500, 750, 1000, 2500, 5000, 7500, 10000, 25000, 50000] # s_grid = [0.10, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0] #do the above range but with even log spacing N_grid = np.unique(np.logspace(np.log10(100), np.log10(100000), num=10, dtype=int)).tolist() s_grid = np.unique(np.logspace(np.log10(0.01), np.log10(10.0), num=10)).tolist() #Now combine each permutation of N_grid and s_grid - for each combination. 3x3 total. all_tuples = list(product(N_grid, s_grid)) # Load adata file = open('../data/genscot_full_with_site_details.pk', 'rb') adata = pickle.load(file) file.close() # Replace 0 and 1 adata.X = np.where(adata.X <= 0, 0.0001, adata.X) adata.X = np.where(adata.X >= 1, 0.9999, adata.X) # #keep only sites that have a single peak adata = adata[(adata.obs["n_peaks"]==1)] # #only keep sites increasing in variance adata = adata[adata.obs["var_change"] > 0] # #make sure mean methylation is between 0.1 and 0.9 adata = adata[(adata.obs.mean_meth > 0.1) & (adata.obs.mean_meth < 0.9)] # #reset ages so that they start at 0 adata.var.age = adata.var.age - adata.var.age.min() #take the top n_sites of absolute spearman value adata.obs['spearman_stat'] = np.abs(adata.obs['spearman_stat']) adata = adata[adata.obs.sort_values('spearman_stat', ascending=False).index] adata = adata[:n_sites] #set random seed np.random.seed(10) random.seed(10) #sample to uniform with respect to age adata = sample_to_uniform_age(adata, sample_size) #run for params in tqdm(all_tuples): sym_model = make_cond_beta_model(adata, params[0], params[1]) with sym_model: trace = pmjax.sample_numpyro_nuts(chains=2, progressbar=True, target_accept=0.95, tune=5000, chain_method='vectorized', draws=1000) #calculate log likelihood with sym_model: pm.compute_log_likelihood(trace) #also get posterior predictive pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, ) filename = "nsites" + str(n_sites) + "_sample_size_" + str(sample_size) + "_N" + str(params[0]) + "_s" + str(params[1]) with open('../exports/model_outputs/humans/fixed_n_s/log_scale_final/'+filename+'.pk', 'wb') as f: pickle.dump(trace, f)