25-09-ews-assessment / code / 02_empirical_analysis / 05_explainer_model.py
05_explainer_model.py
Raw
"""
Explanatory model for EWS performance based on climate and disturbance characteristics.

This script loads the necessary data, preprocesses it, trains an XGBoost model for each combination of vegetation index (VI) and early warning signal (EWS),
and evaluates model performance and feature importance. It also computes SHAP values for interpretability.

For the full dataset, the model training takes around 30 minutes for all VI and EWS combinations. 
The SHAP 

Usage: python3 ./code/02_empirical_analysis/05_explainer_model.py hammond_data_path ews_folder climate_data_path model_out_folder date_check force_recompute

hammond_data_path: Path to the Hammond CSV file (default: "./data/intermediate/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adj.csv")
ews_folder: Directory containing EWS, resistance, and recovery data (default: "./data/final/MOD13Q1_ews_resistance_recovery")
climate_data_path: Path to the TERRACLIMATE data CSV file (default: "./data/raw/terra_climate_data/hammond_climate_data_terraclimate.csv")
model_out_folder: Directory to save model outputs (default: "./data/final/explainer_model_output")
date_check: Date string to append to output files (default: today's date in 'yy_mm_dd' format)
force_recompute: If set to "True", forces recomputation of all models (default: "False")

"""

import os
import sys
import pandas as pd
import numpy as np
import xarray as xr
from collections import OrderedDict
import statsmodels.api as sm
import pickle
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
import xgboost as xgb
from sklearn.inspection import permutation_importance
import itertools
from datetime import datetime
import shap
from tqdm.autonotebook import tqdm

#Replace print with time-stamped print
_print = print
def print(*args, **kwargs):
    _print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]', *args, **kwargs)


### LOAD MORTALITY EVENT DATA ###
def load_hammond_data(hammond_data_path):
    print('Loading Hammond data...')
    df_hammond = pd.read_csv(hammond_data_path)
    return df_hammond

def load_resistance_data(ews_folder):
    #Get the path of the newest all joint resistance file
    all_files = [f for f in os.listdir(ews_folder) if 'all_resistance' in f and f.endswith('.feather')]
    #Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name)
    all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]])
    #Get the last one
    latest_file = all_files[-1]
    print(f'Loading resistance data from {latest_file}...')
    df_resistance = pd.read_feather(os.path.join(ews_folder, latest_file))
    #Replace any true_pos_adj with true_pos
    df_resistance['true_pos_neg'] = df_resistance['true_pos_neg'].replace('true_pos_adj', 'true_pos')
    return df_resistance

def load_recovery_data(ews_folder):
    #Get the path of the newest all joint recovery file
    all_files = [f for f in os.listdir(ews_folder) if 'all_exp' in f and f.endswith('.feather')]
    #Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name)
    all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]])
    #Get the last one
    latest_file = all_files[-1]
    print(f'Loading recovery data from {latest_file}...')
    df_recovery = pd.read_feather(os.path.join(ews_folder, latest_file))
    #Replace any true_pos_adj with true_pos
    df_recovery['true_pos_neg'] = df_recovery['true_pos_neg'].replace('true_pos_adj', 'true_pos')
    #Invert recovery rate to be larger and positive for faster recovery
    df_recovery['r'] = -df_recovery['r']
    return df_recovery

