""" Explanatory model for EWS performance based on climate and disturbance characteristics. This script loads the necessary data, preprocesses it, trains an XGBoost model for each combination of vegetation index (VI) and early warning signal (EWS), and evaluates model performance and feature importance. It also computes SHAP values for interpretability. For the full dataset, the model training takes around 30 minutes for all VI and EWS combinations. The SHAP Usage: python3 ./code/02_empirical_analysis/05_explainer_model.py hammond_data_path ews_folder climate_data_path model_out_folder date_check force_recompute hammond_data_path: Path to the Hammond CSV file (default: "./data/intermediate/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adj.csv") ews_folder: Directory containing EWS, resistance, and recovery data (default: "./data/final/MOD13Q1_ews_resistance_recovery") climate_data_path: Path to the TERRACLIMATE data CSV file (default: "./data/raw/terra_climate_data/hammond_climate_data_terraclimate.csv") model_out_folder: Directory to save model outputs (default: "./data/final/explainer_model_output") date_check: Date string to append to output files (default: today's date in 'yy_mm_dd' format) force_recompute: If set to "True", forces recomputation of all models (default: "False") """ import os import sys import pandas as pd import numpy as np import xarray as xr from collections import OrderedDict import statsmodels.api as sm import pickle from sklearn.model_selection import train_test_split from sklearn.ensemble import GradientBoostingRegressor import xgboost as xgb from sklearn.inspection import permutation_importance import itertools from datetime import datetime import shap from tqdm.autonotebook import tqdm #Replace print with time-stamped print _print = print def print(*args, **kwargs): _print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]', *args, **kwargs) ### LOAD MORTALITY EVENT DATA ### def load_hammond_data(hammond_data_path): print('Loading Hammond data...') df_hammond = pd.read_csv(hammond_data_path) return df_hammond def load_resistance_data(ews_folder): #Get the path of the newest all joint resistance file all_files = [f for f in os.listdir(ews_folder) if 'all_resistance' in f and f.endswith('.feather')] #Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name) all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]]) #Get the last one latest_file = all_files[-1] print(f'Loading resistance data from {latest_file}...') df_resistance = pd.read_feather(os.path.join(ews_folder, latest_file)) #Replace any true_pos_adj with true_pos df_resistance['true_pos_neg'] = df_resistance['true_pos_neg'].replace('true_pos_adj', 'true_pos') return df_resistance def load_recovery_data(ews_folder): #Get the path of the newest all joint recovery file all_files = [f for f in os.listdir(ews_folder) if 'all_exp' in f and f.endswith('.feather')] #Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name) all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]]) #Get the last one latest_file = all_files[-1] print(f'Loading recovery data from {latest_file}...') df_recovery = pd.read_feather(os.path.join(ews_folder, latest_file)) #Replace any true_pos_adj with true_pos df_recovery['true_pos_neg'] = df_recovery['true_pos_neg'].replace('true_pos_adj', 'true_pos') #Invert recovery rate to be larger and positive for faster recovery df_recovery['r'] = -df_recovery['r'] return df_recovery def load_ews_data(ews_folder): #print('Loading EWS data...') #Get the path of the newest all joint ews file all_files = [f for f in os.listdir(ews_folder) if 'all_ews' in f and f.endswith('.feather')] #Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name) all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]]) #Get the last one latest_file = all_files[-1] print(f'Loading EWS data from {latest_file}...') df_ews = pd.read_feather(os.path.join(ews_folder, latest_file)) #Drop irrelevant EWS df_ews = df_ews.dropna(subset = ['kt_stat']).reset_index(drop=True) #Also remove the full setup because we won't use that for the models df_ews = df_ews[df_ews['Setup'] != 'full'].reset_index(drop=True) #Remove any potential duplicates df_ews = df_ews.drop_duplicates(['true_pos_neg', 'VI', 'Setup', 'EWS', 'n', 'kt_stat', 'kt_pval', 'delta', 'pval_sig'], keep = 'first') #Replace true_pos_adj with true_pos df_ews['true_pos_neg'] = df_ews['true_pos_neg'].replace('true_pos_adj', 'true_pos') #For each setup, keep only complete pairs ews_counts = df_ews.loc[df_ews.true_pos_neg.isin(['true_pos', 'true_neg_ratio'])].groupby(['paired_id', 'VI','rolling_window', 'Setup', 'EWS']).size().reset_index(name = 'counts') #Add this to the main table df_ews = df_ews.merge(ews_counts, on=['paired_id', 'VI', 'rolling_window', 'Setup', 'EWS'], how='left') #Create kt_trend column df_ews['kt_trend'] = 'Not significant' df_ews.loc[((df_ews.kt_stat > 0) & (df_ews.kt_pval < 0.05)), 'kt_trend'] = 'Significant positive trend' df_ews.loc[((df_ews.kt_stat < 0) & (df_ews.kt_pval < 0.05)), 'kt_trend'] = 'Significant negative trend' return df_ews ### LOAD DRIVER DATA ### #Function to load and preprocess climate data from TERRACLIMATE def climate_data_preprocessing(climate_data_path): #Load the data df_climate = pd.read_csv(climate_data_path) #Make one long version of it to actually understand the data df_long = pd.melt(df_climate, id_vars = ['Ref_ID', 'Longitude', 'Latitude', 'Year'], var_name = 'variable', value_name = 'value') #Now extract other things: year and month df_long['Month'] = [x.split('_')[2] for x in df_long['variable'].values] #extract the numbers df_long['Month'] = df_long['Month'].str.replace('month', '').astype(int) #Get year df_long['Year'] = [x.split('_')[1] for x in df_long['variable'].values] df_long['Year'] = df_long['Year'].str.replace('year', '').astype(int)#.abs() #Get variable df_long['Variable'] = [x.split('_')[0] for x in df_long['variable'].values] #Drop the 'variable' column df_long = df_long.drop('variable', axis = 1) #For the long version, also compute annual means annual_means = df_long.groupby(['Ref_ID', 'Year', 'Variable', 'Longitude', 'Latitude']).value.mean().reset_index() #Take those columns that contain a year term, drop the month value and get unique values years = [y for y in df_climate.columns if 'year' in y] years = ['_'.join(y.split('_')[0:2]) for y in years] #.flatten().unique() years = list(OrderedDict((x, True) for x in years).keys()) #Replace the - with _ years_new_names = [y.replace('-', '_') for y in years] #Now compute the mean per each unique set of columns #make list of new cols new_cols = [df_climate['Ref_ID'], df_climate['Longitude'], df_climate['Latitude'], df_climate['Year']] for y in years: cols = [c for c in df_climate.columns if y in c] new_cols.append(df_climate[cols].mean(axis=1)) #Drop the original ones df_climate.drop(cols, axis=1, inplace=True) #Also add the annual means df_new = pd.concat(new_cols, axis=1) df_new.columns = ['Ref_ID', 'Longitude', 'Latitude', 'Year'] + years_new_names #Now keep only the climate columns with year_1 (i.e. year before disturbance) cols_keepclimate = df_new.columns #Keep only those with _year_1 cols_keepclimate = [c for c in cols_keepclimate if '_year_1' in c and not '_year_10' in c and not '_year_11' in c] [cols_keepclimate.append(i) for i in ['Ref_ID', 'Longitude', 'Latitude']] #KEep only those columns and rename lon and lat df_climate_clean = df_new[cols_keepclimate] df_climate_clean.rename(columns = {'Longitude': 'lon', 'Latitude': 'lat'}, inplace=True) return df_climate_clean ### WRAPPER FUNCTION TO LOAD ALL INPUT DATA ### def load_input_data(hammond_data_path, ews_folder, climate_data_path): #Load Hammond data df_hammond = load_hammond_data(hammond_data_path) #Load EWS data df_ews = load_ews_data(ews_folder) #Load resistance data df_resistance = load_resistance_data(ews_folder) #Load recovery data df_recovery = load_recovery_data(ews_folder) #Load climate data df_climate = climate_data_preprocessing(climate_data_path) #Join all of these together print('Joining all data together...') df_joint = df_ews.merge(df_hammond[['entry_id', 'Ref_ID', 'biome', 'year_disturbance', 'species', 'lon', 'lat']], how = 'left', on ='entry_id') df_joint = df_joint.merge(df_climate, how = 'left', on = ['Ref_ID', 'lon', 'lat']) df_joint = df_joint.merge(df_recovery, how = 'left', on = ['entry_id', 'true_pos_neg', 'VI']) df_joint = df_joint.merge(df_resistance, how = 'left') #Make kt_trend categorical df_joint['kt_trend'] = df_joint['kt_trend'].astype('category') #Make an extra column with absolute latitude df_joint['abs_lat'] = df_joint['lat'].abs() #Drop duplicated values df_joint = df_joint.drop_duplicates().reset_index(drop=True) #Keep only true_pos and true_neg_ratio df_joint = df_joint[df_joint['true_pos_neg'].isin(['true_pos', 'true_neg_ratio'])].reset_index(drop=True) return df_joint ### XGBOOST MODEL TOOLS ### def model_performance(model, X_train, y_train, X_test, y_test, vi, ews): # Get R2 and RMSE on train data y_train_pred = model.predict(X_train) r2_train = model.score(X_train, y_train) rmse_train = np.sqrt(np.mean((y_train - y_train_pred) ** 2)) #Same on test data y_test_pred = model.predict(X_test) r2_test = model.score(X_test, y_test) rmse_test = np.sqrt(np.mean((y_test - y_test_pred) ** 2)) #Return this as dataframe return pd.DataFrame({ 'VI': [vi], 'EWS': [ews], 'r2_train': [r2_train], 'rmse_train': [rmse_train], 'r2_test': [r2_test], 'rmse_test': [rmse_test] }) #Function to compute feature importance and permutation importance def importance(model, X_test, y_test, vi, ews): # Get feature importance importance = model.feature_importances_ #print('importance done') # Get permutation importance perm_importance = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42) #print('permutation importance done') #Check lengths of different things return pd.DataFrame({ 'VI' : vi, 'EWS' : ews, 'feature': X_test.columns, 'importance': importance, 'permutation_importance_mean': perm_importance.importances_mean, 'permutation_importance_std': perm_importance.importances_std }) #Function to train the model and get performance and feature importance def train_model(df, vi, ews, drivers, cat_drivers): #Function to train xgboost model #Get the data df_model = df.loc[(df.VI == vi) & (df.EWS == ews) & (df.true_pos_neg.isin(['true_pos', 'true_neg_ratio']))].dropna().reset_index(drop = True) #Make sure the drivers are in the DataFrame and in the correct format for col in drivers: if col not in cat_drivers: df_model[f'{col}'] = df_model[col].astype(float) else: df_model[f'{col}'] = df_model[col].astype('category') #Turn everything into categorical counts df_model_encoded = df_model.copy() for col in df_model_encoded.select_dtypes(['category']).columns: df_model_encoded[col] = df_model_encoded[col].cat.codes # Build formula from drivers list and cat_drivers X = df_model_encoded[[f'{col}' for col in drivers]] y = df_model_encoded['kt_stat'] #print(X.shape, y.shape) #Split into train and test set X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) #Fit the model model = xgb.XGBRegressor(random_state=42, enable_categorical=False) model.fit(X_train, y_train) #Compute model performance performance_df = model_performance(model, X_train, y_train, X_test, y_test, vi, ews) #Compute feature importance importance_df = importance(model, X_test, y_test, vi, ews) return performance_df, importance_df, model, X_train, y_train, X_test, y_test, X, y #Wrapper function to apply this per VI and EWS def model_wrapper(df_joint, drivers, cat_drivers, model_out_folder, date_check, force_recompute): #If the output folder does not exist, create it if not os.path.exists(model_out_folder): os.makedirs(model_out_folder) #Make a subfolder in there that is called models model_subfolder = os.path.join(model_out_folder, 'models') if not os.path.exists(model_subfolder): os.makedirs(model_subfolder) #If we don't already have the model list in pickle, continue with computation if os.path.exists(os.path.join(model_subfolder, f'{date_check}_model_list.pkl')) and not force_recompute: print('Model list alrady computed, not running anything...') #Load the model list and X_train_list from pickle with open(os.path.join(model_subfolder, f'{date_check}_model_list.pkl'), 'rb') as f: model_list = pickle.load(f) with open(os.path.join(model_subfolder, f'{date_check}_X_train_list.pkl'), 'rb') as f: X_train_list = pickle.load(f) return model_list, X_train_list else: print('Training models for each VI and EWS combination...') #Make one per EWS and VI model_list = [] importance_list = [] performance_list = [] X_train_list = [] y_train_list = [] X_test_list = [] y_test_list = [] X_list = [] y_list = [] for vi in tqdm(df_joint.VI.unique(), desc='VI loop', total = len(df_joint.VI.unique())): for ews in df_joint.EWS.unique(): print(vi, ews) #Check if the model already exists, if so skip if os.path.exists(os.path.join(model_subfolder, f'{date_check}_model_{vi}_{ews}.pkl')) and not force_recompute: print(f'Model for {vi} and {ews} already exists, skipping...') continue else: performance_df, importance_df, model, X_train, y_train, X_test, y_test, X, y = train_model(df_joint, vi, ews, drivers, cat_drivers) #Append the results to the lists performance_list.append(performance_df) importance_list.append(importance_df) model_list.append(model) X_train_list.append(X_train) y_train_list.append(y_train) X_test_list.append(X_test) y_test_list.append(y_test) X_list.append(X) y_list.append(y) #Save the model with open(os.path.join(model_subfolder, f'{date_check}_model_{vi}_{ews}.pkl'), 'wb') as f: pickle.dump(model, f) #Save X_train as well with open(os.path.join(model_subfolder, f'{date_check}_X_train_{vi}_{ews}.pkl'), 'wb') as f: pickle.dump(X_train, f) #Make importance list into one df and save to feather importance_df = pd.concat(importance_list, axis=0).reset_index(drop=True) importance_df.to_feather(os.path.join(model_out_folder, f'{date_check}_importance.feather')) performance_df = pd.concat(performance_list, axis=0).reset_index(drop=True) performance_df.to_feather(os.path.join(model_out_folder, f'{date_check}_performance.feather')) #Dump the model list and X_train_list to pickle with open(os.path.join(model_subfolder, f'{date_check}_model_list.pkl'), 'wb') as f: pickle.dump(model_list, f) with open(os.path.join(model_subfolder, f'{date_check}_X_train_list.pkl'), 'wb') as f: pickle.dump(X_train_list, f) #Delete the single models from memory to save space for f in os.listdir(model_subfolder): if f.endswith('.pkl') and 'VI' in f: os.remove(os.path.join(model_subfolder, f)) return model_list, X_train_list #SHAP analysis def shap_analysis(model_list, X_train_list, model_out_folder, date_check): print('Running SHAP analysis...') #Check if the SHAP subfolder exists, if not create it shap_subfolder = os.path.join(model_out_folder, 'shap') if not os.path.exists(shap_subfolder): print('Creating SHAP subfolder...') os.makedirs(shap_subfolder) #For each model, compute the SHAP values explainer_list = [] shap_values_list = [] for i, model in enumerate(model_list): print(i) #Use the first model to create the explainer explainer = shap.TreeExplainer(model, X_train_list[i]) #, feature_perturbation='tree_path_dependent') #Compute the SHAP values shap_values = explainer(X_train_list[i]) #Append to the list explainer_list.append(explainer) shap_values_list.append(shap_values) #Save these to file with open(os.path.join(shap_subfolder, f'{date_check}_explainer_list.pkl'), 'wb') as f: pickle.dump(explainer_list, f) with open(os.path.join(shap_subfolder, f'{date_check}_shap_values_list.pkl'), 'wb') as f: pickle.dump(shap_values_list, f) #Save X_train_list to pickle with open(os.path.join(shap_subfolder, f'{date_check}_X_train_list.pkl'), 'wb') as f: pickle.dump(X_train_list, f) def main(hammond_data_path, ews_folder, climate_data_path, model_out_folder, date_check, force_recompute=False): #Check if we already have the final product, ie the shap values shap_subfolder = os.path.join(model_out_folder, 'shap') if os.path.exists(os.path.join(shap_subfolder, f'{date_check}_shap_values_list.pkl')) and not force_recompute: print('SHAP values already computed, nothing to recompute...') return else: print('SHAP values not found, proceeding with computation...') #Load all input data df_joint = load_input_data(hammond_data_path, ews_folder, climate_data_path) print(df_joint.head()) #Define drivers and categorical drivers drivers = ['rolling_window', 'Setup', 'aet_year_1', 'def_year_1', 'soil_year_1', 'srad_year_1','tmax_year_1', 'vpd_year_1', 'PDSI_year_1','r', 'resistance_sd_ratio', 'abs_lat'] #, 'true_pos_neg'] #List of categorical drivers cat_drivers = ['rolling_window', 'Setup', 'true_pos_neg'] #Run the model wrapper model_list, X_train_list = model_wrapper(df_joint, drivers, cat_drivers, model_out_folder, date_check, force_recompute) #Run SHAP analysis shap_analysis(model_list, X_train_list, model_out_folder, date_check) if __name__ == "__main__": #Get command line arguments hammond_data_path = sys.argv[1] if len(sys.argv) > 1 else "./data/intermediary/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adjusted.csv" ews_folder = sys.argv[2] if len(sys.argv) > 2 else "./data/final/MOD13Q1_ews_resistance_recovery" climate_data_path = sys.argv[3] if len(sys.argv) > 3 else "./data/raw/terra_climate_data/hammond_climate_data_terraclimate.csv" model_out_folder = sys.argv[4] if len(sys.argv) > 4 else "./data/final/explainer_model_output" date_check = sys.argv[5] if len(sys.argv) > 5 else datetime.today().strftime('%y_%m_%d') force_recompute = bool(int(sys.argv[6])) if len(sys.argv) > 6 else False #Load all input data main(hammond_data_path, ews_folder, climate_data_path, model_out_folder, date_check, force_recompute)