scan / src / nsclc_bimodal / merge_cv.py
merge_cv.py
Raw
from keras.models import Sequential
import scipy.io
import numpy as np
from keras.regularizers import l2#, activity_l2
from keras.layers import Dense, Activation
from keras.layers import Dropout, Flatten, Dense
from keras.utils.np_utils import to_categorical
from keras.optimizers import SGD, Nadam
import random
import numpy as np
import matplotlib.pyplot as plt
#from keras.utils.visualize_util import plot
#from keras.engine.topology import merge
from keras.models import Sequential
from keras.layers import Dense, Activation
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Activation
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers import Dropout, Flatten, Dense
from keras.utils.np_utils import to_categorical
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.models import model_from_json
#from keras.layers import Merge,merge
from keras.layers import concatenate
from keras.models import Input,Model

from sklearn.metrics import roc_auc_score,auc,roc_curve
from keras.callbacks.callbacks import Callback

import argparse
import tensorflow as tf
from sklearn import preprocessing

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed',default=0,type=int,help='random seed')
    parser.add_argument('--i',default=0,type=int,help='num_of_ensemble')
    return parser.parse_args()

def create_model(is_x,is_c,is_b,i):
    x_input = Input(shape=(15,),dtype='float32')
    left_branch = Dense(40,input_dim=15,W_regularizer=l2(0.001),name='pre_x_input',trainable=is_x)(x_input)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h1',trainable=is_x)(left_branch)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h2',trainable=is_x)(left_branch)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(40,W_regularizer=l2(0.001),name='pre_x_h3',trainable=is_x)(left_branch)
    left_branch = Activation('relu')(left_branch)
    left_branch = Dense(2,name='pre_x_out',trainable=is_x)(left_branch)
    left_branch = Activation('softmax')(left_branch)
    model_x = Model(inputs=x_input,outputs=left_branch)
    if not is_x:  model_x.load_weights('breast_bimodal_ensemble/weights_x_' + str(i) + '.hdf5')
    else:  return model_x

    c_input = Input(shape=(7,),dtype='float32')
    right_branch = Dense(18,input_dim=7,W_regularizer=l2(0.001),name='pre_c_input',trainable=is_c)(c_input)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h1',trainable=is_c)(right_branch)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h2',trainable=is_c)(right_branch)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(18,W_regularizer=l2(0.001),name='pre_c_h3',trainable=is_c)(right_branch)
    right_branch = Activation('relu')(right_branch)
    right_branch = Dense(2,name='pre_c_out',trainable=is_c)(right_branch)
    right_branch = Activation('softmax')(right_branch)
    model_c = Model(inputs=c_input,outputs=right_branch)
    if not is_c:  model_c.load_weights('breast_bimodal_ensemble/weights_c_' + str(i) + '.hdf5')
    else:  return model_c

    merged = concatenate([left_branch,right_branch],axis=-1)
    merged = Dense(58,W_regularizer=l2(0.001),name='merge_h1',trainable=is_b)(merged)
    merged = Activation('relu')(merged)
    merged = Dense(32,W_regularizer=l2(0.001),name='merge_h2',trainable=is_b)(merged)
    merged = Activation('relu')(merged)
    merged = Dense(32,W_regularizer=l2(0.001),name='merge_h3',trainable=is_b)(merged)
    merged = Activation('relu')(merged)
    merged = Dense(2, activation='softmax',trainable=is_b)(merged)
    merged_model = Model(inputs=[x_input,c_input],outputs=merged)
    return merged_model    

