ARMED-MixedEffectsDL / melanoma_aec / infer_design_matrix.py
infer_design_matrix.py
Raw
'''
Infer the cluster membership design matrix for the data from unseen batches.
Train a classifier to predict the design matrix from the latent representations
produced by the random effects (domain enhancing) autoencoder. Then, predict on
the latent representations of the unseen batch data.

'''

import numpy as np
import pandas as pd
from armed.misc import expand_data_path, expand_results_path

from armed.tfutils import set_gpu
from armed.models.autoencoder_classifier import ImageClassifier
import tensorflow as tf
import tensorflow.keras.layers as tkl

def load_metadata(path):
    path = expand_data_path(path)
    return pd.read_csv(path, index_col=0)

def load_latents(path):
    path = expand_results_path(path)
    return pd.read_pickle(path)

# Load data
dfMetadataTrain = load_metadata('melanoma/allpdx_selecteddates/data_train.csv')
dfMetadataVal = load_metadata('melanoma/allpdx_selecteddates/data_val.csv')
dfMetadataTest = load_metadata('melanoma/allpdx_selecteddates/data_test.csv')
dfMetadataUnseen = load_metadata('melanoma/allpdx_selecteddates/data_unseen.csv')

strUnseenDataPath = expand_data_path('melanoma/allpdx_selecteddates/data_unseen.npz')
dictUnseen = dict(np.load(strUnseenDataPath))

# Save new dataset with inferred Z here
strUnseenDataOutPath = expand_data_path('melanoma/allpdx_selecteddates/data_unseen_inferred_z.npz')

# Mapping of cluster design matrix columns to dates
arrClasses = dfMetadataTrain['date'].unique()
dictClassToInt = {k: v for v, k in enumerate(arrClasses)}

strTrainDataPath = expand_data_path('melanoma/allpdx_selecteddates/data_train.npz')
strValDataPath = expand_data_path('melanoma/allpdx_selecteddates/data_val.npz')
strTestDataPath = expand_data_path('melanoma/allpdx_selecteddates/data_test.npz')

dictTrain = np.load(strTrainDataPath)
dictVal = np.load(strValDataPath)
dictTest = np.load(strTestDataPath)

layer_in = tkl.Input((256, 256, 1))
layer_out = ImageClassifier(n_clusters=arrClasses.shape[0])(layer_in)
model = tf.keras.Model(layer_in, layer_out)
model.compile(optimizer=tf.keras.optimizers.Nadam(lr=0.0001), 
              loss='categorical_crossentropy', metrics='accuracy')
lsCallbacks = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=3, verbose=1, 
                                               restore_best_weights=True)

model.fit(dictTrain['images'], dictTrain['cluster'],
          validation_data=(dictVal['images'], dictVal['cluster']),
          batch_size=32,
          epochs=10,
          callbacks=lsCallbacks)

arrZUnseen = model.predict(dictUnseen['images'])
dictUnseen['cluster'] = arrZUnseen
np.savez(strUnseenDataOutPath, **dictUnseen)