################### ### Description ### ################### # This script is for running SCARLET on data from different mammals, with all # mammals included in one single model. # Relevant figures: Fig. 3d, Supp. Figs. 3d-i # Inputs: # - List of Anndata objects of methylation beta values for different mammals, with CpGs as obs and samples as vars # - Each Anndata object also contains relevant metadata in .uns and .obs (e.g. lifespan) # - An Anndata object of human methylation data (replacing the human data in the mammal list to increase sample size), with CpGs as obs and samples as vars # Outputs: # - Model outputs (traces) saved as pk files for each type of model run ############## ### Author ### ############## # Sam Crofts (sam.crofts@ed.ac.uk) ################################## ################################## #Imports import sys sys.path.append("..") # fix to import modules from root from src.general_imports import * import pymc.sampling.jax as pmjax import jax #Set seed for reproducibility np.random.seed(12) random.seed(12) #set parameters n_sites = 200 min_sample_size = 100 for model_type in ["free_ns"]: # bring in mammal data file = open('../data/pan_mammal_blood_with_site_details_filtered.pk', 'rb') data_list = pickle.load(file) file.close() #bring in GenScot to replace the human data file = open('../data/genscot_full_with_site_details.pk', 'rb') humans = pickle.load(file) file.close() #Repopulate a few variables we need humans.obs['variance_second_half_recalc'] = humans.obs['variance_second_half'] humans.obs['variance_first_half_recalc'] = humans.obs['variance_first_half'] humans.uns['common_name']= 'Human' humans.uns['organism']= 'Homo sapiens' humans.uns['sex_maturity']= 13 humans.uns['lifespan']=122.5 #randomly sample 1000 humans to ease memory requirements human_indices = np.random.choice(humans.shape[1], 1000, replace=False) humans = humans[:, human_indices] #replace organism='Homo sapiens' in data_list with humans j=0 for data in data_list: if data.uns['organism'] == 'Homo sapiens': data_list[j] = humans j+=1 #let's remove any samples under the age of sexual maturity j=0 for data in data_list: data = data[:,data.var.age >= data.uns['sex_maturity']] #replace the data in the list with the new data data_list[j] = data j+=1 #just keep animals with more than 100 samples data_list = [x for x in data_list if x.shape[1] > min_sample_size] #Exclude rhesus macaque because it's not converging at this lower sample size data_list = [x for x in data_list if x.uns['common_name'] != 'Rhesus macaque'] #only keep sites with 1 peak and between 0.1 and 0.9 methylation j=0 for data in data_list: #make a var_change_recal column data.obs['var_change_recal'] = data.obs['variance_second_half_recalc'] - data.obs['variance_first_half_recalc'] #make a mean meth recal column data.obs['mean_meth_recal'] = data.X.mean(axis=1) #only keep if mean methylation is between 0.1 and 0.9 data = data[(data.obs.mean_meth_recal > 0.1) & (data.obs.mean_meth_recal < 0.9)] #maybe change this again to recal if needed - but perhaps too stringent for low sample sizes data = data[data.obs.n_peaks == 1] #make sure site is increasing in variance data = data[data.obs.var_change_recal > 0] #replace the data in the list with the new data data_list[j] = data j+=1 #let's reset all ages to be relative to the minimum age j=0 for data in data_list: data.var.age = data.var.age - data.var.age.min() data_list[j] = data j+=1 #don't include 'Sheep' - convergence issues due to small lifepsan range data_list = [x for x in data_list if x.uns['common_name'] != 'Sheep'] #get minimum number of samples min_samples = min([data.shape[1] for data in data_list]) #loop through and uniformly sample min_samples from each animal for i in range(len(data_list)): data_list[i] = sample_to_uniform_age(data_list[i], min_samples) #loop through and calculate spearman_r for each animal, keeping only the top n_sites for i in range(len(data_list)): data = data_list[i] #recalculate spearman R based on the restricted range (now that we've trimmed the data) df = pd.DataFrame(data.X, columns=data.var.index) #make the first column the index df.index = data.obs.index #Assign data.obs['spearman_r'] = df.apply(lambda x: spearmanr(x, data.var.age)[0], axis=1) #get absolute value data.obs['spearman_r'] = np.abs(data.obs['spearman_r']) #get the top top_sites sites data_list[i] = data[np.argsort(data.obs['spearman_r'])[-n_sites:]] # Combine both data into a single structure data_matrix = np.stack([data.X.T for data in data_list], axis=-1) # Combine ages for both species t_matrix = np.stack([np.tile(np.array(data.var.age)[:, None],n_sites) for data in data_list ], axis=-1) #make model sym_model = make_mcmc_beta_model_stacked(data_matrix, t_matrix, data_list, min_samples, n_sites, model_type) with sym_model: trace = pmjax.sample_numpyro_nuts(chains=2, progressbar=True, target_accept=0.95, tune=5000, chain_method='parallel', random_seed=13) #calculate log likelihood with sym_model: pm.compute_log_likelihood(trace) #also get posterior predictive pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, ) #pickle with open('../exports/model_outputs/all_mammals/joint_model_filtered_'+model_type+'.pk', 'wb') as f: pickle.dump(trace, f)