def main(args):
    tf.compat.v1.set_random_seed(args.seed)
    np.random.seed(args.seed)

    data = scipy.io.loadmat('Data_all.mat')
    train_x = data['train_x'].astype('float32')
    train_c = data['clinical_train_x'].astype('float32')
    train_y = data['train_y'].astype(int)
    test_x  = data['test_x'].astype('float32')
    test_c  = data['clinical_test_x'].astype('float32')
    test_y  = data['test_y'].astype(int)

    n_classes = 2
    train_y = train_y -1
    test_y = test_y -1
    y_train = to_categorical(train_y, n_classes)
    y_test = to_categorical(test_y, n_classes)
    valid_num = int(np.shape(train_x)[0] * 0.25)
    valid_x = train_x[-valid_num:,:]
    valid_c = train_c[-valid_num:,:]
    valid_y = y_train[-valid_num:,:]

    nadam = Nadam(lr=0.006,beta_1=0.9,beta_2=0.999,epsilon=1e-08,schedule_decay=0.004)
    ES = EarlyStopping(monitor='val_loss',patience=30,verbose=0,mode='min')
    checkpointer_x = ModelCheckpoint(
        filepath='breast_bimodal_ensemble/weights_x_' + str(args.i) + '.hdf5',
        verbose=1,save_best_only=True,monitor='val_loss',mode='min')
    checkpointer_c = ModelCheckpoint(
        filepath='breast_bimodal_ensemble/weights_c_' + str(args.i) + '.hdf5',
        verbose=1,save_best_only=True,monitor='val_loss',mode='min')
    checkpointer_merge = ModelCheckpoint(
        filepath='breast_bimodal_ensemble/weights_merge_' + str(args.i) + '.hdf5',
        verbose=1,save_best_only=True,monitor='val_loss',mode='min')

    model_x = create_model(True,False,False,args.i)
    model_x.compile(loss='categorical_crossentropy',optimizer=nadam,metrics=['accuracy'])
    model_x.fit(train_x,y_train,batch_size=20,nb_epoch=100,
        validation_data=(valid_x,valid_y),callbacks=[ES,checkpointer_x])
    model_x.load_weights('breast_bimodal_ensemble/weights_x_' + str(args.i) + '.hdf5')
    loss,accuracy = model_x.evaluate(test_x,y_test)
    print('loss:', loss)
    print('accuracy:', accuracy)
    y_pred = model_x.predict(test_x)
    fpr,tpr,thr = roc_curve(y_test[:,0],y_pred[:,0])
    auc_test = auc(fpr,tpr)
    print(auc_test)

    model_c = create_model(False,True,False,args.i)
    model_c.compile(loss='categorical_crossentropy',optimizer=nadam,metrics=['accuracy'])
    model_c.fit(train_c,y_train,batch_size=20,nb_epoch=100,
        validation_data=(valid_c,valid_y),callbacks=[ES,checkpointer_c])
    # model_c.fit(train_c,y_train,batch_size=20,nb_epoch=100,
    #     validation_split=0.25,callbacks=[ES,checkpointer_c])
    model_c.load_weights('breast_bimodal_ensemble/weights_c_' + str(args.i) + '.hdf5')
    loss,accuracy = model_c.evaluate(test_c,y_test)
    print('loss:', loss)
    print('accuracy:', accuracy)
    y_pred = model_c.predict(test_c)
    fpr,tpr,thr = roc_curve(y_test[:,0],y_pred[:,0])
    auc_test = auc(fpr,tpr)
    print(auc_test)

    model_merge = create_model(False,False,True,args.i)
    model_merge.compile(loss='categorical_crossentropy',optimizer=nadam,metrics=['accuracy'])
    model_merge.fit([train_x,train_c],y_train,batch_size=20,nb_epoch=100,
        validation_data=([valid_x,valid_c],valid_y),callbacks=[ES,checkpointer_merge])
    model_merge.load_weights('breast_bimodal_ensemble/weights_merge_' + str(args.i) + '.hdf5')
    
    loss,accuracy = model_merge.evaluate([test_x,test_c],y_test)
    print('loss:', loss)
    print('accuracy:', accuracy)
    y_pred = model_merge.predict([test_x,test_c])
    fpr,tpr,thr = roc_curve(y_test[:,0],y_pred[:,0])
    auc_test = auc(fpr,tpr)
    print(auc_test)

if __name__ == "__main__":
    args = parse_args()
    main(args)