25-09-ews-assessment / code / 02_empirical_analysis / 01_data_preprocessing.py
01_data_preprocessing.py
Raw
"""
Stack the downloaded MODIS tiles over time, clipped to a 25km radius around the true positive locations from Hammond (2022).

This script checks if the files have been downloaded correctly, and if not, re-downloads them. 
It converts the Hammond coordinates to MODIS projections and clips the rasters accordingly.
It saves the clipped and stacked rasters as zarr files for each entry in the Hammond dataset.

This script requires a local copy of the downloaded MODIS files in hdf format, as obtained from 00_data_download_earthaccess.py.
The setup uses multiprocessing to speed the process up, using per default 30 parallel processes
It can be run in test mode for only 2 locations, or in full mode for all locations.

NOTE: the hdf4 files are an old format and can be tricky to open. This script uses rioxarray which is based on GDAL. This might lead to issues.

This script takes around 20-30 seconds per location.
NOTE: if it seems to take much longer than this, kill the process, delete the created empty zarr files in the output folder, and run again.

Usage: python3 ./code/02_empirical_analysis/01_data_preprocessing.py hammond_path modis_hdf_dir output_dir_preprocessed test_mode

hammond_path: Path to the Hammond CSV file (default: "./data/raw/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017.csv")
modis_hdf_dir: Directory where the downloaded MODIS hdf files are stored (default: "./data/raw/MOD13Q1")
output_dir_preprocessed: Directory to save preprocessed zarr files (default: "./data/intermediary/MOD13Q1_preprocessed_entry")
test_mode: If set to "test" (vs "full"), only process 5 locations for testing purposes (default: "test")
"""

import xarray as xr
import numpy as np
import os 
import geopandas as gpd 
import rioxarray as rxr
import earthpy
import earthaccess
from shapely import Point
import cartopy.crs as ccrs
from pyproj import Transformer
import geopandas as gpd
from datetime import datetime, timedelta
from tqdm.autonotebook import tqdm
from pathos.pools import ProcessPool
import pandas as pd
from dask.diagnostics import ProgressBar
from collections import defaultdict
import sys


#Replace print with a time stamp print
_print = print
def print(*args, **kwargs):
    import datetime
    _print(datetime.datetime.now(), *args, **kwargs)

#Reprojection and buffering function
def reproject_bounds(lon, lat, target_crs):
    #Get the crs and transform
    transformer = Transformer.from_crs('EPSG:4326',target_crs, always_xy=True)
    # Transform the coordinates
    x, y = transformer.transform(lon, lat)
    # Create a 25 km buffer
    buffer_geom = Point(x, y).buffer(25_000)
    # Get bounding box (minx, miny, maxx, maxy)
    minx, miny, maxx, maxy = buffer_geom.bounds
    return x, y, minx, miny, maxx, maxy

