################### ### Description ### ################### # This script is for running our model on data from different mammals. # Relevant figures: # - Plotting N/s vs lifespan across mammals (Fig. 3b, Supp. Fig. 3c) # - Plotting pm and pu vs lifespan across mammals (Fig. 3c) # - Joint model comparison (Fig. 3d) # - Examples of site fits from joint model (Supp. Figs 3d-i) # Inputs: # - Model outputs (traces) from 'run_mammals_joint_models.py' and 'run_mammals_separate_models.py' ############## ### Author ### ############## # Sam Crofts (sam.crofts@ed.ac.uk) ################################## ################################## #imports import sys sys.path.append("../..") # fix to import modules from root from src.general_imports import * import matplotlib.pyplot as plt import scipy.stats as stats import os from statsmodels.nonparametric.smoothers_lowess import lowess ################################################################ ###Running each mammal seperately (Fig. 3b-c, Supp. Fig. 2e) ### ################################################################ trace_list = [] for file in os.listdir('../../exports/model_outputs/all_mammals/'): if file.startswith("spearman_separate") and file.endswith(".pk"): name = file.replace('.pk', '') this_trace = name.split("spearman_separate_model_min_sample_size_50_nsites_200_")[1] #replace spaces with underscores name = name.replace(' ', '_') exec(name + " = pickle.load(open('../../exports/model_outputs/all_mammals/"+file+"', 'rb'))") exec(name + ".animal = '"+this_trace+"'") #add the file to trace_list exec("trace_list.append("+name+")") #check: plot posterior N/s for trace in trace_list: #add model name to plot title az.plot_posterior(trace, var_names = 'N/s') #also get sample size from observed values sample_size = trace.observed_data.participants.shape[0] plt.title(trace.animal + ' (n=' + str(sample_size) + ')') #bring in lifespan data file = open('../../data/pan_mammal_blood_with_site_details.pk', 'rb') data_list = pickle.load(file) file.close() #make dataframe of lifespan and animal lifespan_df = pd.DataFrame(columns=['animal', 'lifespan', 'common_name']) for data in data_list: lifespan_df = lifespan_df._append({'animal': data.uns['organism'], 'lifespan': data.uns['lifespan'], 'common_name': data.uns['common_name']}, ignore_index=True) #add lifespan data to each trace in trace_list based on animal for trace in trace_list: trace.lifespan = lifespan_df[lifespan_df['animal']==trace.animal]['lifespan'].values[0] trace.common_name = lifespan_df[lifespan_df['animal']==trace.animal]['common_name'].values[0] #make a dataframe of N/s and lifespan #remove sheep (ovis aries) as it didn't converge due to small range of ages non_converged_orgs = ['Ovis aries'] result_df = pd.DataFrame(columns=['organism', 'N_on_s', 'lifespan', 'sample_size']) for trace in trace_list: sci_name = trace.animal sample_size = trace.observed_data.participants.shape[0] #if not in uncoverged if sci_name not in non_converged_orgs: result_df = result_df._append({'organism': trace.common_name, 'N_on_s': az.summary(trace, var_names='N/s')['mean'].values[0], 'lifespan': trace.lifespan, 'lower': az.summary(trace, var_names='N/s')['hdi_3%'].values[0], 'upper': az.summary(trace, var_names='N/s')['hdi_97%'].values[0], 'sample_size': sample_size}, ignore_index=True) #make column for log lifespan and log N/s result_df['log_lifespan'] = np.log10(result_df['lifespan']) result_df['log_n_s'] = np.log10(result_df['N_on_s']) result_df['log_lower'] = np.log10(result_df['lower']) result_df['log_upper'] = np.log10(result_df['upper']) ###Plotting log lifespan vs log N/s with a regression line using seaborn (Fig. 3b)### yval = 'log_n_s' #remove those that haven't converged result_df = result_df[~result_df['organism'].isin(non_converged_orgs)] #remove if sample size is less than 50 result_df = result_df[result_df['sample_size'] >= 100] #reset index result_df = result_df.reset_index(drop=True) #set figure size plt.figure(figsize=(9,6)) #plot log lifespan vs log N/s with a regression line using seaborn sns.regplot(x='log_lifespan', y=yval, data=result_df) #add labels for points for i in range(0, len(result_df)): plt.text(result_df['log_lifespan'][i], result_df[yval][i], result_df['organism'][i], fontsize=14) #add slope with 95% CI slope, intercept, r_value, p_value, std_err = stats.linregress(result_df['log_lifespan'], result_df[yval]) #add slope with 95%CI and r2 to title plt.title(f'slope: {slope:.2f} 95% CI: {slope-1.96*std_err:.2f} to {slope+1.96*std_err:.2f} r2: {r_value**2:.2f} p: {p_value:.2f}') #y-axis title plt.ylabel('Log N/s') #x-axis title plt.xlabel('Log lifespan') #set text size plt.xticks(fontsize=14) plt.yticks(fontsize=14) plt.xlabel('Log lifespan', fontsize=14) #save plt.savefig('../../exports/figures/scaling_white_text.png', dpi=600) ######################################################################################################################## ###As above, but with pm and pu instead (and seperating sites into increasing and decreasing mean with age) (Fig. 3c)### ######################################################################################################################## new_result_df = pd.DataFrame(columns=['organism', 'N_on_s', 'lifespan', 'mean_pm_inc', 'mean_pm_dec', 'mean_pu_inc', 'mean_pu_dec']) for trace in trace_list: #now, omega = pm+pu and eta = pu/(pm+pu) species = trace.common_name #if ovis aries, skip if trace.animal == 'Ovis aries': continue sample_size = trace.observed_data.participants.shape[0] #if sample size < 100, skip if sample_size < 100: continue omega = az.summary(trace, var_names='omega')['mean'] omega.index = range(0, len(omega)) eta = az.summary(trace, var_names='eta')['mean'] eta.index = range(0, len(eta)) p = az.summary(trace, var_names='p')['mean'] p.index = range(0, len(p)) #now get pm and pu pm = omega * (1 - eta) pu = omega * eta #make dataframe df = pd.DataFrame({ 'p_m': pm, 'p_u': pu, 'p': p }) df['final_val'] = df['p_u'] / (df['p_u']+df['p_m']) #get change in val df['change'] = df['final_val'] - df['p'] #get mean p_m if change is positive mean_pm_inc = df[df['change'] > 0]['p_m'].mean() #get mean p_m if change is negative mean_pm_dec = df[df['change'] < 0]['p_m'].mean() #get mean p_u if change is positive mean_pu_inc = df[df['change'] > 0]['p_u'].mean() #get mean p_u if change is negative mean_pu_dec = df[df['change'] < 0]['p_u'].mean() #add to the dataframe new_result_df = new_result_df._append({'organism': trace.common_name, 'N_on_s': result_df.query('organism == @species')['N_on_s'].values[0], 'lifespan': lifespan_df[lifespan_df['animal']==trace.animal]['lifespan'].values[0], 'mean_pm_inc': mean_pm_inc, 'mean_pm_dec': mean_pm_dec, 'mean_pu_inc': mean_pu_inc, 'mean_pu_dec': mean_pu_dec}, ignore_index=True) new_result_df['log_lifespan'] = np.log10(new_result_df['lifespan']) new_result_df['log_n_s'] = np.log10(new_result_df['N_on_s']) new_result_df['log_pm_inc'] = np.log10(new_result_df['mean_pm_inc']) new_result_df['log_pm_dec'] = np.log10(new_result_df['mean_pm_dec']) new_result_df['log_pu_inc'] = np.log10(new_result_df['mean_pu_inc']) new_result_df['log_pu_dec'] = np.log10(new_result_df['mean_pu_dec']) yval = 'log_pm_inc' #plot log lifespan vs log N/s with a regression line using seaborn sns.regplot(x='log_lifespan', y=yval, data=new_result_df) #add labels for points # for i in range(0, len(result_df)): # plt.text(result_df['log_lifespan'][i], new_result_df[yval][i], new_result_df['organism'][i]) # #add slope with 95% CI slope, intercept, r_value, p_value, std_err = stats.linregress(new_result_df['log_lifespan'], new_result_df[yval]) #add slope with 95%CI and r2 to title plt.title(f'slope: {slope:.2f} 95% CI: {slope-1.96*std_err:.2f} to {slope+1.96*std_err:.2f} r2: {r_value**2:.2f} p: {p_value:.2f}') #y-axis title plt.ylabel(yval) #x-axis title plt.xlabel('Log lifespan') #save plt.savefig('../../exports/figures/pm_inc_notext.png', dpi=600) ######################## #3 way model comparison# ######################## #open pickled traces single_ns = pickle.load(open('../../exports/model_outputs/all_mammals/joint_model_single_ns.pk', 'rb')) scaled_ns = pickle.load(open('../../exports/model_outputs/all_mammals/joint_model_scaled_ns.pk', 'rb')) free_ns = pickle.load(open('../../exports/model_outputs/all_mammals/joint_model_free_ns.pk', 'rb')) df_comp_loo = az.compare({"single_ns": single_ns, "scaled_ns": scaled_ns, "free": free_ns}, ic='loo') compare_df = df_comp_loo.reset_index() #save compare_df for later compare_df.to_csv('../../exports/model_outputs/all_mammals/scaling_model_comparison.csv') #bring back in compare_df = pd.read_csv('../../exports/model_outputs/all_mammals/scaling_model_comparison.csv') # Reorder the dataframe df_ordered = compare_df # Create figure with reduced height fig, ax = plt.subplots(figsize=(12, 4)) # Plot each model y_positions = np.arange(len(df_ordered)) # Plot each model individually for i, (idx, row) in enumerate(df_ordered.iterrows()): color = 'black' err_container = ax.errorbar( x=row['elpd_loo'], y=i, xerr=row['se'], fmt='o', color=color, markersize=12, capsize=0, capthick=0, linewidth=5, elinewidth=5 ) # Set alpha only for error bars, not the point err_container[2][0].set_alpha(0.3) # [2] is the error bar lines # Add red dotted horizontal line at the highest ELPD value max_elpd = df_ordered['elpd_loo'].max() ax.axvline(max_elpd, color='black', linestyle='--', linewidth=1, alpha=0.5) # Formatting ax.set_yticks(y_positions) ax.set_xlabel('ELPD (LOO)', fontsize=12) ax.set_ylabel('Model', fontsize=12) ax.set_title('Model Comparison', fontsize=14, fontweight='bold') ax.invert_yaxis() # Top at top ax.grid(axis='x', alpha=0.3) plt.tight_layout() # Save the figure at 600 dpi plt.savefig(f'../../exports/figures/scaling_model_comparison.png', dpi=600) plt.show() ################################### #Plotting site fits for each model# ################################### #load joint adata file = open('../../data/joint_model_scaling_adata.pk', 'rb') adata_list = pickle.load(file) file.close() #plot some sites from human and mouse for each model for trace in [single_ns, scaled_ns, free_ns]: # Select only Mouse or Human data trace = trace.sel(species='Cattle') #also select human data from adata_list for adata in adata_list: if adata.uns['organism'] == 'Bos taurus': data = adata t = data.var.age.values i=0 for site_i in trace.posterior.sites.values[0:20]: #new plot plt.figure(figsize=(6, 4)) plot_color = 'black' # Plot the posterior predictive samples for full model y_vals=trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0).mean(axis=0) upper = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),99.5, axis=0) lower = np.percentile(trace.posterior_predictive['m-values'].sel(sites=site_i).mean(axis=0),0.5, axis=0) ys_mean = lowess(y_vals, t, return_sorted=True, frac=0.6) ys_upper = lowess(upper, t, return_sorted=True, frac=0.6) ys_lower = lowess(lower, t, return_sorted=True, frac=0.6) plt.plot(ys_mean[:,0], ys_mean[:,1], color=plot_color, alpha=0.9, linewidth=2) plt.scatter(t, data[i].X.T, color='gray', alpha=0.3, s=30, edgecolor='none') # Plot the 95% predictive interval plt.plot(ys_upper[:,0], ys_upper[:,1], color=plot_color, alpha=1, linestyle='--', linewidth=2) plt.plot(ys_upper[:,0], ys_lower[:,1], color=plot_color, alpha=1, linestyle='--', linewidth=2) plt.title(f'Site {site_i}') #save at 600 dpi #plt.savefig(f'../../exports/figures/site_fit_human_example_free_ns.png', dpi=600) i+=1