def load_ews_data(ews_folder):
    #print('Loading EWS data...')
    #Get the path of the newest all joint ews file
    all_files = [f for f in os.listdir(ews_folder) if 'all_ews' in f and f.endswith('.feather')]
    #Sort by date (assuming the date is in the format yy_mm_dd at the start of the file name)
    all_files = sorted(all_files, key=lambda x: [int(i) for i in x.split('_')[0:3]])
    #Get the last one
    latest_file = all_files[-1]
    print(f'Loading EWS data from {latest_file}...')
    df_ews = pd.read_feather(os.path.join(ews_folder, latest_file))

    #Drop irrelevant EWS
    df_ews = df_ews.dropna(subset = ['kt_stat']).reset_index(drop=True)
    #Also remove the full setup because we won't use that for the models
    df_ews = df_ews[df_ews['Setup'] != 'full'].reset_index(drop=True)
    #Remove any potential duplicates
    df_ews = df_ews.drop_duplicates(['true_pos_neg', 'VI', 'Setup', 'EWS', 'n', 'kt_stat', 'kt_pval', 'delta', 'pval_sig'], keep = 'first')

    #Replace true_pos_adj with true_pos
    df_ews['true_pos_neg'] = df_ews['true_pos_neg'].replace('true_pos_adj', 'true_pos')

    #For each setup, keep only complete pairs
    ews_counts = df_ews.loc[df_ews.true_pos_neg.isin(['true_pos', 'true_neg_ratio'])].groupby(['paired_id', 'VI','rolling_window', 'Setup', 'EWS']).size().reset_index(name = 'counts')
    #Add this to the main table
    df_ews = df_ews.merge(ews_counts, on=['paired_id', 'VI', 'rolling_window', 'Setup', 'EWS'], how='left')

    #Create kt_trend column
    df_ews['kt_trend'] = 'Not significant'
    df_ews.loc[((df_ews.kt_stat > 0) & (df_ews.kt_pval < 0.05)), 'kt_trend'] = 'Significant positive trend'
    df_ews.loc[((df_ews.kt_stat < 0) & (df_ews.kt_pval < 0.05)), 'kt_trend'] = 'Significant negative trend'

    return df_ews


### LOAD DRIVER DATA ###

#Function to load and preprocess climate data from TERRACLIMATE
def climate_data_preprocessing(climate_data_path):

    #Load the data
    df_climate = pd.read_csv(climate_data_path)
    #Make one long version of it to actually understand the data
    df_long = pd.melt(df_climate, id_vars = ['Ref_ID', 'Longitude', 'Latitude', 'Year'], var_name = 'variable', value_name = 'value')
    #Now extract other things: year and month
    df_long['Month'] = [x.split('_')[2] for x in df_long['variable'].values]
    #extract the numbers
    df_long['Month'] = df_long['Month'].str.replace('month', '').astype(int) 
    #Get year
    df_long['Year'] = [x.split('_')[1] for x in df_long['variable'].values]
    df_long['Year'] = df_long['Year'].str.replace('year', '').astype(int)#.abs()
    #Get variable
    df_long['Variable'] = [x.split('_')[0] for x in df_long['variable'].values]
    #Drop the 'variable' column
    df_long = df_long.drop('variable', axis = 1)

    #For the long version, also compute annual means
    annual_means = df_long.groupby(['Ref_ID', 'Year', 
                                'Variable', 'Longitude', 
                                'Latitude']).value.mean().reset_index()

    #Take those columns that contain a year term, drop the month value and get unique values
    years = [y for y in df_climate.columns if 'year' in y]
    years = ['_'.join(y.split('_')[0:2]) for y in years] #.flatten().unique()
    years = list(OrderedDict((x, True) for x in years).keys())
    #Replace the - with _
    years_new_names = [y.replace('-', '_') for y in years]

    #Now compute the mean per each unique set of columns
    #make list of new cols
    new_cols = [df_climate['Ref_ID'], df_climate['Longitude'], df_climate['Latitude'], df_climate['Year']]
    for y in years:
        cols = [c for c in df_climate.columns if y in c]
        new_cols.append(df_climate[cols].mean(axis=1))
        #Drop the original ones
        df_climate.drop(cols, axis=1, inplace=True)
    #Also add the annual means
    df_new = pd.concat(new_cols, axis=1)
    df_new.columns = ['Ref_ID', 'Longitude', 'Latitude', 'Year'] + years_new_names

    #Now keep only the climate columns with year_1 (i.e. year before disturbance)
    cols_keepclimate = df_new.columns
    #Keep only those with _year_1
    cols_keepclimate = [c for c in cols_keepclimate if '_year_1' in c and not '_year_10' in c and not '_year_11' in c]
    [cols_keepclimate.append(i) for i in ['Ref_ID', 'Longitude', 'Latitude']]
    #KEep only those columns and rename lon and lat
    df_climate_clean = df_new[cols_keepclimate]
    df_climate_clean.rename(columns = {'Longitude': 'lon', 'Latitude': 'lat'}, inplace=True)

    return df_climate_clean


