SCARLET / notebooks / 2_post_run_analyses / analysis_humans.py
analysis_humans.py
Raw
###################
### Description ###
###################

# This script is for the various analyses of human (GenScot) data.

# Includes: 
# - Plotting heatmap of log likelihoods across N and s (Fig. 2a)
# - Plotting N/s by group (Fig. 2c, Supp. Fig. 2a)
# - Plotting site fits across different categories of CpGs using conditional model (Fig. 2b, Supp. Fig. 1a)
# - Looking at how different categories of sites vary in parameters (Supp. Fig. 1b-c)
# - Outputting a table of distribution of white p values and spearman rho values for Supp. Table 1 

##############
### Author ###
##############
 
# Sam Crofts (sam.crofts@ed.ac.uk)

########################################################################################################################

#Imports
import os
import sys
sys.path.append("../..")   # fix to import modules from root
from src.general_imports import *
from statsmodels.nonparametric.smoothers_lowess import lowess
from scipy.stats import beta

################################
###Plotting heatmap (Fig. 2a)###
################################

#go through each file in exports/model_outputs/humans/fixed_n_s, read in the trace, and save the N and s values to a dataframe
df = pd.DataFrame(columns=['N', 's', 'elpd_loo'])
files = os.listdir('../../exports/model_outputs/humans/fixed_n_s/log_scale/')

for f in files:
    #read in the file(s) with site parameters (.pk files)
    if f.endswith('.pk'):
        with open(f'../../exports/model_outputs/humans/fixed_n_s/log_scale/{f}', 'rb') as file:
            trace = pickle.load(file)
        file.close()

        #get N and s
        N = (f.split('_')[4])
        N = int(N[1:])

        s= (f.split('_')[5])
        s = (s[1:])

        #remove last 3 characters (.pk)
        s = s[:-3]
        s = float(s)

        #calculate LOO
        elpd_loo = az.loo(trace).elpd_loo
        
        #append to df
        df = df._append({'N': N, 's': s, 'elpd_loo': elpd_loo}, ignore_index=True)

#save df
df.to_csv('../../exports/model_outputs/humans/fixed_n_s/heatmap_likelihoods_log_scale.csv', index=False)

#re-import
df = pd.read_csv('../../exports/model_outputs/humans/fixed_n_s/heatmap_likelihoods_log_scale.csv') 

df_backup = df.copy()

#keep if N>=50000
df = df[df['N']<=50000]

plt.figure(figsize=(8.5, 6))
heatmap = sns.heatmap(df.pivot(values='elpd_loo', columns='N', index='s'), cmap='viridis', annot=False, fmt='%.1g')
#tight
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/heatmap_elpd_log_scale.png', dpi=600)

#####################################
###Plotting N/s by group (Fig. 2c)###
#####################################

#vars = ['high_acc', 'low_acc']
vars = ['female', 'male', 'smokers']
vars = ['high_acc', 'low_acc']

selection_method = "spearman" #or "white"

df = pd.DataFrame(columns=['type', 'mean', 'hdi_2.5%', 'hdi_97.5%'])

for var in vars:
    file_path = f'../../exports/model_outputs/humans/cohort_analyses/nsites_500_ss_500selection_method_{selection_method}_{var}.pk'
    with open(file_path, 'rb') as file:
        exec(var+"= pickle.load(file)")

    #plot N/s
    az.plot_posterior(eval(var), var_names='N/s')

    exec("summary_"+var+"= pm.summary("+var+", hdi_prob=0.95, var_names=['N/s'])")
    exec("df = df._append({'type': '"+var+"', 'mean': summary_"+var+"['mean']['N/s'], 'hdi_2.5%': summary_"+var+"['hdi_2.5%']['N/s'], 'hdi_97.5%': summary_"+var+"['hdi_97.5%']['N/s']}, ignore_index=True)")

#Plot
plt.figure(figsize=(3, 6))
plt.errorbar(df['type'], df['mean'], 
             yerr=[df['mean']-df['hdi_2.5%'], df['hdi_97.5%']-df['mean']], 
             fmt='o', color='royalblue', ecolor='lightblue', 
             elinewidth=4, capsize=0)

# # Adjust the plot
plt.title('N/s with 95% HPDI')
plt.ylabel('N/s')

# Reduce space on x-axis between categories
ax = plt.gca()
#ax.set_xlim(-0.2, 1.2) #for two categories
ax.set_xlim(-0.5, 2.5) #for three categories
plt.gcf().set_size_inches(3, 4.5)  # Adjust these values to control category spacing

plt.tight_layout()
plt.savefig('../../exports/figures/group_comparison_smokers_spearman.png', dpi=600)

########################################################
###Plotting sites (using conditional model) (Fig. 2b)###
########################################################

