scan / src / nsclc_bimodal / merge_cv_test.py
merge_cv_test.py
Raw
from keras.models import Sequential
import scipy.io
import numpy as np
from keras.regularizers import l2#, activity_l2
from keras.layers import Dense, Activation
from keras.layers import Dropout, Flatten, Dense
from keras.utils.np_utils import to_categorical
from keras.optimizers import SGD, Nadam
import random
import numpy as np
import matplotlib.pyplot as plt
#from keras.utils.visualize_util import plot
#from keras.engine.topology import merge
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Activation
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers import Dropout, Flatten, Dense
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.models import model_from_json
#from keras.layers import Merge,merge
from keras.layers import concatenate
from keras.models import Input,Model

from sklearn.metrics import roc_auc_score,auc,roc_curve
from keras.callbacks.callbacks import Callback

import argparse
import tensorflow as tf
from sklearn import preprocessing
from sklearn.metrics import roc_auc_score,auc,roc_curve,f1_score,accuracy_score
from lifelines.utils import concordance_index
from sklearn.metrics import accuracy_score,average_precision_score,precision_recall_curve

def calc_c_index_benchmark(o_test,y_prob):
    time = np.array(np.expand_dims(o_test,axis=1),dtype=float)
    prob = np.tanh(y_prob)
    # prob = np.exp(y_prob)
    mean = np.mean(prob)
    y_score = (prob - mean)
    return concordance_index(time,-y_score)
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed',default=0,type=int,help='random seed')
    parser.add_argument('--i',default=0,type=int,help='num_of_ensemble')
    return parser.parse_args()

def create_model(is_x,is_c,is_b,i):
    x_input = Input(shape=(15,),dtype='float32')
    left_branch = Dense(40,input_dim=15,W_regularizer=l2(0.001),name='pre_x_input',trainable=is_x)(x_input)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h1',trainable=is_x)(left_branch)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h2',trainable=is_x)(left_branch)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h3',trainable=is_x)(left_branch)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(2,name='pre_x_out',trainable=is_x)(left_branch)
    left_branch = Activation('softmax')(left_branch)
    model_x = Model(inputs=x_input,outputs=left_branch)
    if not is_x:  model_x.load_weights('breast_bimodal_ensemble/weights_x_' + str(i) + '.hdf5')
    else:  return model_x

    c_input = Input(shape=(7,),dtype='float32')
    right_branch = Dense(18,input_dim=7,W_regularizer=l2(0.001),name='pre_c_input',trainable=is_c)(c_input)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h1',trainable=is_c)(right_branch)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h2',trainable=is_c)(right_branch)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h3',trainable=is_c)(right_branch)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(2,name='pre_c_out',trainable=is_c)(right_branch)
    right_branch = Activation('softmax')(right_branch)
    model_c = Model(inputs=c_input,outputs=right_branch)
    if not is_c:  model_c.load_weights('breast_bimodal_ensemble/weights_c_' + str(i) + '.hdf5')
    else:  return model_c

    merged = concatenate([left_branch,right_branch],axis=-1)
    merged = Dense(58,W_regularizer=l2(0.001),name='merge_h1',trainable=is_b)(merged)
    merged = Activation('relu')(merged)
    merged = Dense(32,W_regularizer=l2(0.001),name='merge_h2',trainable=is_b)(merged)
    merged = Activation('relu')(merged)
    merged = Dense(32,W_regularizer=l2(0.001),name='merge_h3',trainable=is_b)(merged)
    merged = Activation('relu')(merged)
    merged = Dense(2, activation='softmax',trainable=is_b)(merged)
    merged_model = Model(inputs=[x_input,c_input],outputs=merged)
    return merged_model    

