###################
### Description ###
###################
# This script is for the various analyses of human (GenScot) data.
# Includes:
# - Plotting heatmap of log likelihoods across N and s (Fig. 2a)
# - Plotting N/s by group (Fig. 2c, Supp. Fig. 2a)
# - Plotting site fits across different categories of CpGs using conditional model (Fig. 2b, Supp. Fig. 1a)
# - Looking at how different categories of sites vary in parameters (Supp. Fig. 1b-c)
# - Outputting a table of distribution of white p values and spearman rho values for Supp. Table 1
##############
### Author ###
##############
# Sam Crofts (sam.crofts@ed.ac.uk)
########################################################################################################################
#Imports
import os
import sys
sys.path.append("../..") # fix to import modules from root
from src.general_imports import *
from statsmodels.nonparametric.smoothers_lowess import lowess
from scipy.stats import beta
################################
###Plotting heatmap (Fig. 2a)###
################################
#go through each file in exports/model_outputs/humans/fixed_n_s, read in the trace, and save the N and s values to a dataframe
df = pd.DataFrame(columns=['N', 's', 'elpd_loo'])
files = os.listdir('../../exports/model_outputs/humans/fixed_n_s/log_scale/')
for f in files:
#read in the file(s) with site parameters (.pk files)
if f.endswith('.pk'):
with open(f'../../exports/model_outputs/humans/fixed_n_s/log_scale/{f}', 'rb') as file:
trace = pickle.load(file)
file.close()
#get N and s
N = (f.split('_')[4])
N = int(N[1:])
s= (f.split('_')[5])
s = (s[1:])
#remove last 3 characters (.pk)
s = s[:-3]
s = float(s)
#calculate LOO
elpd_loo = az.loo(trace).elpd_loo
#append to df
df = df._append({'N': N, 's': s, 'elpd_loo': elpd_loo}, ignore_index=True)
#save df
df.to_csv('../../exports/model_outputs/humans/fixed_n_s/heatmap_likelihoods_log_scale.csv', index=False)
#re-import
df = pd.read_csv('../../exports/model_outputs/humans/fixed_n_s/heatmap_likelihoods_log_scale.csv')
df_backup = df.copy()
#keep if N>=50000
df = df[df['N']<=50000]
plt.figure(figsize=(8.5, 6))
heatmap = sns.heatmap(df.pivot(values='elpd_loo', columns='N', index='s'), cmap='viridis', annot=False, fmt='%.1g')
#tight
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/heatmap_elpd_log_scale.png', dpi=600)
#####################################
###Plotting N/s by group (Fig. 2c)###
#####################################
#vars = ['high_acc', 'low_acc']
vars = ['female', 'male', 'smokers']
vars = ['high_acc', 'low_acc']
selection_method = "spearman" #or "white"
df = pd.DataFrame(columns=['type', 'mean', 'hdi_2.5%', 'hdi_97.5%'])
for var in vars:
file_path = f'../../exports/model_outputs/humans/cohort_analyses/nsites_500_ss_500selection_method_{selection_method}_{var}.pk'
with open(file_path, 'rb') as file:
exec(var+"= pickle.load(file)")
#plot N/s
az.plot_posterior(eval(var), var_names='N/s')
exec("summary_"+var+"= pm.summary("+var+", hdi_prob=0.95, var_names=['N/s'])")
exec("df = df._append({'type': '"+var+"', 'mean': summary_"+var+"['mean']['N/s'], 'hdi_2.5%': summary_"+var+"['hdi_2.5%']['N/s'], 'hdi_97.5%': summary_"+var+"['hdi_97.5%']['N/s']}, ignore_index=True)")
#Plot
plt.figure(figsize=(3, 6))
plt.errorbar(df['type'], df['mean'],
yerr=[df['mean']-df['hdi_2.5%'], df['hdi_97.5%']-df['mean']],
fmt='o', color='royalblue', ecolor='lightblue',
elinewidth=4, capsize=0)
# # Adjust the plot
plt.title('N/s with 95% HPDI')
plt.ylabel('N/s')
# Reduce space on x-axis between categories
ax = plt.gca()
#ax.set_xlim(-0.2, 1.2) #for two categories
ax.set_xlim(-0.5, 2.5) #for three categories
plt.gcf().set_size_inches(3, 4.5) # Adjust these values to control category spacing
plt.tight_layout()
plt.savefig('../../exports/figures/group_comparison_smokers_spearman.png', dpi=600)
########################################################
###Plotting sites (using conditional model) (Fig. 2b)###
########################################################
#sites types
types = ["vmp", "dmp", "vmp_dmp", "non_linear_saturating"]
#site and sample size combinations
sites_and_sample_size_combo = ["nsites_200_ss_200", "nsites_200_ss_500",
"nsites_500_ss_200"]
for site_and_sample_size in sites_and_sample_size_combo:
for type in types:
#bring in traces
trace_1000 = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_'+site_and_sample_size+'_'+type+'N_1000s_1.pk', 'rb'))
trace_10000 = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_'+site_and_sample_size+'_'+type+'N_10000s_1.pk', 'rb'))
trace_100000 = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_'+site_and_sample_size+'_'+type+'N_100000s_1.pk', 'rb'))
trace_null = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/null_model_'+site_and_sample_size+'_'+type+'.pk', 'rb'))
trace_simple_linear = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/simple_linear_model_'+site_and_sample_size+'_'+type+'.pk', 'rb'))
#compare using loo
df_comp_loo = az.compare({'trace_1000': trace_1000, 'trace_10000': trace_10000, 'trace_100000': trace_100000, 'trace_null': trace_null, 'trace_simple_linear': trace_simple_linear})
#save comparison df
df_comp_loo.to_csv('../../exports/model_outputs/humans/trajectory_categories/loo_comparison_'+type+'_'+site_and_sample_size+'.csv')
for site_and_sample_size in sites_and_sample_size_combo:
#bring our data back in for plotting
for type in types:
compare_df = pd.read_csv('../../exports/model_outputs/humans/trajectory_categories/loo_comparison_'+type+'_'+site_and_sample_size+'.csv', index_col=0)
az.plot_compare(compare_df)
# Define desired order
desired_order = ['trace_1000', 'trace_10000', 'trace_100000', 'trace_simple_linear', 'trace_null']
# Reorder the dataframe
df_ordered = compare_df.reindex(desired_order)
# Create figure with reduced height
fig, ax = plt.subplots(figsize=(3.5, 15))
# Plot each model
x_positions = np.arange(len(df_ordered))
# Plot each model individually
for i, (idx, row) in enumerate(df_ordered.iterrows()):
color = 'orange' if idx == 'trace_10000' else 'black'
err_container = ax.errorbar(
x=i,
y=row['elpd_loo'],
yerr=row['se'],
fmt='o',
color=color,
markersize=8,
capsize=0, # No caps
linewidth=2,
elinewidth=2
)
# Set alpha only for error bars, not the point
err_container[2][0].set_alpha(0.3) # [2] is the error bar lines
# Add red dotted horizontal line at the highest ELPD value
max_elpd = df_ordered['elpd_loo'].max()
ax.axhline(max_elpd, color='black', linestyle='--', linewidth=1, alpha=0.5)
#add a bit of spacing to the edges
ax.set_xlim(-0.5, len(df_ordered)-0.5)
# Formatting
ax.set_xticks(x_positions)
ax.set_xticklabels(desired_order, rotation=45, ha='right')
ax.set_ylabel('ELPD (LOO)', fontsize=12)
ax.set_xlabel('Model', fontsize=12)
ax.set_title('Model Comparison', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3)
plt.tight_layout()
# Save the figure at 600 dpi
plt.savefig(f'../../exports/figures/model_comparison_{type}_{site_and_sample_size}.png', dpi=600)
plt.show()
#############################
#Plotting some of these fits#
#############################
#bring in genscot and merge data to get site annotations
file = open('../../data/genscot_full_with_site_details.pk', 'rb')
genscot = pickle.load(file)
file.close()
#sample 500 uniform ages
sample_size = 500
genscot = sample_to_uniform_age(genscot, sample_size)
#subtract min age
genscot.var.age = genscot.var.age - genscot.var.age.min()
cpgs = ['cg23310490', 'cg19758448', 'cg24436906', 'cg00327072'] #some example cpgs
type = "non_linear_saturating" #change as needed ("vmp", "dmp", "vmp_dmp", "non_linear_saturating")
adata = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/adata_final_'+type+'.pk', 'rb'))
#loop through N/s values
for name in ['1000', '10000', '100000']:
trace = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_'+type+'N_'+name+'s_1.pk', 'rb'))
#make sure adata is in the same order as trace posterior participants and sites
adata = adata[trace.observed_data.sites.values, trace.observed_data.participants.values]
t = adata.var.age.values
#remove one outlier for the scatter
adata_new = adata[:,adata.var.index != '203041550027_R06C01']
t_new = adata_new.var.age.values
#which of the cpgs are in the trace?
sites_in_trace = trace.observed_data.sites.values
cpgs_in_trace = [cpg for cpg in cpgs if cpg in sites_in_trace]
for site_i in cpgs_in_trace:
if name == '10000':
plot_color = 'darkorange'
else:
plot_color = 'black'
#new plot
plt.figure(figsize=(6, 4))
# Plot the posterior predictive samples for full model
y_vals=trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0).mean(axis=0)
upper = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),99.5, axis=0)
lower = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),0.5, axis=0)
#plot upper and lower
plt.scatter(t, upper, color='lightgray', alpha=0.2, s=10)
plt.scatter(t, lower, color='lightgray', alpha=0.2, s=10)
if name == '1000':
#need more smoothing for 1000 N
ys_mean = lowess(y_vals, t, return_sorted=True, frac=0.6, it=500)
ys_upper = lowess(upper, t, return_sorted=True, frac=0.6, it=500)
ys_lower = lowess(lower, t, return_sorted=True, frac=0.6, it=500)
else:
ys_mean = lowess(y_vals, t, return_sorted=True, frac=0.4, it=100)
ys_upper = lowess(upper, t, return_sorted=True, frac=0.4, it=100)
ys_lower = lowess(lower, t, return_sorted=True, frac=0.4, it=100)
plt.plot(ys_mean[:,0], ys_mean[:,1], color=plot_color, alpha=0.9, linewidth=3)
plt.scatter(genscot.var.age.values, genscot[site_i].X.T, color='gray', alpha=0.3, s=30, edgecolor='none')
# Plot the 95% predictive interval
plt.plot(ys_upper[:,0], ys_upper[:,1], color=plot_color, alpha=1, linestyle='--', linewidth=3)
plt.plot(ys_upper[:,0], ys_lower[:,1], color=plot_color, alpha=1, linestyle='--', linewidth=3)
plt.title(f'Site {site_i} type: {name}')
#remove border
sns.despine()
#save at 600 dpi
plt.savefig(f'../../exports/figures/site_fit_{site_i}_N_{name}.png', dpi=600)
#################################################################
#Looking at how different categories of sites vary in parameters#
#################################################################
#bring in traces
vmp = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_vmpN_10000s_1.pk', 'rb'))
vmp_dmp = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_vmp_dmpN_10000s_1.pk', 'rb'))
dmp = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_dmpN_10000s_1.pk', 'rb'))
sat = pickle.load(open('../../exports/model_outputs/humans/trajectory_categories/full_model_nsites_200_ss_200_non_linear_saturatingN_10000s_1.pk', 'rb'))
#get average pm and pu for each category and plot
categories = {'VMP_DMP': vmp_dmp, 'VMP': vmp, 'DMP': dmp, 'Sat': sat}
df_params = pd.DataFrame(columns=['category', 'pm', 'pu', 'p', 'eta', 'omega' 'site'])
#these are traces, so need to get pm and pu from posterior
for category, trace in categories.items():
omega = trace.posterior['omega'].values
eta = trace.posterior['eta'].values
site = trace.observed_data.sites.values
p = trace.posterior['p'].values
#now, eta = pu / (pu + pm) and omega = pm + pu
#so, pm = omega * (1 - eta) and pu = omega * eta
pm = omega * (1 - eta)
pu = omega * eta
#get mean across chains and draws
pm_mean = pm.mean(axis=(0,1))
pu_mean = pu.mean(axis=(0,1))
p_mean = p.mean(axis=(0,1))
eta_mean = eta.mean(axis=(0,1))
omega_mean = omega.mean(axis=(0,1))
#append to df_params
for i in range(len(pm_mean)):
df_params = df_params._append({'category': category, 'pm': pm_mean[i], 'pu': pu_mean[i], 'p': p_mean[i], 'eta': eta_mean[i], 'omega': omega_mean[i], 'site': site[i]}, ignore_index=True)
#bring in genscot and merge data to get site annotations
file = open('../../data/genscot_full_with_site_details.pk', 'rb')
genscot = pickle.load(file)
file.close()
obs = genscot.obs
#merge on site
df_params = df_params.merge(obs[['beta', 'var_change']], left_on='site', right_index=True)
#now boxplot of pm and pu by category and by whether beta is above or below 0
decr_beta = df_params['beta'] < 0
incr_beta = df_params['beta'] >= 0
#how many in each category
print("Counts by category and beta direction:")
print(pd.crosstab(df_params['category'], decr_beta))
#convert to percentages
df_params['pm'] = df_params['pm'] * 100
df_params['pu'] = df_params['pu'] * 100
for df_plot in [df_params[decr_beta], df_params[incr_beta]]:
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.boxplot(x='category', y='pm', data=df_plot)
plt.title('pm by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
plt.subplot(1, 2, 2)
sns.boxplot(x='category', y='pu', data=df_plot)
plt.title('pu by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/pm_pu_by_category_beta_' + ('decr' if df_plot.equals(df_params[decr_beta]) else 'incr') + '.png', dpi=600)
#now also get the sum of pm and pu by category, and ratio, and plot these too
df_params['sum'] = df_params['pm'] + df_params['pu']
df_params['ratio'] = df_params['pu'] / df_params['pm']
#now plot these
for df_plot in [df_params[decr_beta], df_params[incr_beta]]:
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.boxplot(x='category', y='sum', data=df_plot)
plt.title('pm + pu by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
plt.subplot(1, 2, 2)
sns.boxplot(x='category', y='ratio', data=df_plot)
plt.title('pu / pm by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/sum_ratio_by_category_beta_' + ('decr' if df_plot.equals(df_params[decr_beta]) else 'incr') + '.png', dpi=600)
#now also get difference between p and eta by category, and plot these too
df_params['p_minus_eta'] = np.abs(df_params['p'] - df_params['eta'])
#igoring categories, plot scatter of (eta-p) vs. beta
plt.figure(figsize=(6, 5))
plt.scatter(df_params['beta'], df_params['p_minus_eta'], alpha=0.5)
plt.title('Absolute difference between p and eta vs. beta')
plt.xlabel('Beta')
plt.ylabel('Absolute difference between p and eta')
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/p_minus_eta_vs_beta.png', dpi=600)
#plot p_minus_eta vs sum of pm and pu, coloured by category
plt.figure(figsize=(6, 5))
sns.scatterplot(x='sum', y='p_minus_eta', hue='category', data=df_params, alpha=0.7)
plt.title('Absolute difference between p and eta vs. (pm + pu)')
plt.xlabel('pm + pu')
plt.ylabel('Absolute difference between p and eta')
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/p_minus_eta_vs_sum_pm_pu.png', dpi=600)
#let also plot boxplots of p in % by category and beta direction
df_params['p'] = df_params['p'] * 100
for df_plot in [df_params[decr_beta], df_params[incr_beta]]:
plt.figure(figsize=(6, 5))
sns.boxplot(x='category', y='p', data=df_plot)
plt.title('p by category for beta ' + ('< 0' if df_plot.equals(df_params[decr_beta]) else '>= 0'))
plt.tight_layout()
#save at 600 dpi
plt.savefig('../../exports/figures/p_by_category_beta_' + ('decr' if df_plot.equals(df_params[decr_beta]) else 'incr') + '.png', dpi=600)
###################################################################################
#Table of distribution of white p values and spearman rho values for supplementary#
###################################################################################
#bring in genscot and merge data to get site annotations
file = open('../../data/genscot_full_with_site_details.pk', 'rb')
genscot = pickle.load(file)
file.close()
#cumulative distribution of p values and spearman rho values for all sites in genscot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.ecdfplot(genscot.obs['white_pval'], label='All sites', color='blue')
plt.title('Cumulative distribution of white p values')
plt.xlabel('White p value')
plt.ylabel('Cumulative proportion')
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.ecdfplot(genscot.obs['spearman_stat'], label='All sites', color='blue')
plt.title('Cumulative distribution of white p values')
plt.xlabel('White p value')
plt.ylabel('Cumulative proportion')
#boxplot of spearman rho values
plt.figure(figsize=(6, 5))
sns.boxplot(y=genscot.obs['spearman_stat'], color='lightblue')
plt.title('Boxplot of Spearman rho values for all sites')
#as above but log scale on x axis
plt.subplot(1, 2, 2)
sns.ecdfplot(genscot.obs['white_pval'], label='All sites', color='blue')
plt.xscale('log')
plt.title('Cumulative distribution of white p values (log scale)')
plt.xlabel('White p value (log scale)')
plt.ylabel('Cumulative proportion')
plt.tight_layout()
#instead, just make a table of quantiles of white p values and spearman rho values
quantiles = [0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.99]
white_pval_quantiles = genscot.obs['white_pval'].quantile(quantiles)
spearman_rho_quantiles = np.abs(genscot.obs['spearman_stat']).quantile(quantiles)
#combine into a df
quantile_df = pd.DataFrame({'quantile': quantiles, 'white_pval': white_pval_quantiles.values, 'spearman_rho': spearman_rho_quantiles.values})
print(quantile_df)
#save as csv
quantile_df.to_csv('../../exports/figures/white_pval_spearman_rho_quantiles.csv', index=False)