#sites types
types = ["vmp", "dmp", "vmp_dmp", "non_linear_saturating"]

#site and sample size combinations
sites_and_sample_size_combo = ["nsites_200_ss_200", "nsites_200_ss_500",
                            "nsites_500_ss_200"]

for site_and_sample_size in sites_and_sample_size_combo:
    
    for type in types:

        #bring in traces
        trace_1000 = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_'+site_and_sample_size+'_'+type+'N_1000s_1.pk', 'rb'))
        trace_10000 = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_'+site_and_sample_size+'_'+type+'N_10000s_1.pk', 'rb'))
        trace_100000 = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_'+site_and_sample_size+'_'+type+'N_100000s_1.pk', 'rb'))
        trace_null = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/null_model_'+site_and_sample_size+'_'+type+'.pk', 'rb'))
        trace_simple_linear = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/simple_linear_model_'+site_and_sample_size+'_'+type+'.pk', 'rb'))

        #compare using loo
        df_comp_loo = az.compare({'trace_1000': trace_1000, 'trace_10000': trace_10000, 'trace_100000': trace_100000, 'trace_null': trace_null, 'trace_simple_linear': trace_simple_linear})

        #save comparison df
        df_comp_loo.to_csv('../../exports/model_outputs/humans/trajectory_categories/loo_comparison_'+type+'_'+site_and_sample_size+'.csv')

for site_and_sample_size in sites_and_sample_size_combo:
    
    #bring our data back in for plotting
    for type in types:

        compare_df = pd.read_csv('../../exports/model_outputs/humans/trajectory_categories/loo_comparison_'+type+'_'+site_and_sample_size+'.csv', index_col=0)

        az.plot_compare(compare_df)

        # Define desired order
        desired_order = ['trace_1000', 'trace_10000', 'trace_100000', 'trace_simple_linear', 'trace_null']

        # Reorder the dataframe
        df_ordered = compare_df.reindex(desired_order)

        # Create figure with reduced height
        fig, ax = plt.subplots(figsize=(3.5, 15))

        # Plot each model
        x_positions = np.arange(len(df_ordered))

        # Plot each model individually
        for i, (idx, row) in enumerate(df_ordered.iterrows()):
            color = 'orange' if idx == 'trace_10000' else 'black'
            
            err_container = ax.errorbar(
                x=i, 
                y=row['elpd_loo'],
                yerr=row['se'],
                fmt='o',
                color=color,
                markersize=8,
                capsize=0,  # No caps
                linewidth=2,
                elinewidth=2
            )
        
            # Set alpha only for error bars, not the point
            err_container[2][0].set_alpha(0.3)  # [2] is the error bar lines

        # Add red dotted horizontal line at the highest ELPD value
        max_elpd = df_ordered['elpd_loo'].max()
        ax.axhline(max_elpd, color='black', linestyle='--', linewidth=1, alpha=0.5)

        #add a bit of spacing to the edges
        ax.set_xlim(-0.5, len(df_ordered)-0.5)

        # Formatting
        ax.set_xticks(x_positions)
        ax.set_xticklabels(desired_order, rotation=45, ha='right')
        ax.set_ylabel('ELPD (LOO)', fontsize=12)
        ax.set_xlabel('Model', fontsize=12)
        ax.set_title('Model Comparison', fontsize=14, fontweight='bold')
        ax.grid(axis='y', alpha=0.3)

        plt.tight_layout()
        # Save the figure at 600 dpi
        plt.savefig(f'../../exports/figures/model_comparison_{type}_{site_and_sample_size}.png', dpi=600)
        plt.show()

#############################
#Plotting some of these fits#
#############################
    
#bring in genscot and merge data to get site annotations
file = open('../../data/genscot_full_with_site_details.pk', 'rb')
genscot = pickle.load(file)
file.close()

#sample 500 uniform ages
sample_size = 500
genscot = sample_to_uniform_age(genscot, sample_size)

#subtract min age
genscot.var.age = genscot.var.age - genscot.var.age.min()

cpgs = ['cg23310490', 'cg19758448', 'cg24436906', 'cg00327072'] #some example cpgs

type = "non_linear_saturating" #change as needed ("vmp", "dmp", "vmp_dmp", "non_linear_saturating")

adata = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/adata_final_'+type+'.pk', 'rb'))

