scan / src / nsclc_bimodal / merge.py
merge.py
Raw
import scipy.io
import numpy as np
import random
import matplotlib.pyplot as plt

from keras.models import Sequential
from keras.models import model_from_json
from keras.models import Input,Model
from keras.layers import Dense, Activation
from keras.layers import Dropout, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers import concatenate
from keras.regularizers import l2
from keras.utils.np_utils import to_categorical
from keras.optimizers import SGD, Nadam

from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.callbacks.callbacks import Callback
from sklearn.metrics import roc_auc_score,auc,roc_curve

# use the identical train/test data as in Scientific Reports paper
# link: https://www.nature.com/articles/s41598-020-61588-w
data = scipy.io.loadmat('Data_all.mat')
train_x = data['train_x'].astype('float32')
train_c = data['clinical_train_x'].astype('float32')
train_y = data['train_y'].astype(int)
test_x  = data['test_x'].astype('float32')
test_c  = data['clinical_test_x'].astype('float32')
test_y  = data['test_y'].astype(int)

n_classes = 2
train_y = train_y -1
test_y = test_y -1
y_train = to_categorical(train_y, n_classes)
y_test = to_categorical(test_y, n_classes)

x_input = Input(shape=(np.shape(train_x)[1],),dtype='float32')
left_branch = Dense(40,input_dim=np.shape(train_x)[1],W_regularizer=l2(0.001),name='pre_x_input',trainable=False)(x_input)
left_branch = Activation('relu')(left_branch)
left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h1',trainable=False)(left_branch)
left_branch = Activation('relu')(left_branch)
left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h2',trainable=False)(left_branch)
left_branch = Activation('relu')(left_branch)
left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h3',trainable=False)(left_branch)
left_branch = Activation('relu')(left_branch)
left_branch = Dense(n_classes,name='pre_x_out',trainable=False)(left_branch)
left_branch = Activation('softmax')(left_branch)

c_input = Input(shape=(np.shape(train_c)[1],),dtype='float32')
right_branch = Dense(18,input_dim=np.shape(train_c)[1],W_regularizer=l2(0.001),name='pre_c_input',trainable=False)(c_input)
right_branch = Activation('relu')(right_branch)
right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h1',trainable=False)(right_branch)
right_branch = Activation('relu')(right_branch)
right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h2',trainable=False)(right_branch)
right_branch = Activation('relu')(right_branch)
right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h3',trainable=False)(right_branch)
right_branch = Activation('relu')(right_branch)
right_branch = Dense(n_classes,name='pre_c_out',trainable=False)(right_branch)
right_branch = Activation('softmax')(right_branch)

merged = concatenate([left_branch,right_branch],axis=-1)
merged = Dense(58,W_regularizer=l2(0.001),name='merge_h1')(merged)
merged = Activation('relu')(merged)
merged = Dense(32,W_regularizer=l2(0.001),name='merge_h2')(merged)
merged = Activation('relu')(merged)
merged = Dense(32,W_regularizer=l2(0.001),name='merge_h3')(merged)
merged = Activation('relu')(merged)
merged = Dense(2, activation='softmax')(merged)
merged_model = Model(inputs=[x_input,c_input],outputs=merged)
merged_model.summary()

ES = EarlyStopping(monitor='val_loss',patience=30,verbose=0,mode='min')
nadam = Nadam(lr=0.006,beta_1=0.9,beta_2=0.999,epsilon=1e-08,schedule_decay=0.004)
checkpointer = ModelCheckpoint(
    filepath="../../model/bimodal/nsclc_weights_merge.hdf5",
    verbose=1,
    save_best_only=True,
    monitor="val_loss",
    mode="min")
merged_model.compile(
    loss='categorical_crossentropy',
    optimizer=nadam,
    metrics=['accuracy']
)

merged_model.load_weights('../../model/bimodal/nsclc_weights_x.hdf5',by_name=True)
merged_model.load_weights('../../model/bimodal/nsclc_weights_c.hdf5',by_name=True)

loss_test = []
class TestCallback(Callback):
    def __init__(self, test_data):
        self.test_data = test_data

    def on_epoch_end(self, epoch, logs={}):
        x, c, y = self.test_data
        loss, acc = self.model.evaluate([x,c], y, verbose=0)
        loss_test.append(loss)
        #print('\nTesting loss: {}, acc: {}\n'.format(loss, acc))


batch_size = 20
n_epochs = 100
merged_model.load_weights("../../model/bimodal/nsclc_weights_merge.hdf5")
# loss,accuracy = merged_model.evaluate([test_x,test_c],y_test)
# print('loss:', loss)
# print('accuracy:', accuracy)

data = np.load('../../data/nsclc/indep.npz',allow_pickle=True)
y_pred = merged_model.predict([test_x,test_c])
np.savez_compressed('../../model/bimodal/nsclc_bimodal_indep_logits.npz',pred=y_pred)
fpr,tpr,thr = roc_curve(y_test[:,0],y_pred[:,0])
auc_test = auc(fpr,tpr)
print(auc_test)

data = np.load('../../data/nsclc/nsclc_3.npz',allow_pickle=True)  # same test set
y_pred = merged_model.predict([test_x,test_c])
valid_num = int(np.shape(train_x)[0] * 0.25)
x_valid = train_x[-valid_num:,:]
c_valid = train_c[-valid_num:,:]
y_valid = train_y[-valid_num:,:]
y_pred_valid = merged_model.predict([x_valid,c_valid])
print(np.shape(x_valid))
print(np.shape(y_valid))
fpr,tpr,thr = roc_curve(y_valid[:,0],y_pred_valid[:,0])
thr_best = thr[np.argmax(np.subtract(tpr,fpr))]
np.savez_compressed('../../model/bimodal/nsclc_bimodal_logits.npz',pred=y_pred,thr=thr_best)