ARMED-MixedEffectsDL / adni_t1w / png_to_numpy_data_splits.py
png_to_numpy_data_splits.py
Raw
'''
Convert images into a numpy array, then divide into 70% train/10% val/20% test partitions.
'''

import os
import glob
import re
from PIL import Image
import numpy as np
import pandas as pd

from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import StratifiedShuffleSplit
from armed.misc import expand_data_path

strImageDir = expand_data_path('ADNI23_sMRI/right_hippocampus_slices_2pctnorm/coronal_MNI-6_qc/good')
strOutDir = expand_data_path('ADNI23_sMRI/right_hippocampus_slices_2pctnorm/coronal_MNI-6_numpy/12sites', 
                             make=True)
# 5 sites with GE scanners and majority AD, 7 sites with Philips/Siemens scanners and majority CN
lsSitesKept = [52, 5, 126, 57, 16, 2, 73, 100, 41, 22, 941, 20]
# Number of Monte Carlo random splits to generate
nSplits = 10 

dfImageInfo = pd.read_csv('image_list_ad_cn.csv', index_col=0)
dfImageInfo['RID'] = dfImageInfo['RID'].apply(lambda x: f'{int(x):04d}')
dfImageInfo['ScanDate'] = dfImageInfo['ScanDate'].apply(lambda x: x.replace('-', ''))
dfImageInfo.index = pd.MultiIndex.from_frame(dfImageInfo[['RID', 'ScanDate']])

lsImages = glob.glob(os.path.join(strImageDir, '*'))
lsImages.sort()

nImages = len(lsImages)
arrImages = np.zeros((nImages, 192, 192, 1), dtype=np.float32)
lsKeptImages = []

for i, strImagePath in enumerate(lsImages):
    match = re.search(r'sub-(\d+)_ses-(\d+)', strImagePath)
    strSub = match[1]
    strSes = match[2]
    lsKeptImages += [(strSub, strSes)]
    
    img = np.array(Image.open(strImagePath))
    arrImages[i, :, :, 0] = img / 255.
    
dfImageInfoKept = dfImageInfo.loc[lsKeptImages]
arrImagesIncludedSites = arrImages[dfImageInfoKept['Site'].isin(lsSitesKept),]
dfImageInfoIncludedSites = dfImageInfoKept.loc[dfImageInfoKept['Site'].isin(lsSitesKept)]
arrImagesExcludedSites = arrImages[~dfImageInfoKept['Site'].isin(lsSitesKept),]
dfImageInfoExcludedSites = dfImageInfoKept.loc[~dfImageInfoKept['Site'].isin(lsSitesKept)]
dfImageInfoExcludedSites.to_csv(os.path.join(strOutDir, 'data_unseen.csv'))

arrSites = dfImageInfoIncludedSites['Site'].values
arrLabels = (dfImageInfoIncludedSites['DX_Scan'].values == 'Dementia').astype(np.float32)
    
# Create one-hot design matrix encoding cluster membership
onehot = OneHotEncoder(sparse=False)
arrClusters = onehot.fit_transform(arrSites.reshape((-1, 1))).astype(np.float32)
arrSiteOrder = onehot.categories_[0]

# Split the subjects, stratifying by site and diagnosis. Ignore the number of
# images per subject, since there isn't enough data to stratify to that level of
# detail.
dfSubID = dfImageInfoIncludedSites.index.get_level_values(0)
arrSubID = dfSubID[~dfSubID.duplicated()].values
dfStrat = dfImageInfoIncludedSites.apply(lambda x: str(x['Site']) + '_' + x['DX_Scan'], axis=1)
arrStrat = dfStrat.loc[~dfSubID.duplicated()].values
arrSubID = arrSubID.astype(str)
arrStrat = arrStrat.astype(str)

# Check for subjects who switched diagnoses during the study
lsBadSubs = []
for sub in arrSubID:
    dfLabelsSub = dfImageInfoIncludedSites['DX_Scan'].loc[sub]
    if np.any(dfLabelsSub != dfLabelsSub.iloc[0]):
        print(sub, 'switches diagnosis:', dfLabelsSub.values)
        lsBadSubs += [sub]      
arrStrat = arrStrat[~np.isin(arrSubID, lsBadSubs)]
arrSubID = arrSubID[~np.isin(arrSubID, lsBadSubs)]

# Ignore any site-diagnosis combinations that have only 1 subject
lsBadStrats = []
for strat in np.unique(arrStrat):
    arrSubsStrat = arrSubID[arrStrat == strat,]
    if len(np.unique(arrSubsStrat)) < 2:
        print(strat, 'only has one subject')
        lsBadStrats += [strat]