#loop through N/s values
for name in ['1000', '10000', '100000']:
    
    trace = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_'+type+'N_'+name+'s_1.pk', 'rb'))
    
    #make sure adata is in the same order as trace posterior participants and sites
    adata = adata[trace.observed_data.sites.values, trace.observed_data.participants.values]
    t = adata.var.age.values

    #remove one outlier for the scatter
    adata_new = adata[:,adata.var.index != '203041550027_R06C01']
    t_new = adata_new.var.age.values

    #which of the cpgs are in the trace?
    sites_in_trace = trace.observed_data.sites.values
    cpgs_in_trace = [cpg for cpg in cpgs if cpg in sites_in_trace]

    for site_i in cpgs_in_trace:

        if name == '10000':
            plot_color = 'darkorange'
        else:
            plot_color = 'black'

        #new plot
        plt.figure(figsize=(6, 4))

        # Plot the posterior predictive samples for full model
        y_vals=trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0).mean(axis=0)
        upper = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),99.5, axis=0)
        lower = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),0.5, axis=0)

        #plot upper and lower
        plt.scatter(t, upper, color='lightgray', alpha=0.2, s=10)
        plt.scatter(t, lower, color='lightgray', alpha=0.2, s=10)

        if name == '1000':
            #need more smoothing for 1000 N
            ys_mean = lowess(y_vals, t, return_sorted=True, frac=0.6, it=500)
            ys_upper = lowess(upper, t,  return_sorted=True, frac=0.6, it=500)
            ys_lower = lowess(lower, t, return_sorted=True, frac=0.6, it=500)
        else:
            ys_mean = lowess(y_vals, t, return_sorted=True, frac=0.4, it=100)
            ys_upper = lowess(upper, t,  return_sorted=True, frac=0.4, it=100)
            ys_lower = lowess(lower, t, return_sorted=True, frac=0.4, it=100)

        plt.plot(ys_mean[:,0], ys_mean[:,1], color=plot_color, alpha=0.9, linewidth=3)

        plt.scatter(genscot.var.age.values, genscot[site_i].X.T, color='gray', alpha=0.3, s=30, edgecolor='none')

        # Plot the 95% predictive interval
        plt.plot(ys_upper[:,0], ys_upper[:,1], color=plot_color, alpha=1, linestyle='--', linewidth=3)
        plt.plot(ys_upper[:,0], ys_lower[:,1], color=plot_color, alpha=1, linestyle='--',  linewidth=3)

        plt.title(f'Site {site_i} type: {name}')

        #remove border
        sns.despine()
        
        #save at 600 dpi
        plt.savefig(f'../../exports/figures/site_fit_{site_i}_N_{name}.png', dpi=600)


#################################################################
#Looking at how different categories of sites vary in parameters#
#################################################################

#bring in traces
vmp = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_vmpN_10000s_1.pk', 'rb'))
vmp_dmp = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_vmp_dmpN_10000s_1.pk', 'rb'))
dmp = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_dmpN_10000s_1.pk', 'rb'))
sat = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_non_linear_saturatingN_10000s_1.pk', 'rb'))

#get average pm and pu for each category and plot
categories = {'VMP_DMP': vmp_dmp, 'VMP': vmp, 'DMP': dmp, 'Sat': sat}
df_params = pd.DataFrame(columns=['category', 'pm', 'pu', 'p', 'eta', 'omega' 'site'])

#these are traces, so need to get pm and pu from posterior
for category, trace in categories.items():
    omega = trace.posterior['omega'].values
    eta = trace.posterior['eta'].values
    site = trace.observed_data.sites.values
    p = trace.posterior['p'].values

    #now, eta = pu / (pu + pm) and omega = pm + pu
    #so, pm = omega * (1 - eta) and pu = omega * eta
    pm = omega * (1 - eta)
    pu = omega * eta

    #get mean across chains and draws
    pm_mean = pm.mean(axis=(0,1))
    pu_mean = pu.mean(axis=(0,1))
    p_mean = p.mean(axis=(0,1))
    eta_mean = eta.mean(axis=(0,1))
    omega_mean = omega.mean(axis=(0,1))

    #append to df_params
    for i in range(len(pm_mean)):
        df_params = df_params._append({'category': category, 'pm': pm_mean[i], 'pu': pu_mean[i], 'p': p_mean[i], 'eta': eta_mean[i], 'omega': omega_mean[i], 'site': site[i]}, ignore_index=True)

#bring in genscot and merge data to get site annotations
file = open('../../data/genscot_full_with_site_details.pk', 'rb')
genscot = pickle.load(file)
file.close()

obs = genscot.obs

#merge on site
df_params = df_params.merge(obs[['beta', 'var_change']], left_on='site', right_index=True)

#now boxplot of pm and pu by category and by whether beta is above or below 0
decr_beta = df_params['beta'] < 0
incr_beta = df_params['beta'] >= 0

#how many in each category
print("Counts by category and beta direction:")
print(pd.crosstab(df_params['category'], decr_beta))

#convert to percentages
df_params['pm'] = df_params['pm'] * 100
df_params['pu'] = df_params['pu'] * 100