### WRAPPER FUNCTION TO LOAD ALL INPUT DATA ###
def load_input_data(hammond_data_path, ews_folder, climate_data_path):
    #Load Hammond data
    df_hammond = load_hammond_data(hammond_data_path)
    #Load EWS data
    df_ews = load_ews_data(ews_folder)
    #Load resistance data
    df_resistance = load_resistance_data(ews_folder)
    #Load recovery data
    df_recovery = load_recovery_data(ews_folder)
    #Load climate data
    df_climate = climate_data_preprocessing(climate_data_path)

    #Join all of these together
    print('Joining all data together...')
    df_joint = df_ews.merge(df_hammond[['entry_id', 'Ref_ID', 'biome', 'year_disturbance', 'species', 'lon', 'lat']], how = 'left', on ='entry_id')
    df_joint = df_joint.merge(df_climate, how = 'left', on = ['Ref_ID', 'lon', 'lat'])
    df_joint = df_joint.merge(df_recovery, how = 'left', on = ['entry_id', 'true_pos_neg', 'VI'])
    df_joint = df_joint.merge(df_resistance, how = 'left')
    #Make kt_trend categorical
    df_joint['kt_trend'] = df_joint['kt_trend'].astype('category')
    #Make an extra column with absolute latitude
    df_joint['abs_lat'] = df_joint['lat'].abs()
    #Drop duplicated values
    df_joint = df_joint.drop_duplicates().reset_index(drop=True)

    #Keep only true_pos and true_neg_ratio
    df_joint = df_joint[df_joint['true_pos_neg'].isin(['true_pos', 'true_neg_ratio'])].reset_index(drop=True)

    return df_joint


### XGBOOST MODEL TOOLS ###
def model_performance(model, X_train, y_train, X_test, y_test, vi, ews):
    # Get R2 and RMSE on train data
    y_train_pred = model.predict(X_train)
    r2_train = model.score(X_train, y_train)
    rmse_train = np.sqrt(np.mean((y_train - y_train_pred) ** 2))
    #Same on test data
    y_test_pred = model.predict(X_test)
    r2_test = model.score(X_test, y_test)
    rmse_test = np.sqrt(np.mean((y_test - y_test_pred) ** 2))
    #Return this as dataframe
    return pd.DataFrame({
        'VI': [vi],
        'EWS': [ews],
        'r2_train': [r2_train],
        'rmse_train': [rmse_train],
        'r2_test': [r2_test],
        'rmse_test': [rmse_test]
    })

#Function to compute feature importance and permutation importance
def importance(model, X_test, y_test, vi, ews):
    # Get feature importance
    importance = model.feature_importances_
    #print('importance done')
    
    # Get permutation importance
    perm_importance = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42)
    #print('permutation importance done')
    #Check lengths of different things

    return pd.DataFrame({
                    'VI' : vi, 
                    'EWS' : ews,
                    'feature': X_test.columns,
                    'importance': importance,
                    'permutation_importance_mean': perm_importance.importances_mean,
                    'permutation_importance_std': perm_importance.importances_std
                })

