ARMED-MixedEffectsDL / melanoma_aec / main.py
main.py
Raw
'''
Main script for training and evaluating an autoencoder-classifier (AEC). 

Required arguments include the model type and output directory. 

To train a model:
python main.py --model_type conventional --output_dir /path/to/output/location

To evaluate the model after training:
python main.py --model_type conventional --output_dir /path/to/output/location --do_test 
    --load_weights_epoch <epoch index to load weights from>

See python main.py --help for all arguments. 
'''

import os
import argparse
import json
import glob
import numpy as np
import pandas as pd

from armed.misc import expand_data_path, expand_results_path, make_random_onehot

from sklearn.metrics import davies_bouldin_score, calinski_harabasz_score

def _shuffle_data(data_dict):
    # shuffle samples
    arrIdx = np.arange(data_dict['images'].shape[0])
    np.random.seed(64)
    np.random.shuffle(arrIdx)
    return {k: v[arrIdx,] for k, v in data_dict.items()}

def _random_samples(data_dict, metadata=None, n=100):
    # select n random samples
    arrIdx = np.arange(data_dict['images'].shape[0])
    np.random.seed(64)
    arrSampleIdx = np.random.choice(arrIdx, size=n)
    dictNew = {k: v[arrSampleIdx,] for k, v in data_dict.items()}
    if metadata is not None:
        return dictNew, metadata.iloc[arrSampleIdx,]
    else:
        return dictNew

def _get_model(model_type, n_clusters=10):
    # Build and compile a model with some preset hyperparameters
    import tensorflow as tf
    from armed.models import autoencoder_classifier
    if model_type == 'conventional':        
        model = autoencoder_classifier.BaseAutoencoderClassifier(n_latent_dims=56)
          
        model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001),
                      loss=[tf.keras.losses.MeanSquaredError(name='mse'),
                            tf.keras.losses.BinaryCrossentropy(name='bce')],
                      loss_weights=[1.0, 0.1],
                      metrics=[[],
                               [tf.keras.metrics.AUC(name='auroc')]])  
        
    elif model_type == 'adversarial':
        model = autoencoder_classifier.DomainAdversarialAEC(n_clusters=n_clusters, n_latent_dims=56)
        
        model.compile(loss_recon=tf.keras.losses.MeanSquaredError(),
                      loss_class=tf.keras.losses.BinaryCrossentropy(),
                      loss_adv=tf.keras.losses.BinaryCrossentropy(),
                      metric_class=tf.keras.metrics.AUC(name='auroc'),
                      metric_adv=tf.keras.metrics.CategoricalAccuracy(name='acc'),
                      opt_autoencoder=tf.keras.optimizers.Adam(lr=0.0001),
                      opt_adversary=tf.keras.optimizers.Adam(lr=0.0001),
                      loss_recon_weight=1.0,
                      loss_class_weight=0.01,
                      loss_gen_weight=0.1)
    
    elif model_type == 'mixedeffects':
        model = autoencoder_classifier.MixedEffectsAEC(n_clusters=n_clusters, n_latent_dims=56)
        model.compile()
        
    elif model_type == 'randomeffects':
        model = autoencoder_classifier.DomainEnhancingAutoencoderClassifier(n_clusters=n_clusters, n_latent_dims=56, kl_weight=1e-7)
        model.compile()
        
    return model

