25-09-ews-assessment / code / 02_empirical_analysis / 03_true_negative_selection.py
03_true_negative_selection.py
Raw
"""
Script to identify control points (true negatives) for the Hammond dataset based on true positive locations.

The true negatives are selected based on three criteria:
1. Minimum number of outliers in the time series before the disturbance year (outliers defined as absolute z-score > 2)
2. Minimum resistance (difference between mean in disturbance year and mean of three previous years)
3. Minimum mean absolute error (MAE) compared to the true positive time series before the disturbance year

The script extracts the time series for the true positive and the three true negative points and saves them to feather files.
It also saves the locations of the true positive and true negative points to feather files.

NOTE: This script runs for all locations that were previously processed, i.e., if you have run the other scripts in
test mode, it will only run for those locations without throwing an error message.

Usage: python3 ./code/02_empirical_analysis/03_true_negative_selection.py hammond_path zarr_input_folder output_folder

hammond_path: Path to the Hammond CSV file (default: "./data/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adjusted.csv")
zarr_input_folder: Directory where the preprocessed zarr files are stored (default: "./data/MOD13Q1_preprocessed_entry")
output_folder: Directory to save the extracted time series and points (default: "./data/MOD13Q1_extracted_timeseries/")
"""

import xarray as xr
import numpy as np
import pandas as pd
import os
import rioxarray
from pyproj import Transformer
from shapely import Point
from dask.diagnostics import ProgressBar
from tqdm.autonotebook import tqdm
import sys

#Deactivate xarray update warnings
xr.set_options(use_new_combine_kwarg_defaults=True)

#Function to identify the number of outlier points in the time series before the collapse year
def timeseries_outliers(df, collapse_year):
    print('Compute outliers')
    # 1. Resample to annual means
    da_annual = df.resample(time='YE').mean()

    # 2. Compute global mean and std across time
    mean = da_annual.mean(dim='time')
    std = da_annual.std(dim='time')

    # 3. Calculate z-score
    z_score = (da_annual - mean) / std

    # 4. Flag outliers: absolute z > 2
    outliers = xr.where(np.abs(z_score) > 2, 1, 0)

    # 5. Mask to only include years <= collapse_year
    collapse_cutoff = pd.Timestamp(f'{collapse_year}-12-31')
    outliers_before_collapse = outliers.where(outliers.time <= collapse_cutoff, 0)

    # 6. Count outliers for each pixel
    n_outliers = outliers_before_collapse.sum(dim='time')
    return n_outliers.NDVI.rename('n_outliers')


#Function to compute resistance to collapse
def resistance(df, collapse_year):
    print("Computing resistance")
    #Compute resistance: difference in the collapse year mean - mean of the three years previously
    df['resistance'] = df.NDVI.sel(time = slice(f'{collapse_year}-01-01', f'{collapse_year+1}-01-01')).mean(dim = 'time') - df.NDVI.sel(time = slice(f'{collapse_year-3}-01-01', f'{collapse_year}-01-01')).mean(dim = 'time')
    return df['resistance'].rename('resistance')

#Function to compute mean absolute error
def mae(x, y):
    return np.nanmean(np.abs(x - y))

#Function to compute mean absolute error compared to true positive time series
def compute_mae(df, x_true, y_true, collapse_year):
    print("Computing MAE")

    #Clip to just the values before the collapse
    df_pre = df.sel(time = slice('2001-01-01', f'{collapse_year}-01-01'))
    #Get the true time series
    df_true_pre = df_pre.sel(x = x_true, y = y_true, method='nearest')

    #Compute MAE
    mae_xr = xr.apply_ufunc(mae, df_pre.NDVI, 
                        input_core_dims = [['time']], 
                        output_core_dims = [[]], 
                        kwargs = {'y' : df_true_pre.NDVI.values},
                        vectorize = True, 
                        dask = 'parallelized')
    return mae_xr.rename('MAE')

#Function to combine the different criteria for true negative selection
def stack_tn_criteria(df, x_true, y_true, collapse_year):

    #Compute outliers
    df_n_bkpts = timeseries_outliers(df, collapse_year)

    #Compute resistance
    df_resistance = resistance(df, collapse_year)

    #Compute MAE
    df_mae = compute_mae(df, x_true, y_true, collapse_year)

    #Stack the three criteria
    df_stack = xr.merge([df_n_bkpts, df_resistance, df_mae])

    #Convert to dataframe
    print('Stacking and converting to df')
    df_stack = df_stack.to_dataframe().reset_index().dropna()

    #Compute ratio between resistance and MAE for minimum bkpts
    df_stack['ratio'] = abs(1/df_stack.resistance) / df_stack.MAE
    return df_stack

#Function to extract the time series for a given x, y coordinate
def ts_extraction(df, x, y, df_hammond_row):
    #Extract the time series
    df_ts = df.sel(x = x, y = y, method='nearest')[['NDVI', 'EVI', 'SummaryQA']].to_dataframe().reset_index()

    #Drop spatial_ref column and resistance column
    df_ts = df_ts.drop(columns = ['spatial_ref'])

    #Add all the other values of that dataframe row
    for col in df_hammond_row.index:
        if col not in ['x', 'y']:
            df_ts[col] = df_hammond_row[col]
    return df_ts