#Function to train the model and get performance and feature importance
def train_model(df, vi, ews, drivers, cat_drivers):
    #Function to train xgboost model
    #Get the data
    df_model = df.loc[(df.VI == vi) & (df.EWS == ews) & (df.true_pos_neg.isin(['true_pos', 'true_neg_ratio']))].dropna().reset_index(drop = True)
    #Make sure the drivers are in the DataFrame and in the correct format
    for col in drivers:
        if col not in cat_drivers:
            df_model[f'{col}'] = df_model[col].astype(float)
        else:
            df_model[f'{col}'] = df_model[col].astype('category')
    #Turn everything into categorical counts
    df_model_encoded = df_model.copy()
    for col in df_model_encoded.select_dtypes(['category']).columns:
        df_model_encoded[col] = df_model_encoded[col].cat.codes
    # Build formula from drivers list and cat_drivers
    X = df_model_encoded[[f'{col}' for col in drivers]]
    y = df_model_encoded['kt_stat']
    #print(X.shape, y.shape)
    #Split into train and test set
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    #Fit the model
    model = xgb.XGBRegressor(random_state=42, enable_categorical=False)
    model.fit(X_train, y_train)

    #Compute model performance
    performance_df = model_performance(model, X_train, y_train, X_test, y_test, vi, ews)

    #Compute feature importance
    importance_df = importance(model, X_test, y_test, vi, ews)

    return performance_df, importance_df, model, X_train, y_train, X_test, y_test, X, y

#Wrapper function to apply this per VI and EWS
def model_wrapper(df_joint, drivers, cat_drivers, model_out_folder, date_check, force_recompute):

    #If the output folder does not exist, create it
    if not os.path.exists(model_out_folder):
        os.makedirs(model_out_folder)
    #Make a subfolder in there that is called models
    model_subfolder = os.path.join(model_out_folder, 'models')
    if not os.path.exists(model_subfolder):
        os.makedirs(model_subfolder)
    
    #If we don't already have the model list in pickle, continue with computation
    if os.path.exists(os.path.join(model_subfolder, f'{date_check}_model_list.pkl')) and not force_recompute:
        print('Model list alrady computed, not running anything...')
        #Load the model list and X_train_list from pickle
        with open(os.path.join(model_subfolder, f'{date_check}_model_list.pkl'), 'rb') as f:
            model_list = pickle.load(f)
        with open(os.path.join(model_subfolder, f'{date_check}_X_train_list.pkl'), 'rb') as f:
            X_train_list = pickle.load(f)
        return model_list, X_train_list
    else:

        print('Training models for each VI and EWS combination...')
        #Make one per EWS and VI
        model_list = []
        importance_list = []
        performance_list = []
        X_train_list = []
        y_train_list = []
        X_test_list = []
        y_test_list = []
        X_list = []
        y_list = []
        for vi in tqdm(df_joint.VI.unique(), desc='VI loop', total = len(df_joint.VI.unique())):
            for ews in df_joint.EWS.unique():
                print(vi, ews)
                #Check if the model already exists, if so skip
                if os.path.exists(os.path.join(model_subfolder, f'{date_check}_model_{vi}_{ews}.pkl')) and not force_recompute:
                    print(f'Model for {vi} and {ews} already exists, skipping...')

                    continue
                else:
                    performance_df, importance_df, model, X_train, y_train, X_test, y_test, X, y = train_model(df_joint, vi, ews, drivers, cat_drivers)
                    #Append the results to the lists
                    performance_list.append(performance_df)
                    importance_list.append(importance_df)
                    model_list.append(model)
                    X_train_list.append(X_train)
                    y_train_list.append(y_train)
                    X_test_list.append(X_test)
                    y_test_list.append(y_test)
                    X_list.append(X)
                    y_list.append(y)
                    #Save the model
                    with open(os.path.join(model_subfolder, f'{date_check}_model_{vi}_{ews}.pkl'), 'wb') as f:
                        pickle.dump(model, f)
                    #Save X_train as well
                    with open(os.path.join(model_subfolder, f'{date_check}_X_train_{vi}_{ews}.pkl'), 'wb') as f:
                        pickle.dump(X_train, f)

        #Make importance list into one df and save to feather
        importance_df = pd.concat(importance_list, axis=0).reset_index(drop=True)
        importance_df.to_feather(os.path.join(model_out_folder, f'{date_check}_importance.feather'))
        performance_df = pd.concat(performance_list, axis=0).reset_index(drop=True)
        performance_df.to_feather(os.path.join(model_out_folder, f'{date_check}_performance.feather'))

        #Dump the model list and X_train_list to pickle
        with open(os.path.join(model_subfolder, f'{date_check}_model_list.pkl'), 'wb') as f:
            pickle.dump(model_list, f)
        with open(os.path.join(model_subfolder, f'{date_check}_X_train_list.pkl'), 'wb') as f:
            pickle.dump(X_train_list, f)
        #Delete the single models from memory to save space
        for f in os.listdir(model_subfolder):
            if f.endswith('.pkl') and 'VI' in f:
                os.remove(os.path.join(model_subfolder, f))

    return model_list, X_train_list