def train_model(model_type: str, 
                data_train: dict, 
                data_val: dict, 
                train_metadata: pd.DataFrame,
                val_metadata: pd.DataFrame,
                output_dir: str,
                epochs: int=10,
                verbose: bool=True,
                ):
    """Train a model.

    Args:
        model_type (str): name of model type
        data_train (dict): training data, should have keys 'images', 'label',
            and 'cluster'
        data_val (dict): validation data, should have keys 'images', 'label',
            and 'cluster'
        train_metadata (pd.DataFrame): training metadata
        val_metadata (pd.DataFrame): validation metadata
        output_dir (str): path to output location
        epochs (int, optional): epochs to train. Defaults to 10.
        verbose (bool, optional): training verbosity. Defaults to True.

    Returns:
        dict: final model metrics
    """    

    # Imports done inside function so that memory is allocated properly when
    # used with Ray Tune    
    import tensorflow as tf
    import tensorflow.keras.layers as tkl
    from armed.callbacks import aec_callbacks
        
    strOutputDir = expand_results_path(output_dir)

    model = _get_model(model_type, n_clusters=data_train['cluster'].shape[1])
    if model_type == 'conventional':        
        train_in = data_train['images']
        train_out = (data_train['images'], data_train['label'])
        val_in = data_val['images']
        val_out = (data_val['images'], data_val['label'])
          
    else:
        train_in = (data_train['images'], data_train['cluster'])
        train_out = (data_train['images'], data_train['label'])
        val_in = (data_val['images'], data_val['cluster'])
        val_out = (data_val['images'], data_val['label'])

    # Get a few samples to generate example reconstructions every epoch
    data_sample = _random_samples(data_val, n=8)
    arrBatchX = data_sample['images']
    arrBatchZ = data_sample['cluster']
    
    # Callbacks:
    # Create figure with example reconstructions
    recon_images = aec_callbacks.make_recon_figure_callback(arrBatchX, model, output_dir,
                                            clusters=None if model_type == 'conventional' else arrBatchZ,
                                            mixedeffects=model_type == 'mixedeffects')
    # Compute image metrics 
    compute_image_metrics = aec_callbacks.make_image_metrics_callback(model, val_in, val_metadata, output_dir,
                                            output_idx=1 if model_type == 'mixedeffects' else 0)
    
    lsCallbacks = [tf.keras.callbacks.CSVLogger(os.path.join(strOutputDir, 'training_log.csv')),
                   tf.keras.callbacks.LambdaCallback(on_epoch_end=recon_images),
                   tf.keras.callbacks.LambdaCallback(on_epoch_end=compute_image_metrics),
                   tf.keras.callbacks.ModelCheckpoint(os.path.join(strOutputDir, 'epoch{epoch:03d}_weights.h5'),
                                                      save_weights_only=True)]
    
    # Isolate the encoder
    if model_type == 'randomeffects':
        # RE model takes both image and cluster as inputs
        encoder_in = (tkl.Input((256, 256, 1), name='encoder_in_x'),
                      tkl.Input((data_train['cluster'].shape[1],), name='encoder_in_z'))
        encoder_out = model.encoder(encoder_in)
        encoder_data = train_in
    else:
        encoder_in = tkl.Input((256, 256, 1), name='encoder_in')
        encoder_out = model.encoder(encoder_in)
        if isinstance(encoder_out, tuple):
            # If the encoder outputs all layer activations, keep only the latent rep output
            encoder_out = encoder_out[-1]
        encoder_data = data_train['images']
        
    encoder = tf.keras.models.Model(encoder_in, encoder_out, name='standalone_encoder')
    
    # Create callback to save latent representations for training data every epoch
    compute_latents = aec_callbacks.make_compute_latents_callback(encoder, encoder_data,
                                                                  train_metadata, output_dir)
    lsCallbacks += [tf.keras.callbacks.LambdaCallback(on_epoch_end=compute_latents)]
    
    # Train
    history = model.fit(train_in, train_out,
                        epochs=epochs,
                        verbose=verbose,
                        batch_size=16, # changed from 32 to fit on P40
                        validation_data=(val_in, val_out),
                        shuffle=True,
                        callbacks=lsCallbacks)
    # Get final metrics
    dfHistory = pd.DataFrame(history.history)
    dictResults = dfHistory.iloc[-1].to_dict()
    
    # Compute clustering metrics on latents
    arrLatents = encoder.predict(encoder_data)
    arrLatents -= arrLatents.mean(axis=0)
    arrLatents /= arrLatents.std(axis=0)    
    db = davies_bouldin_score(arrLatents, train_metadata['date'])
    ch = calinski_harabasz_score(arrLatents, train_metadata['date'])
    
    dictResults.update(db=db, ch=ch)
    
    return dictResults


