###################
### Description ###
###################
# This script is the for the analysis of our sensivity analyses - the effects of number of sites,
# number of samples, and length of x_var on our N/s estimates.
# Relevant figures:
# - Plotting N/s vs. number of sites (Supp. Fig. 2b)
# - Plotting N/s vs. number of samples (Supp. Fig. 2c)
# - Plotting N/s vs. analysed timespan (Supp. Figs. 3a-b)
# Inputs:
# - Model outputs (traces) from 'run_humans_sensitivity_n_sites.py', 'run_humans_sensitivity_sample_size.py',
# and 'run_humans_sensitivity_timespans.py'
# Outputs:
# - Plots of N/s against the varied parameter, with error bars showing 95% HPDI
##############
### Author ###
##############
# Sam Crofts (sam.crofts@ed.ac.uk)
##################################
##################################
#Imports
import os
import sys
sys.path.append("../..") # fix to import modules from root
from src.general_imports import *
#which sensitivity analysis are we running? ("n_sites","sample_size","timespans")
x_var = 'timespans'
#make a df with N/s and x_var
df = pd.DataFrame(columns=['mean','upper','lower', 'x_var'])
#Change directory depending on which senstivity analysis we want to run
for file in os.listdir('../../exports/model_outputs/humans/sensitivity_analyses/'+x_var+'/'):
#for timespans, change depending on ig white or spearman
if 'white' in file:
name = file.replace('.pk', '')
#open the file
with open(f'../../exports/model_outputs/humans/sensitivity_analyses/{x_var}/{file}', 'rb') as f:
trace = pickle.load(f)
f.close()
if x_var == 'timespans':
#get the x_var value
val = name.split('_')[5]
elif x_var == 'n_sites':
val = name.split('_')[0]
#remove "ss"
val = val.replace('nsites', '')
elif x_var == 'sample_size':
val = name.split('_')[1]
#remove "ss"
val = val.replace('ss', '')
#make an int
val = float(val)
#get N/s from the trace
n_on_s = pm.summary(trace, hdi_prob=0.95, var_names=['N/s'])['mean'][0]
n_on_s_upper = pm.summary(trace, hdi_prob=0.95, var_names=['N/s'])['hdi_97.5%'][0]
n_on_s_lower = pm.summary(trace, hdi_prob=0.95, var_names=['N/s'])['hdi_2.5%'][0]
#plot N/s posterior
pm.plot_posterior(trace, var_names=['N/s'])
#Add x_var to the title
plt.title(f'{val}')
#append to df
df = df._append({'mean':n_on_s, 'upper':n_on_s_upper, 'lower':n_on_s_lower, 'x_var':val}, ignore_index=True)
#sort by x_var
df = df.sort_values(by='x_var')
#get logs and do regplot
df['log_x_var'] = np.log10(df['x_var'])
df['log_mean'] = np.log10(df['mean'])
#plot un-logged
plt.figure(figsize=(6, 4))
plt.scatter(df['x_var'], (df['mean']), color='royalblue')
plt.errorbar(df['x_var'], (df['mean']), yerr=[df['mean']-df['lower'], df['upper']-df['mean']], fmt='o', color='royalblue', ecolor='lightblue', elinewidth=3, capsize=0)
#save at 600dpi
plt.savefig('../../exports//figures/sensitivity_'+x_var+'_unlogged_spearman.png', dpi=600)
#For timespans, also plotting with the relevant scaling lines
plt.figure(figsize=(6, 4))
plt.errorbar(df['log_x_var'], (df['log_mean']), yerr=[np.log10(df['mean'])-np.log10(df['lower']), np.log10(df['upper'])-np.log10(df['mean'])], fmt='o', color='royalblue', ecolor='lightblue', elinewidth=3, capsize=0)
plt.title('N/s with 95% HPDI')
plt.ylabel('log10(N/s)')
plt.tight_layout()
plt.xlabel('Log10(Maximum age)')
#add a line with equation:
x = np.linspace(0.7, 2, 100)
#y = 1.46*(x+np.log10(18)) + 1.09 #spearman
y = 1.80*(x+np.log10(18)) + 0.757 #white
plt.plot(x, y, color='red')
plt.savefig('../../exports/figures/sensitivity_white_log_with_scaling_line.png', dpi=600)