#SHAP analysis
def shap_analysis(model_list, X_train_list, model_out_folder, date_check):
    print('Running SHAP analysis...')
    #Check if the SHAP subfolder exists, if not create it
    shap_subfolder = os.path.join(model_out_folder, 'shap')
    if not os.path.exists(shap_subfolder):
        print('Creating SHAP subfolder...')
        os.makedirs(shap_subfolder)
    #For each model, compute the SHAP values
    explainer_list = []
    shap_values_list = []
    for i, model in enumerate(model_list):
        print(i)
        #Use the first model to create the explainer
        explainer = shap.TreeExplainer(model, X_train_list[i]) #, feature_perturbation='tree_path_dependent')
        #Compute the SHAP values
        shap_values = explainer(X_train_list[i])
        #Append to the list
        explainer_list.append(explainer)
        shap_values_list.append(shap_values)

    #Save these to file
    with open(os.path.join(shap_subfolder, f'{date_check}_explainer_list.pkl'), 'wb') as f:
        pickle.dump(explainer_list, f)
    with open(os.path.join(shap_subfolder, f'{date_check}_shap_values_list.pkl'), 'wb') as f:
        pickle.dump(shap_values_list, f)
    #Save X_train_list to pickle
    with open(os.path.join(shap_subfolder, f'{date_check}_X_train_list.pkl'), 'wb') as f:
        pickle.dump(X_train_list, f)

def main(hammond_data_path, ews_folder, climate_data_path, model_out_folder, date_check, force_recompute=False):

    #Check if we already have the final product, ie the shap values
    shap_subfolder = os.path.join(model_out_folder, 'shap')
    if os.path.exists(os.path.join(shap_subfolder, f'{date_check}_shap_values_list.pkl')) and not force_recompute:
        print('SHAP values already computed, nothing to recompute...')
        return
    else:
        print('SHAP values not found, proceeding with computation...')

        #Load all input data
        df_joint = load_input_data(hammond_data_path, ews_folder, climate_data_path)
        print(df_joint.head())

        #Define drivers and categorical drivers
        drivers = ['rolling_window', 'Setup',  'aet_year_1', 'def_year_1',  'soil_year_1', 'srad_year_1','tmax_year_1',
            'vpd_year_1', 'PDSI_year_1','r', 'resistance_sd_ratio', 'abs_lat'] #, 'true_pos_neg']
        #List of categorical drivers
        cat_drivers = ['rolling_window', 'Setup', 'true_pos_neg']

        #Run the model wrapper
        model_list, X_train_list = model_wrapper(df_joint, drivers, cat_drivers, model_out_folder, date_check, force_recompute)

        #Run SHAP analysis
        shap_analysis(model_list, X_train_list, model_out_folder, date_check)


if __name__ == "__main__":
    #Get command line arguments
    hammond_data_path = sys.argv[1] if len(sys.argv) > 1 else "./data/intermediary/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adjusted.csv"
    ews_folder = sys.argv[2] if len(sys.argv) > 2 else "./data/final/MOD13Q1_ews_resistance_recovery"
    climate_data_path = sys.argv[3] if len(sys.argv) > 3 else "./data/raw/terra_climate_data/hammond_climate_data_terraclimate.csv"
    model_out_folder = sys.argv[4] if len(sys.argv) > 4 else "./data/final/explainer_model_output"
    date_check = sys.argv[5] if len(sys.argv) > 5 else datetime.today().strftime('%y_%m_%d')
    force_recompute = bool(int(sys.argv[6])) if len(sys.argv) > 6 else False

    #Load all input data
    main(hammond_data_path, ews_folder, climate_data_path, model_out_folder, date_check, force_recompute)