def test_model(model_type: str, 
               saved_weights: str,
               data: dict,
               randomize_z: bool = False):
    """Evaluate trained model.

    Args:
        model_type (str): name of model type
        saved_weights (str): path to saved weights in .h5 file
        data (dict): data for evaluating model, should have keys 'images',
            'label', and 'cluster'
        randomize_z (bool): randomize the cluster membership input as an 
            ablation test. Defaults to False.

    Returns:
        dict: model metrics
    """    
    
    # Imports done inside function so that memory is allocated properly when
    # used with Ray Tune    
    from armed.models.autoencoder_classifier import load_weights_base_aec
        
    data = _shuffle_data(data)

    model = _get_model(model_type, n_clusters=data['cluster'].shape[1])
    if model_type == 'conventional':
        data_in = data['images']
        data_out = (data['images'], data['label'])
        
    else:
        z = data['cluster']
        if randomize_z:
            z = make_random_onehot(z.shape[0], z.shape[1])
        
        data_in = (data['images'], z)
        data_out = (data['images'], data['label'])
    
    # Call model once to instantiate weights
    _ = model.predict(data_in, steps=1, batch_size=32)    
    if model_type == 'conventional':
        # Workaround for weight loading bug
        load_weights_base_aec(model, saved_weights)
    else:
        model.load_weights(saved_weights)
            
    dictMetrics = model.evaluate(data_in, data_out, 
                                 batch_size=32,
                                 return_dict=True)
                
    return dictMetrics

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--training_metadata', type=str, 
                        default='melanoma/allpdx_selecteddates/data_train.csv',
                        help='Path to CSV table containing training image metadata.')
    parser.add_argument('--training_data', type=str, 
                        default='melanoma/allpdx_selecteddates/data_train.npz',
                        help='Path to .npz file containing training images.')
    parser.add_argument('--val_metadata', type=str, 
                        default='melanoma/allpdx_selecteddates/data_val.csv',
                        help='Path to CSV table containing validation image metadata.')
    parser.add_argument('--val_data', type=str, 
                        default='melanoma/allpdx_selecteddates/data_val.npz',
                        help='Path to .npz file containing validation images.')
    parser.add_argument('--test_metadata', type=str, 
                        default='melanoma/allpdx_selecteddates/data_test.csv',
                        help='Path to CSV table containing test image metadata.')
    parser.add_argument('--test_data', type=str, 
                        default='melanoma/allpdx_selecteddates/data_test.npz',
                        help='Path to .npz file containing test images.')
    parser.add_argument('--output_dir', type=str, required=True, help='Output directory.')
    parser.add_argument('--model_type', type=str, choices=['conventional', 'adversarial', 'mixedeffects', 'randomeffects'],
                        required=True, help='Model type.')
    parser.add_argument('--epochs', type=int, default=10, help='Training duration. Defaults to 10')
    parser.add_argument('--do_test', action='store_true', help='Evaluate on test')
    parser.add_argument('--randomize_batch', action='store_true', help='Use a randomized batch membership'
                        ' input when testing (as an ablation test).')
    parser.add_argument('--load_weights_epoch', type=int, default=None, help='If evaluating on test, load weights'
                        ' from this epoch and skip training.')
    parser.add_argument('--verbose', type=int, default=1, help='Show training progress.')
    parser.add_argument('--gpu', type=int, help='GPU to use. Defaults to all.')
    parser.add_argument('--smoketest', action='store_true', 
                        help='For quick testing purposes, use a dataset of 100 samples')

    args = parser.parse_args()

    if args.gpu:
        # Select GPU to use
        from armed.tfutils import set_gpu
        set_gpu(args.gpu)
        
    strOutputDir = expand_results_path(args.output_dir, make=True)
    
    if args.load_weights_epoch is None:
        # If no weights were selected to load, train model 
        strTrainDataPath = expand_data_path(args.training_data)
        strTrainMetaDataPath = expand_data_path(args.training_metadata)
        strValDataPath = expand_data_path(args.val_data)
        strValMetaDataPath = expand_data_path(args.val_metadata)
        dfTrainMetadata = pd.read_csv(strTrainMetaDataPath, index_col=0)
        dfValMetadata = pd.read_csv(strValMetaDataPath, index_col=0)

        dictDataTrain = np.load(strTrainDataPath)
        dictDataVal = np.load(strValDataPath)
        
        if args.smoketest:
            dictDataTrain, dfTrainMetadata = _random_samples(dictDataTrain, dfTrainMetadata, n=100)
            dictDataVal, dfValMetadata = _random_samples(dictDataVal, dfValMetadata, n=100)
    
        dictMetrics = train_model(model_type=args.model_type, 
                                data_train=dictDataTrain, 
                                data_val=dictDataVal,
                                train_metadata=dfTrainMetadata, 
                                val_metadata=dfValMetadata,
                                output_dir=strOutputDir,
                                epochs=args.epochs,
                                verbose=args.verbose == 1)
        
        print(dictMetrics)
        
    if args.do_test:
        strTestDataPath = expand_data_path(args.test_data)
        strTestMetaDataPath = expand_data_path(args.test_metadata)
        dictDataTest = np.load(strTestDataPath)
        dfTestMetadata = pd.read_csv(strTestMetaDataPath, index_col=0)
        
        if args.smoketest:
            dictDataTest, dfTestMetadata = _random_samples(dictDataTest, dfTestMetadata, n=100)

        if args.load_weights_epoch is not None:
            strSavedWeightsPath = os.path.join(strOutputDir, f'epoch{args.load_weights_epoch:03d}_weights.h5')
            assert os.path.exists(strSavedWeightsPath)
        else:
            # Grab the last epoch weights
            lsWeights = glob.glob(os.path.join(strOutputDir, '*weights.h5'))
            lsWeights.sort()
            strSavedWeightsPath = lsWeights[-1]
            
        print('Loading weights from', strSavedWeightsPath, flush=True)

        dictMetrics = test_model(model_type=args.model_type,
                                 saved_weights=strSavedWeightsPath,
                                 data=dictDataTest,
                                 randomize_z=args.randomize_batch)
        
        with open(os.path.join(strOutputDir, 'test_metrics.json'), 'w') as f:
            json.dump(dictMetrics, f, indent=4)