"""
Explanatory model for EWS performance based on climate and disturbance characteristics.
This script loads the necessary data, preprocesses it, trains an XGBoost model for each combination of vegetation index (VI) and early warning signal (EWS),
and evaluates model performance and feature importance. It also computes SHAP values for interpretability.
For the full dataset, the model training takes around 30 minutes for all VI and EWS combinations.
The SHAP
Usage: python3 ./code/02_empirical_analysis/05_explainer_model.py hammond_data_path ews_folder climate_data_path model_out_folder date_check force_recompute
hammond_data_path: Path to the Hammond CSV file (default: "./data/intermediate/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adj.csv")
ews_folder: Directory containing EWS, resistance, and recovery data (default: "./data/final/MOD13Q1_ews_resistance_recovery")
climate_data_path: Path to the TERRACLIMATE data CSV file (default: "./data/raw/terra_climate_data/hammond_climate_data_terraclimate.csv")
model_out_folder: Directory to save model outputs (default: "./data/final/explainer_model_output")
date_check: Date string to append to output files (default: today's date in 'yy_mm_dd' format)
force_recompute: If set to "True", forces recomputation of all models (default: "False")
"""
import os
import sys
import pandas as pd
import numpy as np
import xarray as xr
from collections import OrderedDict
import statsmodels.api as sm
import pickle
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
import xgboost as xgb
from sklearn.inspection import permutation_importance
import itertools
from datetime import datetime
import shap
from tqdm.autonotebook import tqdm
#Replace print with time-stamped print
_print = print
def print(*args, **kwargs):
_print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]', *args, **kwargs)
### LOAD MORTALITY EVENT DATA ###
def load_hammond_data(hammond_data_path):
print('Loading Hammond data...')
df_hammond = pd.read_csv(hammond_data_path)
return df_hammond
def load_resistance_data(ews_folder):
#Get the path of the newest all joint resistance file
all_files = [f for f in os.listdir(ews_folder) if 'all_resistance' in f and f.endswith('.feather')]
#Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name)
all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]])
#Get the last one
latest_file = all_files[-1]
print(f'Loading resistance data from {latest_file}...')
df_resistance = pd.read_feather(os.path.join(ews_folder, latest_file))
#Replace any true_pos_adj with true_pos
df_resistance['true_pos_neg'] = df_resistance['true_pos_neg'].replace('true_pos_adj', 'true_pos')
return df_resistance
def load_recovery_data(ews_folder):
#Get the path of the newest all joint recovery file
all_files = [f for f in os.listdir(ews_folder) if 'all_exp' in f and f.endswith('.feather')]
#Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name)
all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]])
#Get the last one
latest_file = all_files[-1]
print(f'Loading recovery data from {latest_file}...')
df_recovery = pd.read_feather(os.path.join(ews_folder, latest_file))
#Replace any true_pos_adj with true_pos
df_recovery['true_pos_neg'] = df_recovery['true_pos_neg'].replace('true_pos_adj', 'true_pos')
#Invert recovery rate to be larger and positive for faster recovery
df_recovery['r'] = -df_recovery['r']
return df_recovery
def load_ews_data(ews_folder):
#print('Loading EWS data...')
#Get the path of the newest all joint ews file
all_files = [f for f in os.listdir(ews_folder) if 'all_ews' in f and f.endswith('.feather')]
#Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name)
all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]])
#Get the last one
latest_file = all_files[-1]
print(f'Loading EWS data from {latest_file}...')
df_ews = pd.read_feather(os.path.join(ews_folder, latest_file))
#Drop irrelevant EWS
df_ews = df_ews.dropna(subset = ['kt_stat']).reset_index(drop=True)
#Also remove the full setup because we won't use that for the models
df_ews = df_ews[df_ews['Setup'] != 'full'].reset_index(drop=True)
#Remove any potential duplicates
df_ews = df_ews.drop_duplicates(['true_pos_neg', 'VI', 'Setup', 'EWS', 'n', 'kt_stat', 'kt_pval', 'delta', 'pval_sig'], keep = 'first')
#Replace true_pos_adj with true_pos
df_ews['true_pos_neg'] = df_ews['true_pos_neg'].replace('true_pos_adj', 'true_pos')
#For each setup, keep only complete pairs
ews_counts = df_ews.loc[df_ews.true_pos_neg.isin(['true_pos', 'true_neg_ratio'])].groupby(['paired_id', 'VI','rolling_window', 'Setup', 'EWS']).size().reset_index(name = 'counts')
#Add this to the main table
df_ews = df_ews.merge(ews_counts, on=['paired_id', 'VI', 'rolling_window', 'Setup', 'EWS'], how='left')
#Create kt_trend column
df_ews['kt_trend'] = 'Not significant'
df_ews.loc[((df_ews.kt_stat > 0) & (df_ews.kt_pval < 0.05)), 'kt_trend'] = 'Significant positive trend'
df_ews.loc[((df_ews.kt_stat < 0) & (df_ews.kt_pval < 0.05)), 'kt_trend'] = 'Significant negative trend'
return df_ews
### LOAD DRIVER DATA ###
#Function to load and preprocess climate data from TERRACLIMATE
def climate_data_preprocessing(climate_data_path):
#Load the data
df_climate = pd.read_csv(climate_data_path)
#Make one long version of it to actually understand the data
df_long = pd.melt(df_climate, id_vars = ['Ref_ID', 'Longitude', 'Latitude', 'Year'], var_name = 'variable', value_name = 'value')
#Now extract other things: year and month
df_long['Month'] = [x.split('_')[2] for x in df_long['variable'].values]
#extract the numbers
df_long['Month'] = df_long['Month'].str.replace('month', '').astype(int)
#Get year
df_long['Year'] = [x.split('_')[1] for x in df_long['variable'].values]
df_long['Year'] = df_long['Year'].str.replace('year', '').astype(int)#.abs()
#Get variable
df_long['Variable'] = [x.split('_')[0] for x in df_long['variable'].values]
#Drop the 'variable' column
df_long = df_long.drop('variable', axis = 1)
#For the long version, also compute annual means
annual_means = df_long.groupby(['Ref_ID', 'Year',
'Variable', 'Longitude',
'Latitude']).value.mean().reset_index()
#Take those columns that contain a year term, drop the month value and get unique values
years = [y for y in df_climate.columns if 'year' in y]
years = ['_'.join(y.split('_')[0:2]) for y in years] #.flatten().unique()
years = list(OrderedDict((x, True) for x in years).keys())
#Replace the - with _
years_new_names = [y.replace('-', '_') for y in years]
#Now compute the mean per each unique set of columns
#make list of new cols
new_cols = [df_climate['Ref_ID'], df_climate['Longitude'], df_climate['Latitude'], df_climate['Year']]
for y in years:
cols = [c for c in df_climate.columns if y in c]
new_cols.append(df_climate[cols].mean(axis=1))
#Drop the original ones
df_climate.drop(cols, axis=1, inplace=True)
#Also add the annual means
df_new = pd.concat(new_cols, axis=1)
df_new.columns = ['Ref_ID', 'Longitude', 'Latitude', 'Year'] + years_new_names
#Now keep only the climate columns with year_1 (i.e. year before disturbance)
cols_keepclimate = df_new.columns
#Keep only those with _year_1
cols_keepclimate = [c for c in cols_keepclimate if '_year_1' in c and not '_year_10' in c and not '_year_11' in c]
[cols_keepclimate.append(i) for i in ['Ref_ID', 'Longitude', 'Latitude']]
#KEep only those columns and rename lon and lat
df_climate_clean = df_new[cols_keepclimate]
df_climate_clean.rename(columns = {'Longitude': 'lon', 'Latitude': 'lat'}, inplace=True)
return df_climate_clean
### WRAPPER FUNCTION TO LOAD ALL INPUT DATA ###
def load_input_data(hammond_data_path, ews_folder, climate_data_path):
#Load Hammond data
df_hammond = load_hammond_data(hammond_data_path)
#Load EWS data
df_ews = load_ews_data(ews_folder)
#Load resistance data
df_resistance = load_resistance_data(ews_folder)
#Load recovery data
df_recovery = load_recovery_data(ews_folder)
#Load climate data
df_climate = climate_data_preprocessing(climate_data_path)
#Join all of these together
print('Joining all data together...')
df_joint = df_ews.merge(df_hammond[['entry_id', 'Ref_ID', 'biome', 'year_disturbance', 'species', 'lon', 'lat']], how = 'left', on ='entry_id')
df_joint = df_joint.merge(df_climate, how = 'left', on = ['Ref_ID', 'lon', 'lat'])
df_joint = df_joint.merge(df_recovery, how = 'left', on = ['entry_id', 'true_pos_neg', 'VI'])
df_joint = df_joint.merge(df_resistance, how = 'left')
#Make kt_trend categorical
df_joint['kt_trend'] = df_joint['kt_trend'].astype('category')
#Make an extra column with absolute latitude
df_joint['abs_lat'] = df_joint['lat'].abs()
#Drop duplicated values
df_joint = df_joint.drop_duplicates().reset_index(drop=True)
#Keep only true_pos and true_neg_ratio
df_joint = df_joint[df_joint['true_pos_neg'].isin(['true_pos', 'true_neg_ratio'])].reset_index(drop=True)
return df_joint
### XGBOOST MODEL TOOLS ###
def model_performance(model, X_train, y_train, X_test, y_test, vi, ews):
# Get R2 and RMSE on train data
y_train_pred = model.predict(X_train)
r2_train = model.score(X_train, y_train)
rmse_train = np.sqrt(np.mean((y_train - y_train_pred) ** 2))
#Same on test data
y_test_pred = model.predict(X_test)
r2_test = model.score(X_test, y_test)
rmse_test = np.sqrt(np.mean((y_test - y_test_pred) ** 2))
#Return this as dataframe
return pd.DataFrame({
'VI': [vi],
'EWS': [ews],
'r2_train': [r2_train],
'rmse_train': [rmse_train],
'r2_test': [r2_test],
'rmse_test': [rmse_test]
})
#Function to compute feature importance and permutation importance
def importance(model, X_test, y_test, vi, ews):
# Get feature importance
importance = model.feature_importances_
#print('importance done')
# Get permutation importance
perm_importance = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42)
#print('permutation importance done')
#Check lengths of different things
return pd.DataFrame({
'VI' : vi,
'EWS' : ews,
'feature': X_test.columns,
'importance': importance,
'permutation_importance_mean': perm_importance.importances_mean,
'permutation_importance_std': perm_importance.importances_std
})
#Function to train the model and get performance and feature importance
def train_model(df, vi, ews, drivers, cat_drivers):
#Function to train xgboost model
#Get the data
df_model = df.loc[(df.VI == vi) & (df.EWS == ews) & (df.true_pos_neg.isin(['true_pos', 'true_neg_ratio']))].dropna().reset_index(drop = True)
#Make sure the drivers are in the DataFrame and in the correct format
for col in drivers:
if col not in cat_drivers:
df_model[f'{col}'] = df_model[col].astype(float)
else:
df_model[f'{col}'] = df_model[col].astype('category')
#Turn everything into categorical counts
df_model_encoded = df_model.copy()
for col in df_model_encoded.select_dtypes(['category']).columns:
df_model_encoded[col] = df_model_encoded[col].cat.codes
# Build formula from drivers list and cat_drivers
X = df_model_encoded[[f'{col}' for col in drivers]]
y = df_model_encoded['kt_stat']
#print(X.shape, y.shape)
#Split into train and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
#Fit the model
model = xgb.XGBRegressor(random_state=42, enable_categorical=False)
model.fit(X_train, y_train)
#Compute model performance
performance_df = model_performance(model, X_train, y_train, X_test, y_test, vi, ews)
#Compute feature importance
importance_df = importance(model, X_test, y_test, vi, ews)
return performance_df, importance_df, model, X_train, y_train, X_test, y_test, X, y
#Wrapper function to apply this per VI and EWS
def model_wrapper(df_joint, drivers, cat_drivers, model_out_folder, date_check, force_recompute):
#If the output folder does not exist, create it
if not os.path.exists(model_out_folder):
os.makedirs(model_out_folder)
#Make a subfolder in there that is called models
model_subfolder = os.path.join(model_out_folder, 'models')
if not os.path.exists(model_subfolder):
os.makedirs(model_subfolder)
#If we don't already have the model list in pickle, continue with computation
if os.path.exists(os.path.join(model_subfolder, f'{date_check}_model_list.pkl')) and not force_recompute:
print('Model list alrady computed, not running anything...')
#Load the model list and X_train_list from pickle
with open(os.path.join(model_subfolder, f'{date_check}_model_list.pkl'), 'rb') as f:
model_list = pickle.load(f)
with open(os.path.join(model_subfolder, f'{date_check}_X_train_list.pkl'), 'rb') as f:
X_train_list = pickle.load(f)
return model_list, X_train_list
else:
print('Training models for each VI and EWS combination...')
#Make one per EWS and VI
model_list = []
importance_list = []
performance_list = []
X_train_list = []
y_train_list = []
X_test_list = []
y_test_list = []
X_list = []
y_list = []
for vi in tqdm(df_joint.VI.unique(), desc='VI loop', total = len(df_joint.VI.unique())):
for ews in df_joint.EWS.unique():
print(vi, ews)
#Check if the model already exists, if so skip
if os.path.exists(os.path.join(model_subfolder, f'{date_check}_model_{vi}_{ews}.pkl')) and not force_recompute:
print(f'Model for {vi} and {ews} already exists, skipping...')
continue
else:
performance_df, importance_df, model, X_train, y_train, X_test, y_test, X, y = train_model(df_joint, vi, ews, drivers, cat_drivers)
#Append the results to the lists
performance_list.append(performance_df)
importance_list.append(importance_df)
model_list.append(model)
X_train_list.append(X_train)
y_train_list.append(y_train)
X_test_list.append(X_test)
y_test_list.append(y_test)
X_list.append(X)
y_list.append(y)
#Save the model
with open(os.path.join(model_subfolder, f'{date_check}_model_{vi}_{ews}.pkl'), 'wb') as f:
pickle.dump(model, f)
#Save X_train as well
with open(os.path.join(model_subfolder, f'{date_check}_X_train_{vi}_{ews}.pkl'), 'wb') as f:
pickle.dump(X_train, f)
#Make importance list into one df and save to feather
importance_df = pd.concat(importance_list, axis=0).reset_index(drop=True)
importance_df.to_feather(os.path.join(model_out_folder, f'{date_check}_importance.feather'))
performance_df = pd.concat(performance_list, axis=0).reset_index(drop=True)
performance_df.to_feather(os.path.join(model_out_folder, f'{date_check}_performance.feather'))
#Dump the model list and X_train_list to pickle
with open(os.path.join(model_subfolder, f'{date_check}_model_list.pkl'), 'wb') as f:
pickle.dump(model_list, f)
with open(os.path.join(model_subfolder, f'{date_check}_X_train_list.pkl'), 'wb') as f:
pickle.dump(X_train_list, f)
#Delete the single models from memory to save space
for f in os.listdir(model_subfolder):
if f.endswith('.pkl') and 'VI' in f:
os.remove(os.path.join(model_subfolder, f))
return model_list, X_train_list
#SHAP analysis
def shap_analysis(model_list, X_train_list, model_out_folder, date_check):
print('Running SHAP analysis...')
#Check if the SHAP subfolder exists, if not create it
shap_subfolder = os.path.join(model_out_folder, 'shap')
if not os.path.exists(shap_subfolder):
print('Creating SHAP subfolder...')
os.makedirs(shap_subfolder)
#For each model, compute the SHAP values
explainer_list = []
shap_values_list = []
for i, model in enumerate(model_list):
print(i)
#Use the first model to create the explainer
explainer = shap.TreeExplainer(model, X_train_list[i]) #, feature_perturbation='tree_path_dependent')
#Compute the SHAP values
shap_values = explainer(X_train_list[i])
#Append to the list
explainer_list.append(explainer)
shap_values_list.append(shap_values)
#Save these to file
with open(os.path.join(shap_subfolder, f'{date_check}_explainer_list.pkl'), 'wb') as f:
pickle.dump(explainer_list, f)
with open(os.path.join(shap_subfolder, f'{date_check}_shap_values_list.pkl'), 'wb') as f:
pickle.dump(shap_values_list, f)
#Save X_train_list to pickle
with open(os.path.join(shap_subfolder, f'{date_check}_X_train_list.pkl'), 'wb') as f:
pickle.dump(X_train_list, f)
def main(hammond_data_path, ews_folder, climate_data_path, model_out_folder, date_check, force_recompute=False):
#Check if we already have the final product, ie the shap values
shap_subfolder = os.path.join(model_out_folder, 'shap')
if os.path.exists(os.path.join(shap_subfolder, f'{date_check}_shap_values_list.pkl')) and not force_recompute:
print('SHAP values already computed, nothing to recompute...')
return
else:
print('SHAP values not found, proceeding with computation...')
#Load all input data
df_joint = load_input_data(hammond_data_path, ews_folder, climate_data_path)
print(df_joint.head())
#Define drivers and categorical drivers
drivers = ['rolling_window', 'Setup', 'aet_year_1', 'def_year_1', 'soil_year_1', 'srad_year_1','tmax_year_1',
'vpd_year_1', 'PDSI_year_1','r', 'resistance_sd_ratio', 'abs_lat'] #, 'true_pos_neg']
#List of categorical drivers
cat_drivers = ['rolling_window', 'Setup', 'true_pos_neg']
#Run the model wrapper
model_list, X_train_list = model_wrapper(df_joint, drivers, cat_drivers, model_out_folder, date_check, force_recompute)
#Run SHAP analysis
shap_analysis(model_list, X_train_list, model_out_folder, date_check)
if __name__ == "__main__":
#Get command line arguments
hammond_data_path = sys.argv[1] if len(sys.argv) > 1 else "./data/intermediary/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adjusted.csv"
ews_folder = sys.argv[2] if len(sys.argv) > 2 else "./data/final/MOD13Q1_ews_resistance_recovery"
climate_data_path = sys.argv[3] if len(sys.argv) > 3 else "./data/raw/terra_climate_data/hammond_climate_data_terraclimate.csv"
model_out_folder = sys.argv[4] if len(sys.argv) > 4 else "./data/final/explainer_model_output"
date_check = sys.argv[5] if len(sys.argv) > 5 else datetime.today().strftime('%y_%m_%d')
force_recompute = bool(int(sys.argv[6])) if len(sys.argv) > 6 else False
#Load all input data
main(hammond_data_path, ews_folder, climate_data_path, model_out_folder, date_check, force_recompute)