ARMED-MixedEffectsDL / adni_t1w / main.py
main.py
Raw
'''
Main script for training and evaluating classifiers. Performs cross-validation 
using pregenerated data splits and prints the mean and 95% CI of model performance.

Example:
python main.py --model_type mixedeffects --data_dir /path/to/data/splits

See python main.py --help for arguments. 
'''
import os
import glob
import argparse
import numpy as np
import pandas as pd

import tensorflow as tf
from armed.models.cnn_classifier import ImageClassifier, ClusterInputImageClassifier, \
    DomainAdversarialImageClassifier, RandomEffectsClassifier, MixedEffectsClassifier
from armed.misc import expand_data_path, expand_results_path, make_random_onehot
from armed.metrics import classification_metrics

def get_model(model_type: str, n_clusters: int=None):
    tf.random.set_seed(2343)
    if model_type == 'conventional':    
        model = ImageClassifier()
        model.compile(optimizer=tf.keras.optimizers.Nadam(lr=0.0001),
                    loss=tf.keras.losses.BinaryCrossentropy(),
                    metrics=[tf.keras.metrics.AUC(name='auroc')])
         
    elif model_type == 'clusterinput':
        model = ClusterInputImageClassifier()
        model.compile(optimizer=tf.keras.optimizers.Nadam(lr=0.0001),
                    loss=tf.keras.losses.BinaryCrossentropy(),
                    metrics=[tf.keras.metrics.AUC(name='auroc')])
    
    elif model_type == 'adversarial':
        model = DomainAdversarialImageClassifier(n_clusters=n_clusters)
        model.compile(opt_adversary=tf.keras.optimizers.Nadam(lr=0.0001),
                    opt_classifier=tf.keras.optimizers.Nadam(lr=0.0001),
                    metric_classifier=tf.keras.metrics.AUC(name='auroc'))
    
    elif model_type == 'randomeffects':
        model = RandomEffectsClassifier(intercept_post_init_scale=0.1,
                                        intercept_prior_scale=0.25)
        model.compile(optimizer=tf.keras.optimizers.Nadam(lr=0.0001),
                    loss=tf.keras.losses.BinaryCrossentropy(),
                    metrics=[tf.keras.metrics.AUC(name='auroc')])
        
    elif model_type == 'mixedeffects':
        model = MixedEffectsClassifier(n_clusters=n_clusters,
                                    intercept_post_init_scale=0.1,
                                    intercept_prior_scale=0.25)
        model.compile(opt_adversary=tf.keras.optimizers.Nadam(lr=0.0001),
                    opt_classifier=tf.keras.optimizers.Nadam(lr=0.0001),
                    metric_classifier=tf.keras.metrics.AUC(name='auroc'))
    else:
         raise ValueError(model_type, 'not recognized')           
    return model