for df_plot in [df_params[decr_beta], df_params[incr_beta]]:
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    sns.boxplot(x='category', y='pm', data=df_plot)
    plt.title('pm by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
    plt.subplot(1, 2, 2)
    sns.boxplot(x='category', y='pu', data=df_plot)
    plt.title('pu by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
    plt.tight_layout()
    #save at 600 dpi
    plt.savefig('../../exports/figures/pm_pu_by_category_beta_' + ('decr' if df_plot.equals(df_params[decr_beta]) else 'incr') + '.png', dpi=600)

#now also get the sum of pm and pu by category, and ratio, and plot these too
df_params['sum'] = df_params['pm'] + df_params['pu']
df_params['ratio'] = df_params['pu'] / df_params['pm']

#now plot these
for df_plot in [df_params[decr_beta], df_params[incr_beta]]:
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    sns.boxplot(x='category', y='sum', data=df_plot)
    plt.title('pm + pu by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
    plt.subplot(1, 2, 2)
    sns.boxplot(x='category', y='ratio', data=df_plot)
    plt.title('pu / pm by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
    plt.tight_layout()
    #save at 600 dpi
    plt.savefig('../../exports/figures/sum_ratio_by_category_beta_' + ('decr' if df_plot.equals(df_params[decr_beta]) else 'incr') + '.png', dpi=600)

#now also get difference between p and eta by category, and plot these too
df_params['p_minus_eta'] = np.abs(df_params['p'] - df_params['eta'])

#igoring categories, plot scatter of (eta-p) vs. beta
plt.figure(figsize=(6, 5))
plt.scatter(df_params['beta'], df_params['p_minus_eta'], alpha=0.5)
plt.title('Absolute difference between p and eta vs. beta')
plt.xlabel('Beta')
plt.ylabel('Absolute difference between p and eta')
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/p_minus_eta_vs_beta.png', dpi=600)

#plot p_minus_eta vs sum of pm and pu, coloured by category
plt.figure(figsize=(6, 5))
sns.scatterplot(x='sum', y='p_minus_eta', hue='category', data=df_params, alpha=0.7)
plt.title('Absolute difference between p and eta vs. (pm + pu)')
plt.xlabel('pm + pu')
plt.ylabel('Absolute difference between p and eta')
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/p_minus_eta_vs_sum_pm_pu.png', dpi=600)

#let also plot boxplots of p in % by category and beta direction
df_params['p'] = df_params['p'] * 100
for df_plot in [df_params[decr_beta], df_params[incr_beta]]:
    plt.figure(figsize=(6, 5))
    sns.boxplot(x='category', y='p', data=df_plot)
    plt.title('p by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
    plt.tight_layout()
    #save at 600 dpi
    plt.savefig('../../exports/figures/p_by_category_beta_' + ('decr' if df_plot.equals(df_params[decr_beta]) else 'incr') + '.png', dpi=600)

###################################################################################
#Table of distribution of white p values and spearman rho values for supplementary#
###################################################################################

#bring in genscot and merge data to get site annotations
file = open('../../data/genscot_full_with_site_details.pk', 'rb')
genscot = pickle.load(file)
file.close()

#cumulative distribution of p values and spearman rho values for all sites in genscot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.ecdfplot(genscot.obs['white_pval'], label='All sites', color='blue')
plt.title('Cumulative distribution of white p values')
plt.xlabel('White p value')
plt.ylabel('Cumulative proportion') 

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.ecdfplot(genscot.obs['spearman_stat'], label='All sites', color='blue')
plt.title('Cumulative distribution of white p values')
plt.xlabel('White p value')
plt.ylabel('Cumulative proportion') 

#boxplot of spearman rho values
plt.figure(figsize=(6, 5))
sns.boxplot(y=genscot.obs['spearman_stat'], color='lightblue')
plt.title('Boxplot of Spearman rho values for all sites')

#as above but log scale on x axis
plt.subplot(1, 2, 2)
sns.ecdfplot(genscot.obs['white_pval'], label='All sites', color='blue')
plt.xscale('log')
plt.title('Cumulative distribution of white p values (log scale)')
plt.xlabel('White p value (log scale)')
plt.ylabel('Cumulative proportion')
plt.tight_layout()

#instead, just make a table of quantiles of white p values and spearman rho values
quantiles = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
white_pval_quantiles = genscot.obs['white_pval'].quantile(quantiles)
spearman_rho_quantiles = np.abs(genscot.obs['spearman_stat']).quantile(quantiles)   
#combine into a df
quantile_df = pd.DataFrame({'quantile': quantiles, 'white_pval': white_pval_quantiles.values, 'spearman_rho': spearman_rho_quantiles.values})
print(quantile_df)

#save as csv
quantile_df.to_csv('../../exports/figures/white_pval_spearman_rho_quantiles.csv', index=False)