#Function to open and preprocess one file in parallel
def open_preprocess_parallelized(entry_id, file_path, minx, miny, maxx, maxy, output_path):
    #Check if this exists, if so, skip
    if os.path.exists(os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr")):
        return os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr")
    else:
        #Load file
        #Test if this works, if not, download it again
        try:
            file = rxr.open_rasterio(os.path.join(file_path), chunks = {'x':100, 'y': 100}, 
                                variable = ['250m 16 days NDVI', '250m 16 days EVI', '250m 16 days pixel reliability'])
        except:
            #Delete the existing file and re-download it
            print('File could not be opened, re-downloading:', file_path)
            os.remove(file_path)
            print(file_path)
            file = earthaccess.download([f'https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/MOD13Q1.061/{file_path.split("/")[-1].replace(".hdf", "")}/{file_path.split("/")[-1]}'], os.path.dirname(file_path), threads = 16)
            file = rxr.open_rasterio(os.path.join(file_path), chunks = {'x':100, 'y': 100},
                                variable = ['250m 16 days NDVI', '250m 16 days EVI', '250m 16 days pixel reliability'])
        
        #Get direction of y-axis
        y_dir = file.y.values[1] - file.y.values[0]
        if y_dir > 0:
            #Clip raster to this area
            file_clip = file.sel(x = slice(minx, maxx), y = slice(miny, maxy))
        else:
            file_clip = file.sel(x = slice(minx, maxx), y = slice(maxy, miny))

        #Rename NDVI and SummaryQA
        file_clip = file_clip.rename({'250m 16 days NDVI' : 'NDVI', 
                                        '250m 16 days pixel reliability' : 'SummaryQA', 
                                        '250m 16 days EVI' : 'EVI'})[['NDVI', 'EVI', 'SummaryQA']].squeeze(dim = 'band').drop_vars('band')
        #Rescale NDVI and EVI
        file_clip['NDVI'] = file_clip.NDVI*0.0001
        file_clip['EVI'] = file_clip.EVI*0.0001
        #Mask out low-quality values
        file_clip = file_clip.where(file_clip.SummaryQA.isin([0, 1]))

        #Add time dimension
        # Define a specific time value
        time_string = file_clip.attrs['LOCALGRANULEID'].split('.')[1]
        year = time_string[1:5]
        doy = time_string[5:8]
        time_value = datetime(int(year), 1, 1) + timedelta(int(doy) - 1)

        # Add the time coordinate
        file_clip = file_clip.assign_coords(time=time_value)

        #Save to temporary file
        file_clip.chunk({'x': 100, 'y': 100}).to_zarr(os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr"), mode = 'w', consolidated = False)
        return os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr")


def wrapper_open_preprocess_parallelized(entry_id, lon, lat, minx, miny, maxx, maxy, out_dir, out_dir_preprocessed):
    #Create temp folder for clipped files if this does not exist yet
    out_dir_preprocessed_parallelized = os.path.join(out_dir_preprocessed, 'tmp_entry')
    if not os.path.exists(out_dir_preprocessed_parallelized):
        os.makedirs(out_dir_preprocessed_parallelized)

    #Get subfiles for one location
    #Check if this entry_id already exists
    if os.path.exists(f"{out_dir_preprocessed}/entry_{entry_id}.zarr"):
        print(f"Entry {entry_id} already exists, skipping")
        return
    else:
        print(f"Processing entry {entry_id}")
        #Create empty zarr to block other processes from working on this
        with open(f"{out_dir_preprocessed}/entry_{entry_id}.zarr", 'w') as f:
            pass

    #Get list of existing hdf files
    print("Getting list of existing files")
    files = os.listdir(out_dir)
    files = [f for f in files if f.endswith('.hdf')]

    #Get relevant locations
    print("Searching relevant MODIS tiles")
    results = earthaccess.search_data(
        short_name='MOD13Q1',
        bounding_box=(lon - 0.25, lat - 0.25, lon + 0.25, lat + 0.25),
        temporal=('2002-01-01', '2002-03-02'),
        cloud_hosted=True
    )

    #Get unique location strings
    print("Getting unique locations")
    locs = [r.data_links()[0].split('/')[-1].split('.')[2] for r in results]
    locs = np.unique(locs)
    print(f"This point has {len(locs)} unique locations")
    print(locs)

    #Get list of all those files in the folder that include this location
    loc_files = [f for f in files if f.split('.')[2] in locs]
    print(f"Found {len(loc_files)} files for this location")
    #Sort by date
    loc_files = sorted(loc_files, key = lambda x: x.split('/')[-1].split('.')[1])

    #Check dimensions of all tiles and consider which ones to include - some of them might have been included because of their latitudinal overlap, but are not actually relevant 
    #when we look at the coordinates in meters
    keep_locs = []
    for l in locs:
        print('testing location ', l)
        #Get path to one entry
        first_file = [os.path.join(out_dir, f) for f in loc_files if l in f]
        #Open file
        test_file = xr.open_dataset(open_preprocess_parallelized(id, first_file[1], minx, miny, maxx, maxy, out_dir), chunks = {'x': 100, 'y': 100}, consolidated = False)
        #Check if any dimension is 0
        if ((len(test_file.x.values) == 0) or (len(test_file.y.values) == 0)):
            print(f"Location {l} has 0 dimensions, skipping")
        else:
            keep_locs.append(l)

    print('After testing, we keep the following locations:', keep_locs)
    #In the end, keep only those files that are in the relevant location
    loc_files = [f for f in loc_files if f.split('.')[2] in keep_locs]

    #Add full path
    loc_files_path = [os.path.join(out_dir, f) for f in loc_files]
    #Check if there is less than 556 files, skip this one, clearly not everything is downloaded yet
    if len(loc_files_path) < 550:
        print(f"Less than 550 files found, skipping this entry")
        return

    #Set up multiprocessing
    print("running multiprocessing...")
    p = ProcessPool(nodes=30)
    #Iterate over the entries
    prep_files = p.map(open_preprocess_parallelized, 
                        [entry_id]*len(loc_files_path),
                        loc_files_path, 
                        [minx]*len(loc_files_path), 
                        [miny]*len(loc_files_path), 
                        [maxx]*len(loc_files_path), 
                        [maxy]*len(loc_files_path), 
                        [out_dir_preprocessed_parallelized]*len(loc_files_path))
    
    #Load, stack in time, and resave
    print("loading clipped files...")
    if len(keep_locs) == 1:
        df = xr.open_mfdataset(prep_files, combine = 'nested', concat_dim='time', chunks = {'x':100, 'y': 100}, consolidated = False)
    else: 
        print('still more than one location, complex concatenation')
        # Step 1: group Zarr paths by their time value
        groups = defaultdict(list)

        for path in prep_files:
            ds = xr.open_zarr(path, consolidated=False, chunks = {'x': 1000, 'y': 1000})
            
            # Extract time (assumes each dataset has only one time step)
            time_val = ds.time.values

            #Check if the length of any dimension here is 0
            if ((len(ds.x.values) == 0) or (len(ds.y.values) == 0)):
                print('skipping', path)
            else:
                groups[time_val].append(path)
            
        # Step 2: spatially merge per time step
        mosaics = []

        for time_val, group_paths in sorted(groups.items()):
            tiles = [xr.open_zarr(p, consolidated=False, chunks = {'x': 1000, 'y': 1000}) for p in group_paths]
            
            # Spatial merge (along x/y), keeping time as-is
            merged = xr.combine_by_coords(tiles, combine_attrs="drop")
            
            mosaics.append(merged)

        # Step 3: concatenate all time slices
        df = xr.concat(mosaics, dim="time")


    #Limit to the relevant time period
    df = df.sel(time = slice('2001-01-01', '2025-01-01'))
    #Save to zarr
    print("saving to zarr...")
    #Delete the existing zarr
    if os.path.exists(f"{out_dir_preprocessed}/entry_{int(entry_id)}.zarr"):
        os.remove(f"{out_dir_preprocessed}/entry_{int(entry_id)}.zarr")
    with ProgressBar():
        df.chunk({'x': 100, 'y': 100}).to_zarr(f"{out_dir_preprocessed}/entry_{int(entry_id)}.zarr", mode='w', consolidated = False)
    print("done :)")



def main(hammond_path, out_dir, out_dir_preprocessed, test_mode):
    #Login to earthaccess
    print('Logging in to Earthdata...')
    earthaccess.login(strategy = 'interactive', persist = True)

    print('Loading Hammond dataset...')
    #Load forest dieback dataset (Hammond, 2022)
    df_hammond = pd.read_csv(hammond_path).drop_duplicates()
    #Drop collapse events pre 1984
    df_hammond = df_hammond.loc[df_hammond.time > 1984].reset_index(drop=True)
    #Add an entry_id to this
    df_hammond['entry_id'] = df_hammond.index
    #For compatibility, also add a paired_id column
    df_hammond['paired_id'] = df_hammond.entry_id
    #Rename for consistency
    df_hammond = df_hammond.rename(columns={'time':'year_disturbance', 'long':'lon'})
    df_hammond['true_pos_neg'] = 'true_pos'
    #Remove all those with collapse pre-2001
    df_hammond = df_hammond.loc[df_hammond.year_disturbance > 2001].reset_index(drop=True)

    #If we run this in test mode, only keep the first 5 locations
    if test_mode == "test":
        df_hammond = df_hammond.sample(5, random_state = 42).reset_index(drop=True)
        print("Running in test mode, only processing 5 locations.")

    #Check if the output folder exists, if not, create it
    if not os.path.exists(out_dir_preprocessed):
        print("Creating output folder:", out_dir_preprocessed)
        os.makedirs(out_dir_preprocessed)
    
    #Project to correct csr - load one file to get the target crs
    with os.scandir(out_dir) as entries:
        for entry in entries:
            if entry.is_file():
                first_file_path = entry.path
                break
    
    #Apply to all rows in df_hammond
    print('Reprojecting Hammond coordinates to MODIS projection...')
    r1 = rxr.open_rasterio(first_file_path)
    target_crs = r1.rio.crs
    print('Target crs is:', target_crs)
    #Apply
    df_hammond['x'], df_hammond['y'], df_hammond['minx'], df_hammond['miny'], df_hammond['maxx'], df_hammond['maxy'] = zip(*df_hammond.apply(lambda x: reproject_bounds(x.lon, x.lat, target_crs), axis = 1))


    #Loop through the rows of the Hammond dataset and apply
    for i in tqdm(range(len(df_hammond)), desc = 'Processing Hammond locations', total = len(df_hammond)):
        wrapper_open_preprocess_parallelized(df_hammond.loc[i, 'entry_id'], df_hammond.loc[i, 'lon'], df_hammond.loc[i, 'lat'], df_hammond.loc[i, 'minx'], df_hammond.loc[i, 'miny'], df_hammond.loc[i, 'maxx'], df_hammond.loc[i, 'maxy'], out_dir, out_dir_preprocessed)
    print("done :)")

if __name__ == '__main__':

    #Extract arguments
    hammond_path = sys.argv[1] if len(sys.argv) > 1 else "./data/raw/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017.csv"
    modis_hdf_dir = sys.argv[2] if len(sys.argv) > 2 else "./data/raw/MOD13Q1"
    output_dir_preprocessed = sys.argv[3] if len(sys.argv) > 3 else "./data/intermediary/MOD13Q1_preprocessed_entry"
    test_mode = sys.argv[4] if len(sys.argv) > 4 else "test"  # "test" or "full"

    #Local run
    #hammond_path = "/home/nielja/USBdisc/data/on_the_ground/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017.csv"
    #modis_hdf_dir = '/home/nielja/USBdisc/data/25_03_27_MOD13Q1_groundtruthing'
    #output_dir_preprocessed = '/home/nielja/USBdisc/data/25_03_27_MOD13Q1_groundtruthing_preprocessed_entry'

    #Run everything
    main(hammond_path = hammond_path,
         out_dir = modis_hdf_dir,
        out_dir_preprocessed = output_dir_preprocessed, 
        test_mode = test_mode)