#Extraction pipeline for one entry
def extraction_pipeline_one_entry(id, 
                                df_hammond,
                                zarr_in_folder,
                                out_folder):
    #Pipeline to run the whole true negative computation and time series extraction for one entry
    print(f"Started processing id {id} in PID {os.getpid()}")

    #Check if path exists
    if os.path.exists(os.path.join(out_folder, f'entry_{id}_points.feather')):
        print(f"Entry {id} already exists, skipping")
        return os.path.join(out_folder, f'entry_{id}_points.feather'), os.path.join(out_folder, f'entry_{id}_timeseries.feather')
    else:
        print(f"Entry {id} does not exist, processing")
        #Create empty feather files to prevent another process from working on this
        pd.DataFrame().to_feather(os.path.join(out_folder, f'entry_{id}_points.feather'))
        pd.DataFrame().to_feather(os.path.join(out_folder, f'entry_{id}_timeseries.feather'))

        #Load the zarr file
        print('Loading zarr file')
        df = xr.open_zarr(os.path.join(zarr_in_folder, f'entry_{id}.zarr'), consolidated = False, chunks = {'x': 1000, 'y': 1000, 'time': -1})
        #Check if we have more than 600 time steps
        if len(df.time) > 600:
            #If so, skip this entry and give a warning
            print(f"Entry {id} has more than 600 time steps, skipping, there is something wrong here!")
            return os.path.join(out_folder, f'entry_{id}_points.feather'), os.path.join(out_folder, f'entry_{id}_timeseries.feather')

        #Clip df to 10km box
        print('Clipping to 10km box')
        #Get direction of axes
        if df.y[0] > df.y[1]:
            y_dir = -1
        else:
            y_dir = 1
        print(f"Y axis direction: {y_dir}")
        #Get direction of x-axis
        if df.x[0] > df.x[1]:
            x_dir = -1
        else:
            x_dir = 1
        print(f"X axis direction: {x_dir}")
        #Select the box
        if x_dir == 1:
            x_slice = slice(df_hammond.loc[df_hammond.entry_id == id, 'minx'].values[0], df_hammond.loc[df_hammond.entry_id == id, 'maxx'].values[0])
        else:
            x_slice = slice(df_hammond.loc[df_hammond.entry_id == id, 'maxx'].values[0], df_hammond.loc[df_hammond.entry_id == id, 'minx'].values[0])
        if y_dir == 1:
            y_slice = slice(df_hammond.loc[df_hammond.entry_id == id, 'miny'].values[0], df_hammond.loc[df_hammond.entry_id == id, 'maxy'].values[0])
        else:
            y_slice = slice(df_hammond.loc[df_hammond.entry_id == id, 'maxy'].values[0], df_hammond.loc[df_hammond.entry_id == id, 'miny'].values[0])

        #Apply this
        df = df.sel(x = x_slice, y = y_slice)
        #Also rechunk properly
        df = df.chunk({'x': 1000, 'y': 1000, 'time': -1})
        print(f"Data shape after clipping: {df.NDVI.shape}")

        #Get the x, y coordinates of the true positive
        print('TP coordinates')
        x_true = df_hammond.loc[df_hammond.entry_id == id, 'x'].values[0]
        y_true = df_hammond.loc[df_hammond.entry_id == id, 'y'].values[0]
        collapse_year = df_hammond.loc[df_hammond.entry_id == id, 'year_disturbance'].values[0]

        #Compute TN
        print('Computing TN')
        df_tn = stack_tn_criteria(df, x_true, y_true, collapse_year)

        #Extract the different points
        print('Extracting points')
        #Keep only points with minimum number of breakpoints
        df_tn = df_tn.loc[df_tn.n_outliers == df_tn.n_outliers.min()].reset_index(drop=True)
        #Remove the point with MAE == 0 because this clearly is the original point
        df_tn = df_tn.loc[df_tn.MAE != 0].reset_index(drop=True)
        #Get the coordinates of the point with the maximum ratio
        x_tn_ratio, y_tn_ratio = df_tn.loc[df_tn.ratio == df_tn.ratio.max(), ['x', 'y']].values[0]
        #Get the coordinates of the point with minimum MAE
        x_tn_mae, y_tn_mae = df_tn.loc[df_tn.MAE == df_tn.MAE.min(), ['x', 'y']].values[0]
        #Get the coordinates of the point with minimum resistance
        x_tn_res, y_tn_res = df_tn.loc[abs(df_tn.resistance) == abs(df_tn.resistance).min(), ['x', 'y']].values[0]


        #Extract true positive time series
        print('Extracting TP timeseries')
        dtrue = ts_extraction(df, x_true, y_true, df_hammond.loc[df_hammond.entry_id == id].iloc[0])

        #Make empty list to save the tn points to
        new_point_list = []

        #Save the TN points as additional row to the dataframe
        print('Extracting TN ratio')
        new_point = df_hammond.loc[df_hammond.entry_id == id].iloc[0].copy()
        new_point['x'] = x_tn_ratio
        new_point['y'] = y_tn_ratio
        new_point['true_pos_neg'] = 'true_neg_ratio'
        #Extract ts
        d_tn_ratio = ts_extraction(df, x_tn_ratio, y_tn_ratio, new_point)

        #Save the point based on MAE
        print('Extracting TN MAE')
        new_point_mae = df_hammond.loc[df_hammond.entry_id == id].iloc[0].copy()
        new_point_mae['x'] = x_tn_mae   
        new_point_mae['y'] = y_tn_mae
        new_point_mae['true_pos_neg'] = 'true_neg_mae'
        d_tn_mae = ts_extraction(df, x_tn_mae, y_tn_mae, new_point_mae)

        #Save the point based on resistance
        print('Extracting TN resistance')
        new_point_res = df_hammond.loc[df_hammond.entry_id == id].iloc[0].copy()
        new_point_res['x'] = x_tn_res
        new_point_res['y'] = y_tn_res
        new_point_res['true_pos_neg'] = 'true_neg_resistance'
        d_tn_res = ts_extraction(df, x_tn_res, y_tn_res, new_point_res)

        #Add the new points to a list to later concatenate
        new_point_list.append(new_point)
        new_point_list.append(new_point_mae)    
        new_point_list.append(new_point_res)

        #Concatenate the time series as well
        print('Concatenating time series')
        d_vals = pd.concat([dtrue, d_tn_ratio, d_tn_mae, d_tn_res], axis = 0)

        #Save them as feather files and return paths
        print('Saving to feather')
        new_point_list = pd.DataFrame(new_point_list)
        new_point_list.to_feather(os.path.join(out_folder, f'entry_{id}_points.feather'))
        d_vals.to_feather(os.path.join(out_folder, f'entry_{id}_timeseries.feather'))  
        #Return the paths as well
        return os.path.join(out_folder, f'entry_{id}_points.feather'), os.path.join(out_folder, f'entry_{id}_timeseries.feather')

