"""
Stack the downloaded MODIS tiles over time, clipped to a 25km radius around the true positive locations from Hammond (2022).
This script checks if the files have been downloaded correctly, and if not, re-downloads them.
It converts the Hammond coordinates to MODIS projections and clips the rasters accordingly.
It saves the clipped and stacked rasters as zarr files for each entry in the Hammond dataset.
This script requires a local copy of the downloaded MODIS files in hdf format, as obtained from 00_data_download_earthaccess.py.
The setup uses multiprocessing to speed the process up, using per default 30 parallel processes
It can be run in test mode for only 2 locations, or in full mode for all locations.
NOTE: the hdf4 files are an old format and can be tricky to open. This script uses rioxarray which is based on GDAL. This might lead to issues.
This script takes around 20-30 seconds per location.
NOTE: if it seems to take much longer than this, kill the process, delete the created empty zarr files in the output folder, and run again.
Usage: python3 ./code/02_empirical_analysis/01_data_preprocessing.py hammond_path modis_hdf_dir output_dir_preprocessed test_mode
hammond_path: Path to the Hammond CSV file (default: "./data/raw/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017.csv")
modis_hdf_dir: Directory where the downloaded MODIS hdf files are stored (default: "./data/raw/MOD13Q1")
output_dir_preprocessed: Directory to save preprocessed zarr files (default: "./data/intermediary/MOD13Q1_preprocessed_entry")
test_mode: If set to "test" (vs "full"), only process 5 locations for testing purposes (default: "test")
"""
import xarray as xr
import numpy as np
import os
import geopandas as gpd
import rioxarray as rxr
import earthpy
import earthaccess
from shapely import Point
import cartopy.crs as ccrs
from pyproj import Transformer
import geopandas as gpd
from datetime import datetime, timedelta
from tqdm.autonotebook import tqdm
from pathos.pools import ProcessPool
import pandas as pd
from dask.diagnostics import ProgressBar
from collections import defaultdict
import sys
#Replace print with a time stamp print
_print = print
def print(*args, **kwargs):
import datetime
_print(datetime.datetime.now(), *args, **kwargs)
#Reprojection and buffering function
def reproject_bounds(lon, lat, target_crs):
#Get the crs and transform
transformer = Transformer.from_crs('EPSG:4326',target_crs, always_xy=True)
# Transform the coordinates
x, y = transformer.transform(lon, lat)
# Create a 25 km buffer
buffer_geom = Point(x, y).buffer(25_000)
# Get bounding box (minx, miny, maxx, maxy)
minx, miny, maxx, maxy = buffer_geom.bounds
return x, y, minx, miny, maxx, maxy
#Function to open and preprocess one file in parallel
def open_preprocess_parallelized(entry_id, file_path, minx, miny, maxx, maxy, output_path):
#Check if this exists, if so, skip
if os.path.exists(os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr")):
return os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr")
else:
#Load file
#Test if this works, if not, download it again
try:
file = rxr.open_rasterio(os.path.join(file_path), chunks = {'x':100, 'y': 100},
variable = ['250m 16 days NDVI', '250m 16 days EVI', '250m 16 days pixel reliability'])
except:
#Delete the existing file and re-download it
print('File could not be opened, re-downloading:', file_path)
os.remove(file_path)
print(file_path)
file = earthaccess.download([f'https://data.lpdaac.earthdatacloud.nasa.gov/lp-prod-protected/MOD13Q1.061/{file_path.split("/")[-1].replace(".hdf", "")}/{file_path.split("/")[-1]}'], os.path.dirname(file_path), threads = 16)
file = rxr.open_rasterio(os.path.join(file_path), chunks = {'x':100, 'y': 100},
variable = ['250m 16 days NDVI', '250m 16 days EVI', '250m 16 days pixel reliability'])
#Get direction of y-axis
y_dir = file.y.values[1] - file.y.values[0]
if y_dir > 0:
#Clip raster to this area
file_clip = file.sel(x = slice(minx, maxx), y = slice(miny, maxy))
else:
file_clip = file.sel(x = slice(minx, maxx), y = slice(maxy, miny))
#Rename NDVI and SummaryQA
file_clip = file_clip.rename({'250m 16 days NDVI' : 'NDVI',
'250m 16 days pixel reliability' : 'SummaryQA',
'250m 16 days EVI' : 'EVI'})[['NDVI', 'EVI', 'SummaryQA']].squeeze(dim = 'band').drop_vars('band')
#Rescale NDVI and EVI
file_clip['NDVI'] = file_clip.NDVI*0.0001
file_clip['EVI'] = file_clip.EVI*0.0001
#Mask out low-quality values
file_clip = file_clip.where(file_clip.SummaryQA.isin([0, 1]))
#Add time dimension
# Define a specific time value
time_string = file_clip.attrs['LOCALGRANULEID'].split('.')[1]
year = time_string[1:5]
doy = time_string[5:8]
time_value = datetime(int(year), 1, 1) + timedelta(int(doy) - 1)
# Add the time coordinate
file_clip = file_clip.assign_coords(time=time_value)
#Save to temporary file
file_clip.chunk({'x': 100, 'y': 100}).to_zarr(os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr"), mode = 'w', consolidated = False)
return os.path.join(output_path, f"entry_{entry_id}_{file_path.split('/')[-1].replace('.hdf', '')}_clipped.zarr")
def wrapper_open_preprocess_parallelized(entry_id, lon, lat, minx, miny, maxx, maxy, out_dir, out_dir_preprocessed):
#Create temp folder for clipped files if this does not exist yet
out_dir_preprocessed_parallelized = os.path.join(out_dir_preprocessed, 'tmp_entry')
if not os.path.exists(out_dir_preprocessed_parallelized):
os.makedirs(out_dir_preprocessed_parallelized)
#Get subfiles for one location
#Check if this entry_id already exists
if os.path.exists(f"{out_dir_preprocessed}/entry_{entry_id}.zarr"):
print(f"Entry {entry_id} already exists, skipping")
return
else:
print(f"Processing entry {entry_id}")
#Create empty zarr to block other processes from working on this
with open(f"{out_dir_preprocessed}/entry_{entry_id}.zarr", 'w') as f:
pass
#Get list of existing hdf files
print("Getting list of existing files")
files = os.listdir(out_dir)
files = [f for f in files if f.endswith('.hdf')]
#Get relevant locations
print("Searching relevant MODIS tiles")
results = earthaccess.search_data(
short_name='MOD13Q1',
bounding_box=(lon - 0.25, lat - 0.25, lon + 0.25, lat + 0.25),
temporal=('2002-01-01', '2002-03-02'),
cloud_hosted=True
)
#Get unique location strings
print("Getting unique locations")
locs = [r.data_links()[0].split('/')[-1].split('.')[2] for r in results]
locs = np.unique(locs)
print(f"This point has {len(locs)} unique locations")
print(locs)
#Get list of all those files in the folder that include this location
loc_files = [f for f in files if f.split('.')[2] in locs]
print(f"Found {len(loc_files)} files for this location")
#Sort by date
loc_files = sorted(loc_files, key = lambda x: x.split('/')[-1].split('.')[1])
#Check dimensions of all tiles and consider which ones to include - some of them might have been included because of their latitudinal overlap, but are not actually relevant
#when we look at the coordinates in meters
keep_locs = []
for l in locs:
print('testing location ', l)
#Get path to one entry
first_file = [os.path.join(out_dir, f) for f in loc_files if l in f]
#Open file
test_file = xr.open_dataset(open_preprocess_parallelized(id, first_file[1], minx, miny, maxx, maxy, out_dir), chunks = {'x': 100, 'y': 100}, consolidated = False)
#Check if any dimension is 0
if ((len(test_file.x.values) == 0) or (len(test_file.y.values) == 0)):
print(f"Location {l} has 0 dimensions, skipping")
else:
keep_locs.append(l)
print('After testing, we keep the following locations:', keep_locs)
#In the end, keep only those files that are in the relevant location
loc_files = [f for f in loc_files if f.split('.')[2] in keep_locs]
#Add full path
loc_files_path = [os.path.join(out_dir, f) for f in loc_files]
#Check if there is less than 556 files, skip this one, clearly not everything is downloaded yet
if len(loc_files_path) < 550:
print(f"Less than 550 files found, skipping this entry")
return
#Set up multiprocessing
print("running multiprocessing...")
p = ProcessPool(nodes=30)
#Iterate over the entries
prep_files = p.map(open_preprocess_parallelized,
[entry_id]*len(loc_files_path),
loc_files_path,
[minx]*len(loc_files_path),
[miny]*len(loc_files_path),
[maxx]*len(loc_files_path),
[maxy]*len(loc_files_path),
[out_dir_preprocessed_parallelized]*len(loc_files_path))
#Load, stack in time, and resave
print("loading clipped files...")
if len(keep_locs) == 1:
df = xr.open_mfdataset(prep_files, combine = 'nested', concat_dim='time', chunks = {'x':100, 'y': 100}, consolidated = False)
else:
print('still more than one location, complex concatenation')
# Step 1: group Zarr paths by their time value
groups = defaultdict(list)
for path in prep_files:
ds = xr.open_zarr(path, consolidated=False, chunks = {'x': 1000, 'y': 1000})
# Extract time (assumes each dataset has only one time step)
time_val = ds.time.values
#Check if the length of any dimension here is 0
if ((len(ds.x.values) == 0) or (len(ds.y.values) == 0)):
print('skipping', path)
else:
groups[time_val].append(path)
# Step 2: spatially merge per time step
mosaics = []
for time_val, group_paths in sorted(groups.items()):
tiles = [xr.open_zarr(p, consolidated=False, chunks = {'x': 1000, 'y': 1000}) for p in group_paths]
# Spatial merge (along x/y), keeping time as-is
merged = xr.combine_by_coords(tiles, combine_attrs="drop")
mosaics.append(merged)
# Step 3: concatenate all time slices
df = xr.concat(mosaics, dim="time")
#Limit to the relevant time period
df = df.sel(time = slice('2001-01-01', '2025-01-01'))
#Save to zarr
print("saving to zarr...")
#Delete the existing zarr
if os.path.exists(f"{out_dir_preprocessed}/entry_{int(entry_id)}.zarr"):
os.remove(f"{out_dir_preprocessed}/entry_{int(entry_id)}.zarr")
with ProgressBar():
df.chunk({'x': 100, 'y': 100}).to_zarr(f"{out_dir_preprocessed}/entry_{int(entry_id)}.zarr", mode='w', consolidated = False)
print("done :)")
def main(hammond_path, out_dir, out_dir_preprocessed, test_mode):
#Login to earthaccess
print('Logging in to Earthdata...')
earthaccess.login(strategy = 'interactive', persist = True)
print('Loading Hammond dataset...')
#Load forest dieback dataset (Hammond, 2022)
df_hammond = pd.read_csv(hammond_path).drop_duplicates()
#Drop collapse events pre 1984
df_hammond = df_hammond.loc[df_hammond.time > 1984].reset_index(drop=True)
#Add an entry_id to this
df_hammond['entry_id'] = df_hammond.index
#For compatibility, also add a paired_id column
df_hammond['paired_id'] = df_hammond.entry_id
#Rename for consistency
df_hammond = df_hammond.rename(columns={'time':'year_disturbance', 'long':'lon'})
df_hammond['true_pos_neg'] = 'true_pos'
#Remove all those with collapse pre-2001
df_hammond = df_hammond.loc[df_hammond.year_disturbance > 2001].reset_index(drop=True)
#If we run this in test mode, only keep the first 5 locations
if test_mode == "test":
df_hammond = df_hammond.sample(5, random_state = 42).reset_index(drop=True)
print("Running in test mode, only processing 5 locations.")
#Check if the output folder exists, if not, create it
if not os.path.exists(out_dir_preprocessed):
print("Creating output folder:", out_dir_preprocessed)
os.makedirs(out_dir_preprocessed)
#Project to correct csr - load one file to get the target crs
with os.scandir(out_dir) as entries:
for entry in entries:
if entry.is_file():
first_file_path = entry.path
break
#Apply to all rows in df_hammond
print('Reprojecting Hammond coordinates to MODIS projection...')
r1 = rxr.open_rasterio(first_file_path)
target_crs = r1.rio.crs
print('Target crs is:', target_crs)
#Apply
df_hammond['x'], df_hammond['y'], df_hammond['minx'], df_hammond['miny'], df_hammond['maxx'], df_hammond['maxy'] = zip(*df_hammond.apply(lambda x: reproject_bounds(x.lon, x.lat, target_crs), axis = 1))
#Loop through the rows of the Hammond dataset and apply
for i in tqdm(range(len(df_hammond)), desc = 'Processing Hammond locations', total = len(df_hammond)):
wrapper_open_preprocess_parallelized(df_hammond.loc[i, 'entry_id'], df_hammond.loc[i, 'lon'], df_hammond.loc[i, 'lat'], df_hammond.loc[i, 'minx'], df_hammond.loc[i, 'miny'], df_hammond.loc[i, 'maxx'], df_hammond.loc[i, 'maxy'], out_dir, out_dir_preprocessed)
print("done :)")
if __name__ == '__main__':
#Extract arguments
hammond_path = sys.argv[1] if len(sys.argv) > 1 else "./data/raw/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017.csv"
modis_hdf_dir = sys.argv[2] if len(sys.argv) > 2 else "./data/raw/MOD13Q1"
output_dir_preprocessed = sys.argv[3] if len(sys.argv) > 3 else "./data/intermediary/MOD13Q1_preprocessed_entry"
test_mode = sys.argv[4] if len(sys.argv) > 4 else "test" # "test" or "full"
#Local run
#hammond_path = "/home/nielja/USBdisc/data/on_the_ground/global_tree_mortality_database/GTM_full_database_resolve_biomes_2017.csv"
#modis_hdf_dir = '/home/nielja/USBdisc/data/25_03_27_MOD13Q1_groundtruthing'
#output_dir_preprocessed = '/home/nielja/USBdisc/data/25_03_27_MOD13Q1_groundtruthing_preprocessed_entry'
#Run everything
main(hammond_path = hammond_path,
out_dir = modis_hdf_dir,
out_dir_preprocessed = output_dir_preprocessed,
test_mode = test_mode)