ARMED-MixedEffectsDL / adni_t1w / infer_z_unseen_sites.py
infer_z_unseen_sites.py
Raw
'''
Use a CNN, trained to predict site from an image, to infer the cluster
membership design matrix Z for each image from an unseen site.
'''

import os
import numpy as np
from armed.tfutils import set_gpu

set_gpu(1, 0.5)

import tensorflow as tf
import tensorflow.keras.layers as tkl
from armed.misc import expand_data_path

strDataDir = expand_data_path('ADNI23_sMRI/right_hippocampus_slices_2pctnorm/coronal_MNI-6_numpy/12sites')

dictDataTrain = np.load(os.path.join(strDataDir, 'data_train.npz'))
dictDataVal = np.load(os.path.join(strDataDir, 'data_val.npz'))
# dictDataTest = np.load(os.path.join(strDataDir, 'data_test.npz'))
dictDataUnseen = np.load(os.path.join(strDataDir, 'data_unseen.npz'))

nClusters = dictDataTrain['siteorder'].shape[0]

# Simple CNN classifier
x = tkl.Input(dictDataTrain['images'].shape[1:])
h = tkl.Conv2D(64, 3, padding='same')(x)
h = tkl.BatchNormalization()(h)
h = tkl.PReLU()(h)
h = tkl.MaxPool2D()(h)

h = tkl.Conv2D(64, 3, padding='same')(h)
h = tkl.Dropout(0.5)(h)
h = tkl.BatchNormalization()(h)
h = tkl.PReLU()(h)
h = tkl.MaxPool2D()(h)

h = tkl.Conv2D(128, 3, padding='same')(h)
h = tkl.BatchNormalization()(h)
h = tkl.PReLU()(h)
h = tkl.MaxPool2D()(h)

h = tkl.Conv2D(128, 3, padding='same')(h)
h = tkl.Dropout(0.5)(h)
h = tkl.BatchNormalization()(h)
h = tkl.PReLU()(h)
h = tkl.MaxPool2D()(h)

h = tkl.Conv2D(256, 3, padding='same')(h)
h = tkl.BatchNormalization()(h)
h = tkl.PReLU()(h)
h = tkl.MaxPool2D()(h)

h = tkl.Conv2D(256, 3, padding='same')(h)
h = tkl.Dropout(0.5)(h)
h = tkl.BatchNormalization()(h)
h = tkl.PReLU()(h)
h = tkl.MaxPool2D()(h)

h = tkl.Conv2D(512, 3, padding='valid')(h)
h = tkl.BatchNormalization()(h)
h = tkl.PReLU()(h)

h = tkl.Flatten()(h)
h = tkl.Dense(512)(h)
h = tkl.PReLU()(h)
y = tkl.Dense(nClusters, activation='softmax')(h)

model = tf.keras.Model(x, y)

model.compile(optimizer=tf.keras.optimizers.Nadam(lr=0.0001),
              loss=tf.keras.losses.CategoricalCrossentropy(),
              metrics=['accuracy'])
lsCallbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', mode='max',
                                                patience=5, 
                                                restore_best_weights=True)]
model.fit(x=dictDataTrain['images'],
          y=dictDataTrain['cluster'],
          batch_size=32,
          epochs=20,
          verbose=1,
          callbacks=lsCallbacks,
          validation_data=(dictDataVal['images'], dictDataVal['cluster']))
arrZInferred = model.predict(dictDataUnseen['images'])