SCARLET / notebooks / 1_model_runs / run_humans_cohorts_unconditional.py
run_humans_cohorts_unconditional.py
Raw
###################
### Description ###
###################

# This script is for running SCARLET (unconditional, i.e. without fixed N and s) on human data across various cohorts
# Relevant figures: Fig. 2c, Supp. Fig. 2a

# Inputs:
# - Anndata object of methylation beta values, with CpGs as obs and samples as vars
# - List of non-cell-composition-related CpGs (csv, taken from https://doi.org/10.1038/s42003-024-06609-4)
# - For cohort definitions: smoking status and sex information (in adata.var)
# - For accelerated cohorts: acceleration values (csv, generated according to Dabrowski et al. 2024)

# Outputs:
# - Model outputs (traces) saved as pk files for each cohort

##############
### Author ###
##############
 
# Sam Crofts (sam.crofts@ed.ac.uk)

##################################
##################################

# Imports
import sys
sys.path.append("..")   # fix to import modules from root
from src.general_imports import *
import pymc.sampling.jax as pmjax
import jax

print(jax.default_backend())
print(jax.devices())

# Variable for site selection (spearman = based on mean change, white = based on variance change)
selection_method = "white"

#Other parameters
n_sites = 500
sample_size = 500

#set random seed
np.random.seed(10)
random.seed(10)

#cohorts
for group in ["smokers", "non_smokers", "male", "female", "high_acc", "low_acc"]:

    filename = "nsites_"+str(n_sites)+"_ss_"+str(sample_size)+"selection_method_"+selection_method+"_"+group

    # Load adata
    file = open('../data/genscot_full_with_site_details.pk', 'rb')
    adata = pickle.load(file)
    file.close()

    #import list of sites to keep (sites unrelated to cell composition, taken from https://doi.org/10.1038/s42003-024-06609-4)
    keep_sites = pd.read_csv('../data/non_composition_cpgs.csv', index_col=0)

    adata = adata[adata.obs.index.isin(keep_sites.index)]

    # Replace 0 and 1
    adata.X = np.where(adata.X <= 0, 0.0001, adata.X)
    adata.X = np.where(adata.X >= 1, 0.9999, adata.X)

    #keep only sites that have a single peak
    adata = adata[(adata.obs["n_peaks"]==1)]

    #only keep sites increasing in variance
    adata = adata[adata.obs["var_change"] > 0]

    #make sure mean methylation is between 0.1 and 0.9
    adata = adata[(adata.obs.mean_meth > 0.1) & (adata.obs.mean_meth < 0.9)]

    #reset ages so that they start at 0
    adata.var.age = adata.var.age - adata.var.age.min()

    #set random seed
    np.random.seed(10)
    random.seed(10)

    #keep only smokers
    smokers = adata[:,adata.var.weighted_smoke > 0.25].copy()

    #uniformly sample sample_size
    smokers = sample_to_uniform_age(smokers, sample_size)

    if ((group == "male")|(group == "female")):

        #keep only males
        males = adata[:,adata.var.sex == 'M'].copy()

        #make sure non-smokers
        males = males[:,males.var.weighted_smoke < 0.25]

        #age-match to smokers
        males = match_ages(males, smokers)

        #same for females
        females = adata[:,adata.var.sex == 'F'].copy()

        #make sure non-smokers
        females = females[:,females.var.weighted_smoke < 0.25]

        #okay, now age-match a random sample of females to the males
        females = match_ages(females, males)

        if (group == "males"):
            adata = males
        else:
            adata = females
        
    elif ((group == "smokers")):
    
        adata = smokers

    elif ((group == "non_smokers")):
        
            #keep only non-smokers
            adata = adata[:,adata.var.weighted_smoke < 0.25]
            
            #age-match to smokers
            adata = match_ages(adata, smokers)

    elif ((group == "high_acc") | (group == "low_acc")):

        #bring in acceleration values (csv) (generated according to Dabrowski et al. 2024)
        acc_df = pd.read_csv('../data/genscot_acc_and_bias.csv')

        #add to adata.var based on Basename in acc_df and index in adata.var
        adata.var['acc'] = [acc_df[acc_df['Basename'] == basename]['acc_wave3'].values[0] for basename in adata.var.index]
        adata.var['bias'] = [acc_df[acc_df['Basename'] == basename]['bias_wave3'].values[0] for basename in adata.var.index]

        #take top 500 acceleration people
        adata_high_acc = adata[:,adata.var.sort_values('acc', ascending=False).index]
        adata_high_acc = adata_high_acc[:,:sample_size]

        #take those with a low acceleration (practically defined as < -0.2, so there are still enough participants to indivudally match)
        adata_low_acc = adata[:,adata.var.acc<-0.2]
        #age-match to high_acc
        adata_low_acc = match_ages(adata_low_acc, adata_high_acc)

        #looks pretty good
        if (group == "high_acc"):
            adata = adata_high_acc
        else:
            adata = adata_low_acc
        
    if selection_method == "spearman":
        #Get absolute value of spearman stat
        adata.obs['spearman_stat'] = np.abs(adata.obs['spearman_stat'])
        adata = adata[adata.obs.sort_values('spearman_stat', ascending=False).index]
        adata = adata[:n_sites]

    elif selection_method == "white":
        #take the lowest white pval sites
        adata = adata[adata.obs.sort_values('white_pval', ascending=True).index]
        adata = adata[:n_sites]
            
    #make model
    sym_model = make_mcmc_order1(adata)

    with sym_model:
        trace = pmjax.sample_numpyro_nuts(chains=2, progressbar=True, target_accept=0.95, tune=5000, chain_method='vectorized')

    #also get posterior predictive
    pm.sample_posterior_predictive(trace, model=sym_model, extend_inferencedata=True, )

    with open('../exports/model_outputs/humans/cohort_analyses/'+filename+'.pk', 'wb') as f:
        pickle.dump(trace, f)