arrSubID = arrSubID[~np.isin(arrStrat, lsBadStrats)]
arrStrat = arrStrat[~np.isin(arrStrat, lsBadStrats)]

testsplit = StratifiedShuffleSplit(n_splits=nSplits, test_size=0.2, random_state=32)
for iSplit, (arrTrainValIdx, arrTestIdx) in enumerate(testsplit.split(arrSubID, arrStrat)):
      print('===== Split', iSplit, '=====')
      arrSubIDTrainVal = arrSubID[arrTrainValIdx,]
      arrSubIDTest = arrSubID[arrTestIdx,]
      arrStratTrainVal = arrStrat[arrTrainValIdx,]

      arrImagesTest = arrImagesIncludedSites[dfSubID.isin(arrSubIDTest),]
      arrLabelsTest = arrLabels[dfSubID.isin(arrSubIDTest),]
      arrClustersTest = arrClusters[dfSubID.isin(arrSubIDTest),]
      arrImagesTrainVal = arrImagesIncludedSites[dfSubID.isin(arrSubIDTrainVal),]
      arrLabelsTrainVal = arrLabels[dfSubID.isin(arrSubIDTrainVal),]
      arrClustersTrainVal = arrClusters[dfSubID.isin(arrSubIDTrainVal),]

      print('Test set:', len(arrSubIDTest), 'subjects,', arrImagesTest.shape[0], 'images',
            f'{arrLabelsTest.mean()*100:.03f}% AD')

      valsplit = StratifiedShuffleSplit(n_splits=1, test_size=0.125, random_state=32)
      arrTrainIdx, arrValIdx = next(testsplit.split(arrSubIDTrainVal, arrStratTrainVal))
      arrSubIDTrain = arrSubIDTrainVal[arrTrainIdx,]
      arrSubIDVal = arrSubIDTrainVal[arrValIdx,]

      arrImagesVal = arrImagesIncludedSites[dfSubID.isin(arrSubIDVal),]
      arrLabelsVal = arrLabels[dfSubID.isin(arrSubIDVal),]
      arrClustersVal = arrClusters[dfSubID.isin(arrSubIDVal),]
      arrImagesTrain = arrImagesIncludedSites[dfSubID.isin(arrSubIDTrain),]
      arrLabelsTrain = arrLabels[dfSubID.isin(arrSubIDTrain),]
      arrClustersTrain = arrClusters[dfSubID.isin(arrSubIDTrain),]

      print('Val set:', len(arrSubIDVal), 'subjects,', arrImagesVal.shape[0], 'images',
            f'{arrLabelsVal.mean()*100:.03f}% AD')
      print('Train set:', len(arrSubIDTrain), 'subjects,', arrImagesTrain.shape[0], 'images',
            f'{arrLabelsTrain.mean()*100:.03f}% AD')

      # Unseen sites 
      arrLabelsUnseen = (dfImageInfoExcludedSites['DX_Scan'].values == 'Dementia').astype(np.float32)
      arrClustersUnseen = np.zeros((arrImagesExcludedSites.shape[0], arrSiteOrder.shape[0]), dtype=np.float32)
      print('Unseen sites:', len(dfImageInfoExcludedSites['RID'].unique()), 'subjects,', 
            arrImagesExcludedSites.shape[0], 'images',
            f'{arrLabelsUnseen.mean()*100:.03f}% AD')
      
      strSplitDir = os.path.join(strOutDir, f'split{iSplit:02d}')
      os.makedirs(strSplitDir, exist_ok=True)
      
      np.savez(os.path.join(strSplitDir, 'data_test.npz'), images=arrImagesTest, label=arrLabelsTest, 
            cluster=arrClustersTest, siteorder=arrSiteOrder)
      np.savez(os.path.join(strSplitDir, 'data_val.npz'), images=arrImagesVal, label=arrLabelsVal, 
            cluster=arrClustersVal, siteorder=arrSiteOrder)
      np.savez(os.path.join(strSplitDir, 'data_train.npz'), images=arrImagesTrain, label=arrLabelsTrain, 
            cluster=arrClustersTrain, siteorder=arrSiteOrder)
      np.savez(os.path.join(strSplitDir, 'data_unseen.npz'), images=arrImagesExcludedSites,
            label=arrLabelsUnseen, cluster=arrClustersUnseen, siteorder=arrSiteOrder)