#Function to load and preprocess Hammond, and run the extraction pipeline for all entries
def main(hammond_path, zarr_in_folder, out_folder):

    #Load hammond dataset
    print('Loading dataset')
    df_hammond = pd.read_csv(hammond_path)

    #Check if out_folder exists, if not, create it
    if not os.path.exists(out_folder):
        os.makedirs(out_folder)
    
    #Add 10km buffer to the coordinates 
    #NOTE: double check this with direction of y-axis
    print('Adding 10km buffer')
    df_hammond['minx'] = df_hammond['x'] - 10_000
    df_hammond['miny'] = df_hammond['y'] - 10_000
    df_hammond['maxx'] = df_hammond['x'] + 10_000
    df_hammond['maxy'] = df_hammond['y'] + 10_000

    #For those entry_id values, where we have both true_pos and true_pos_adj, drop the true_pos entries
    print('Filtering out low resolution entries')
    len_orig = len(df_hammond)
    ids_adj = df_hammond.groupby('entry_id').filter(lambda x: len(x) > 1).entry_id.unique()
    #For those entries, filter out the true_pos
    df_hammond = df_hammond.loc[~((df_hammond.entry_id.isin(ids_adj)) & (df_hammond.true_pos_neg == 'true_pos'))].reset_index(drop=True)
    print(f"Filtered out {len_orig - len(df_hammond)} entries with true_pos")
    
    #Parallelize application of extraction_pipeline_one_entry
    print('Application')
    point_list = []
    ts_list = []
    for i in tqdm(df_hammond.entry_id.values, total = len(df_hammond.entry_id.values)):
        #Apply the function to each entry
        point_path, ts_path = extraction_pipeline_one_entry(i, df_hammond, zarr_in_folder = zarr_in_folder,
                                    out_folder = out_folder)
        #Append all these
        point_list.append(point_path)
        ts_list.append(ts_path)
    #Concatenate all the points and time series
    print('Concatenating all points and time series')
    df_points = pd.concat([pd.read_feather(p) for p in point_list], ignore_index=True).reset_index(drop=True)
    #Concat the points with df_hammond
    df_points = pd.concat([df_points, df_hammond], ignore_index=True).reset_index(drop=True)
    df_timeseries = pd.concat([pd.read_feather(p) for p in ts_list], ignore_index=True).reset_index(drop=True)
    #Save the concatenated dataframes
    df_points.to_feather(os.path.join(out_folder, 'all_points.feather'))
    df_timeseries.to_feather(os.path.join(out_folder, 'all_timeseries.feather'))
    print('done :)')

if __name__ == "__main__":
    #Get command line arguments
    hammond_path = sys.argv[1] if len(sys.argv) > 1 else "./data/intermediary/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017_with_true_pos_adjusted.csv"
    zarr_in_folder = sys.argv[2] if len(sys.argv) > 2 else "./data/intermediary/MOD13Q1_preprocessed_entry"
    out_folder = sys.argv[3] if len(sys.argv) > 3 else "./data/intermediary/MOD13Q1_extracted_timeseries/"

    #Call the main function
    main(hammond_path, zarr_in_folder, out_folder)