def train_evaluate(split_dir: str, model_type: str, epochs: int=20, weights_path:str = None,
                   verbose: int=0, randomize_z=False):

    dictDataTrain = np.load(os.path.join(split_dir, 'data_train.npz'))
    dictDataVal = np.load(os.path.join(split_dir, 'data_val.npz'))
    dictDataTest = np.load(os.path.join(split_dir, 'data_test.npz'))
    dictDataUnseen = np.load(os.path.join(split_dir, 'data_unseen.npz'))

    # Weight each class by 1 - class frequency
    dictClassWeights = {0.: dictDataTrain['label'].mean(),
                        1.: 1 - dictDataTrain['label'].mean()}

    if model_type == 'conventional':    
        train_in = dictDataTrain['images']
        val_in = dictDataVal['images']
        test_in = dictDataTest['images']
        unseen_in = dictDataUnseen['images']
    else:
        train_in = (dictDataTrain['images'], dictDataTrain['cluster'])
        val_in = (dictDataVal['images'], dictDataVal['cluster'])
        test_in = (dictDataTest['images'], dictDataTest['cluster'])
        unseen_in = (dictDataUnseen['images'], dictDataUnseen['cluster'])
    
    model = get_model(model_type, n_clusters=dictDataTrain['siteorder'].shape[0])
        
    lsCallbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_auroc', mode='max',
                                                    patience=5, 
                                                    restore_best_weights=True)]
        
    model.fit(x=train_in,
            y=dictDataTrain['label'],
            batch_size=32,
            epochs=epochs,
            verbose=verbose,
            class_weight=dictClassWeights,
            callbacks=lsCallbacks,
            validation_data=(val_in, dictDataVal['label']))

    arrPredTrain = model.predict(train_in, verbose=0)
    arrPredVal = model.predict(val_in, verbose=0)
    
    # Make random Z inputs for test and unseen site data
    if randomize_z:
        nClusters = len(dictDataTrain['siteorder'])
        arrImagesTest = test_in[0]
        nTest = arrImagesTest.shape[0]
        arrZTest = make_random_onehot(nTest, nClusters)
        test_in = (arrImagesTest, arrZTest)
    
        arrImagesUnseen = unseen_in[0]
        nUnseen = arrImagesUnseen.shape[0]
        arrZUnseen = make_random_onehot(nUnseen, nClusters)    
        unseen_in = (arrImagesUnseen, arrZUnseen)
    
    arrPredTest = model.predict(test_in, verbose=0)
    arrPredUnseen = model.predict(unseen_in, verbose=0)
    dictMetricsTrain, youden = classification_metrics(dictDataTrain['label'], arrPredTrain, 
                                                      fixed_sens=0.7, fixed_spec=0.7)
    dictMetricsVal, _ = classification_metrics(dictDataVal['label'], arrPredVal, 
                                               fixed_sens=0.7, fixed_spec=0.7)
    dictMetricsTest, _ = classification_metrics(dictDataTest['label'], arrPredTest, 
                                                fixed_sens=0.7, fixed_spec=0.7)
    dictMetricsUnseen, _ = classification_metrics(dictDataUnseen['label'], arrPredUnseen, 
                                                  fixed_sens=0.7, fixed_spec=0.7)
        
    lsMetrics = [dictMetricsTrain, dictMetricsVal, dictMetricsTest, dictMetricsUnseen]
    
    dfMetrics = pd.DataFrame(lsMetrics)
    dfMetrics['partition'] = ['Train', 'Val', 'Test', 'Unseen']
    
    if weights_path:
        model.save_weights(weights_path)
    
    return dfMetrics


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, 
                        default='ADNI23_sMRI/right_hippocampus_slices_2pctnorm/coronal_MNI-6_numpy/12sites',
                        help='Path to directory containing data splits.')
    parser.add_argument('--out_dir', type=str, required=True, help='Output directory')
    parser.add_argument('--model_type', type=str, choices=['conventional', 'clusterinput', 'adversarial', 
                                                           'mixedeffects', 'randomeffects'],
                        required=True, help='Model type.')
    parser.add_argument('--epochs', type=int, default=20, help='Training duration. Defaults to 20')
    parser.add_argument('--randomize_sites', action='store_true', help='Use a randomized site membership'
                        ' input on test and unseen site data (RE ablation test).')
    parser.add_argument('--gpu', type=int, help='GPU to use. Defaults to all.')
    parser.add_argument('--verbose', type=int, default=0, help='Show training progress.')

    args = parser.parse_args()
    
    if args.gpu:
        from armed.tfutils import set_gpu
        set_gpu(args.gpu)
    
    strDataDir = expand_data_path(args.data_dir)
    lsSplitDirs = glob.glob(os.path.join(strDataDir, 'split*'))
    lsSplitDirs.sort()

    strOutDir = expand_results_path(args.out_dir, make=True)

    lsAllMetrics = []
    for strSplitDir in lsSplitDirs:
        strSplitName = os.path.basename(strSplitDir)
        print(strSplitName)
        strOutPath = os.path.join(strOutDir, strSplitName + '_weights.h5')
        df = train_evaluate(strSplitDir, args.model_type, weights_path=strOutPath, epochs=args.epochs, 
                            verbose=args.verbose, randomize_z=args.randomize_sites)
        df['split'] = os.path.basename(strSplitDir)
        lsAllMetrics += [df]
        
    dfAllMetrics = pd.concat(lsAllMetrics)
    dfAllMetrics.to_csv(os.path.join(strOutDir, 'metrics.csv'))
    
    dfMean = dfAllMetrics.groupby('partition').mean()
    dfSE = dfAllMetrics.groupby('partition').std() / (len(lsSplitDirs) ** 0.5)
    df95CILow = dfMean - dfSE * 1.96
    df95CIHi = dfMean + dfSE * 1.96
    dfMeanCI = pd.concat({'Mean': dfMean, '95CI Low': df95CILow, '95CI Hi': df95CIHi}, axis=1)
    print(dfMeanCI.loc[['Train', 'Val', 'Test', 'Unseen']].to_string())