###################
### Description ###
###################
# This script is for running our model on data from different mammals.
# Relevant figures:
# - Plotting N/s vs lifespan across mammals (Fig. 3b, Supp. Fig. 3c)
# - Plotting pm and pu vs lifespan across mammals (Fig. 3c)
# - Joint model comparison (Fig. 3d)
# - Examples of site fits from joint model (Supp. Figs 3d-i)
# Inputs:
# - Model outputs (traces) from 'run_mammals_joint_models.py' and 'run_mammals_separate_models.py'
##############
### Author ###
##############
# Sam Crofts (sam.crofts@ed.ac.uk)
##################################
##################################
#imports
import sys
sys.path.append("../..") # fix to import modules from root
from src.general_imports import *
import matplotlib.pyplot as plt
import scipy.stats as stats
import os
from statsmodels.nonparametric.smoothers_lowess import lowess
################################################################
###Running each mammal seperately (Fig. 3b-c, Supp. Fig. 2e) ###
################################################################
trace_list = []
for file in os.listdir('../../exports/model_outputs/all_mammals/'):
if file.startswith("spearman_separate") and file.endswith(".pk"):
name = file.replace('.pk', '')
this_trace = name.split("spearman_separate_model_min_sample_size_50_nsites_200_")[1]
#replace spaces with underscores
name = name.replace(' ', '_')
exec(name + " = pickle.load(open('../../exports/model_outputs/all_mammals/"+file+"', 'rb'))")
exec(name + ".animal = '"+this_trace+"'")
#add the file to trace_list
exec("trace_list.append("+name+")")
#check: plot posterior N/s
for trace in trace_list:
#add model name to plot title
az.plot_posterior(trace, var_names = 'N/s')
#also get sample size from observed values
sample_size = trace.observed_data.participants.shape[0]
plt.title(trace.animal + ' (n=' + str(sample_size) + ')')
#bring in lifespan data
file = open('../../data/pan_mammal_blood_with_site_details.pk', 'rb')
data_list = pickle.load(file)
file.close()
#make dataframe of lifespan and animal
lifespan_df = pd.DataFrame(columns=['animal', 'lifespan', 'common_name'])
for data in data_list:
lifespan_df = lifespan_df._append({'animal': data.uns['organism'], 'lifespan': data.uns['lifespan'], 'common_name': data.uns['common_name']}, ignore_index=True)
#add lifespan data to each trace in trace_list based on animal
for trace in trace_list:
trace.lifespan = lifespan_df[lifespan_df['animal']==trace.animal]['lifespan'].values[0]
trace.common_name = lifespan_df[lifespan_df['animal']==trace.animal]['common_name'].values[0]
#make a dataframe of N/s and lifespan
#remove sheep (ovis aries) as it didn't converge due to small range of ages
non_converged_orgs = ['Ovis aries']
result_df = pd.DataFrame(columns=['organism', 'N_on_s', 'lifespan', 'sample_size'])
for trace in trace_list:
sci_name = trace.animal
sample_size = trace.observed_data.participants.shape[0]
#if not in uncoverged
if sci_name not in non_converged_orgs:
result_df = result_df._append({'organism': trace.common_name, 'N_on_s': az.summary(trace, var_names='N/s')['mean'].values[0], 'lifespan': trace.lifespan, 'lower': az.summary(trace, var_names='N/s')['hdi_3%'].values[0], 'upper': az.summary(trace, var_names='N/s')['hdi_97%'].values[0], 'sample_size': sample_size}, ignore_index=True)
#make column for log lifespan and log N/s
result_df['log_lifespan'] = np.log10(result_df['lifespan'])
result_df['log_n_s'] = np.log10(result_df['N_on_s'])
result_df['log_lower'] = np.log10(result_df['lower'])
result_df['log_upper'] = np.log10(result_df['upper'])
###Plotting log lifespan vs log N/s with a regression line using seaborn (Fig. 3b)###
yval = 'log_n_s'
#remove those that haven't converged
result_df = result_df[~result_df['organism'].isin(non_converged_orgs)]
#remove if sample size is less than 50
result_df = result_df[result_df['sample_size'] >= 100]
#reset index
result_df = result_df.reset_index(drop=True)
#set figure size
plt.figure(figsize=(9,6))
#plot log lifespan vs log N/s with a regression line using seaborn
sns.regplot(x='log_lifespan', y=yval, data=result_df)
#add labels for points
for i in range(0, len(result_df)):
plt.text(result_df['log_lifespan'][i], result_df[yval][i], result_df['organism'][i], fontsize=14)
#add slope with 95% CI
slope, intercept, r_value, p_value, std_err = stats.linregress(result_df['log_lifespan'], result_df[yval])
#add slope with 95%CI and r2 to title
plt.title(f'slope: {slope:.2f} 95% CI: {slope-1.96*std_err:.2f} to {slope+1.96*std_err:.2f} r2: {r_value**2:.2f} p: {p_value:.2f}')
#y-axis title
plt.ylabel('Log N/s')
#x-axis title
plt.xlabel('Log lifespan')
#set text size
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel('Log lifespan', fontsize=14)
#save
plt.savefig('../../exports/figures/scaling_white_text.png', dpi=600)
########################################################################################################################
###As above, but with pm and pu instead (and seperating sites into increasing and decreasing mean with age) (Fig. 3c)###
########################################################################################################################
new_result_df = pd.DataFrame(columns=['organism', 'N_on_s', 'lifespan', 'mean_pm_inc', 'mean_pm_dec', 'mean_pu_inc', 'mean_pu_dec'])
for trace in trace_list:
#now, omega = pm+pu and eta = pu/(pm+pu)
species = trace.common_name
#if ovis aries, skip
if trace.animal == 'Ovis aries':
continue
sample_size = trace.observed_data.participants.shape[0]
#if sample size < 100, skip
if sample_size < 100:
continue
omega = az.summary(trace, var_names='omega')['mean']
omega.index = range(0, len(omega))
eta = az.summary(trace, var_names='eta')['mean']
eta.index = range(0, len(eta))
p = az.summary(trace, var_names='p')['mean']
p.index = range(0, len(p))
#now get pm and pu
pm = omega * (1 - eta)
pu = omega * eta
#make dataframe
df = pd.DataFrame({
'p_m': pm,
'p_u': pu,
'p': p
})
df['final_val'] = df['p_u'] / (df['p_u']+df['p_m'])
#get change in val
df['change'] = df['final_val'] - df['p']
#get mean p_m if change is positive
mean_pm_inc = df[df['change'] > 0]['p_m'].mean()
#get mean p_m if change is negative
mean_pm_dec = df[df['change'] < 0]['p_m'].mean()
#get mean p_u if change is positive
mean_pu_inc = df[df['change'] > 0]['p_u'].mean()
#get mean p_u if change is negative
mean_pu_dec = df[df['change'] < 0]['p_u'].mean()
#add to the dataframe
new_result_df = new_result_df._append({'organism': trace.common_name, 'N_on_s': result_df.query('organism == @species')['N_on_s'].values[0], 'lifespan': lifespan_df[lifespan_df['animal']==trace.animal]['lifespan'].values[0], 'mean_pm_inc': mean_pm_inc, 'mean_pm_dec': mean_pm_dec, 'mean_pu_inc': mean_pu_inc, 'mean_pu_dec': mean_pu_dec}, ignore_index=True)
new_result_df['log_lifespan'] = np.log10(new_result_df['lifespan'])
new_result_df['log_n_s'] = np.log10(new_result_df['N_on_s'])
new_result_df['log_pm_inc'] = np.log10(new_result_df['mean_pm_inc'])
new_result_df['log_pm_dec'] = np.log10(new_result_df['mean_pm_dec'])
new_result_df['log_pu_inc'] = np.log10(new_result_df['mean_pu_inc'])
new_result_df['log_pu_dec'] = np.log10(new_result_df['mean_pu_dec'])
yval = 'log_pm_inc'
#plot log lifespan vs log N/s with a regression line using seaborn
sns.regplot(x='log_lifespan', y=yval, data=new_result_df)
#add labels for points
# for i in range(0, len(result_df)):
# plt.text(result_df['log_lifespan'][i], new_result_df[yval][i], new_result_df['organism'][i])
# #add slope with 95% CI
slope, intercept, r_value, p_value, std_err = stats.linregress(new_result_df['log_lifespan'], new_result_df[yval])
#add slope with 95%CI and r2 to title
plt.title(f'slope: {slope:.2f} 95% CI: {slope-1.96*std_err:.2f} to {slope+1.96*std_err:.2f} r2: {r_value**2:.2f} p: {p_value:.2f}')
#y-axis title
plt.ylabel(yval)
#x-axis title
plt.xlabel('Log lifespan')
#save
plt.savefig('../../exports/figures/pm_inc_notext.png', dpi=600)
########################
#3 way model comparison#
########################
#open pickled traces
single_ns = pickle.load(open('../../exports/model_outputs/all_mammals/joint_model_single_ns.pk', 'rb'))
scaled_ns = pickle.load(open('../../exports/model_outputs/all_mammals/joint_model_scaled_ns.pk', 'rb'))
free_ns = pickle.load(open('../../exports/model_outputs/all_mammals/joint_model_free_ns.pk', 'rb'))
df_comp_loo = az.compare({"single_ns": single_ns, "scaled_ns": scaled_ns, "free": free_ns}, ic='loo')
compare_df = df_comp_loo.reset_index()
#save compare_df for later
compare_df.to_csv('../../exports/model_outputs/all_mammals/scaling_model_comparison.csv')
#bring back in
compare_df = pd.read_csv('../../exports/model_outputs/all_mammals/scaling_model_comparison.csv')
# Reorder the dataframe
df_ordered = compare_df
# Create figure with reduced height
fig, ax = plt.subplots(figsize=(12, 4))
# Plot each model
y_positions = np.arange(len(df_ordered))
# Plot each model individually
for i, (idx, row) in enumerate(df_ordered.iterrows()):
color = 'black'
err_container = ax.errorbar(
x=row['elpd_loo'],
y=i,
xerr=row['se'],
fmt='o',
color=color,
markersize=12,
capsize=0,
capthick=0,
linewidth=5,
elinewidth=5
)
# Set alpha only for error bars, not the point
err_container[2][0].set_alpha(0.3) # [2] is the error bar lines
# Add red dotted horizontal line at the highest ELPD value
max_elpd = df_ordered['elpd_loo'].max()
ax.axvline(max_elpd, color='black', linestyle='--', linewidth=1, alpha=0.5)
# Formatting
ax.set_yticks(y_positions)
ax.set_xlabel('ELPD (LOO)', fontsize=12)
ax.set_ylabel('Model', fontsize=12)
ax.set_title('Model Comparison', fontsize=14, fontweight='bold')
ax.invert_yaxis() # Top at top
ax.grid(axis='x', alpha=0.3)
plt.tight_layout()
# Save the figure at 600 dpi
plt.savefig(f'../../exports/figures/scaling_model_comparison.png', dpi=600)
plt.show()
###################################
#Plotting site fits for each model#
###################################
#load joint adata
file = open('../../data/joint_model_scaling_adata.pk', 'rb')
adata_list = pickle.load(file)
file.close()
#plot some sites from human and mouse for each model
for trace in [single_ns, scaled_ns, free_ns]:
# Select only Mouse or Human data
trace = trace.sel(species='Cattle')
#also select human data from adata_list
for adata in adata_list:
if adata.uns['organism'] == 'Bos taurus':
data = adata
t = data.var.age.values
i=0
for site_i in trace.posterior.sites.values[0:20]:
#new plot
plt.figure(figsize=(6, 4))
plot_color = 'black'
# Plot the posterior predictive samples for full model
y_vals=trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0).mean(axis=0)
upper = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),99.5, axis=0)
lower = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),0.5, axis=0)
ys_mean = lowess(y_vals, t, return_sorted=True, frac=0.6)
ys_upper = lowess(upper, t, return_sorted=True, frac=0.6)
ys_lower = lowess(lower, t, return_sorted=True, frac=0.6)
plt.plot(ys_mean[:,0], ys_mean[:,1], color=plot_color, alpha=0.9, linewidth=2)
plt.scatter(t, data[i].X.T, color='gray', alpha=0.3, s=30, edgecolor='none')
# Plot the 95% predictive interval
plt.plot(ys_upper[:,0], ys_upper[:,1], color=plot_color, alpha=1, linestyle='--', linewidth=2)
plt.plot(ys_upper[:,0], ys_lower[:,1], color=plot_color, alpha=1, linestyle='--', linewidth=2)
plt.title(f'Site {site_i}')
#save at 600 dpi
#plt.savefig(f'../../exports/figures/site_fit_human_example_free_ns.png', dpi=600)
i+=1