def main(args):
    tf.compat.v1.set_random_seed(args.seed)
    np.random.seed(args.seed)

    data = scipy.io.loadmat('Data_all.mat')
    train_x = data['train_x'].astype('float32')
    train_c = data['clinical_train_x'].astype('float32')
    train_y = data['train_y'].astype(int)
    test_x  = data['test_x'].astype('float32')
    test_c  = data['clinical_test_x'].astype('float32')
    test_y  = data['test_y'].astype(int)

    n_classes = 2
    train_y = train_y -1
    test_y = test_y -1
    y_train = to_categorical(train_y, n_classes)
    y_test = to_categorical(test_y, n_classes)
    valid_num = int(np.shape(train_x)[0] * 0.25)
    valid_x = train_x[-valid_num:,:]
    valid_c = train_c[-valid_num:,:]
    valid_y = y_train[-valid_num:,:]

    nadam = Nadam(lr=0.006,beta_1=0.9,beta_2=0.999,epsilon=1e-08,schedule_decay=0.004)
    ES = EarlyStopping(monitor='val_loss',patience=30,verbose=0,mode='min')
    checkpointer_x = ModelCheckpoint(
        filepath='breast_bimodal_ensemble/weights_x_' + str(args.i) + '.hdf5',
        verbose=1,save_best_only=True,monitor='val_loss',mode='min')
    checkpointer_c = ModelCheckpoint(
        filepath='breast_bimodal_ensemble/weights_c_' + str(args.i) + '.hdf5',
        verbose=1,save_best_only=True,monitor='val_loss',mode='min')
    checkpointer_merge = ModelCheckpoint(
        filepath='breast_bimodal_ensemble/weights_merge_' + str(args.i) + '.hdf5',
        verbose=1,save_best_only=True,monitor='val_loss',mode='min')

    # model_x = create_model(True,False,False,args.i)
    # model_x.compile(loss='categorical_crossentropy',optimizer=nadam,metrics=['accuracy'])
    # model_x.fit(train_x,y_train,batch_size=20,nb_epoch=100,
    #     validation_data=(valid_x,valid_y),callbacks=[ES,checkpointer_x])
    # model_x.load_weights('breast_bimodal_ensemble/weights_x_' + str(args.i) + '.hdf5')
    # loss,accuracy = model_x.evaluate(test_x,y_test)
    # print('loss:', loss)
    # print('accuracy:', accuracy)
    # y_pred = model_x.predict(test_x)
    # fpr,tpr,thr = roc_curve(y_test[:,0],y_pred[:,0])
    # auc_test = auc(fpr,tpr)
    # print(auc_test)

    # model_c = create_model(False,True,False,args.i)
    # model_c.compile(loss='categorical_crossentropy',optimizer=nadam,metrics=['accuracy'])
    # model_c.fit(train_c,y_train,batch_size=20,nb_epoch=100,
    #     validation_data=(valid_c,valid_y),callbacks=[ES,checkpointer_c])
    # # model_c.fit(train_c,y_train,batch_size=20,nb_epoch=100,
    # #     validation_split=0.25,callbacks=[ES,checkpointer_c])
    # model_c.load_weights('breast_bimodal_ensemble/weights_c_' + str(args.i) + '.hdf5')
    # loss,accuracy = model_c.evaluate(test_c,y_test)
    # print('loss:', loss)
    # print('accuracy:', accuracy)
    # y_pred = model_c.predict(test_c)
    # fpr,tpr,thr = roc_curve(y_test[:,0],y_pred[:,0])
    # auc_test = auc(fpr,tpr)
    # print(auc_test)

    model_merge = create_model(False,False,True,args.i)
    # model_merge.compile(loss='categorical_crossentropy',optimizer=nadam,metrics=['accuracy'])
    # model_merge.fit([train_x,train_c],y_train,batch_size=20,nb_epoch=100,
    #     validation_data=([valid_x,valid_c],valid_y),callbacks=[ES,checkpointer_merge])    
    data = np.load('../m2_aae/data/nsclc/nsclc_3.npz',allow_pickle=True)

    # independent validation ################
    # data_indep = np.load('../m2_aae/data/nsclc/indep.npz',allow_pickle=True)
    # c_tmp = data_indep['c_test']
    # stage = c_tmp[:,-1]
    # Stage = np.zeros((np.shape(stage)[0],5))
    # for i in range(np.shape(Stage)[0]):
    #     Stage[i,stage[i]-1] = 1
    # c_tmp = np.concatenate((c_tmp[:,0:2],Stage),axis=1)
    # cm = data['c_mean']
    # cs = data['c_scale']
    # c_mean = [cm[0],cm[1],cm[2],cm[2],cm[2],cm[2],cm[2]]
    # c_scale = [cs[0],cs[1],cs[2],cs[2],cs[2],cs[2],cs[2]]
    # test_x = (data_indep['x_test'] - data['x_mean']) / data['x_scale']
    # test_c = (c_tmp - c_mean) / c_scale
    # test_y = data_indep['y_test']
    # test_x = data_indep['x_test']
    # test_c = c_tmp
    # test_y = data_indep['y_test']
    ###########################################
    
    probs_test,probs_valid = [],[]
    for i in range(200):
        model_merge.load_weights('breast_bimodal_ensemble/weights_merge_' + str(i) + '.hdf5')

        prob_test = model_merge.predict([test_x,test_c])
        prob_test = prob_test[:,0]
        probs_test.append(prob_test)

        #####
        fpr,tpr,thr = roc_curve(y_test[:,0].astype(int),prob_test)
        # fpr,tpr,thr = roc_curve(test_y.astype(int),prob_test)
        print(auc(fpr,tpr))

        prob_valid = model_merge.predict([valid_x,valid_c])
        prob_valid = prob_valid[:,0]
        probs_valid.append(prob_valid)

    probs_test = np.mean(np.array(probs_test),axis=0)
    probs_valid = np.mean(np.array(probs_valid),axis=0)

    fpr,tpr,thr = roc_curve(valid_y[:,0],probs_valid)
    thr_best = thr[np.argmax(np.subtract(tpr,fpr))]

    # print(np.shape(probs_test))
    # print(probs_test)
    # print(thr_best)    
    # np.savez_compressed('../m2_aae/model/ensemble/nsclc_bimodal_logits.npz',pred=probs_test,thr=thr_best)
    # np.savez_compressed('../m2_aae/model/ensemble/nsclc_bimodal_indep_logits.npz',pred=probs_test)

if __name__ == "__main__":
    args = parse_args()
    main(args)