scan / src / test_scan_nsclc.py
test_scan_nsclc.py
Raw
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
from six.moves import range, zip
import numpy as np
import zhusuan as zs
import argparse
from sklearn.metrics import roc_auc_score,auc,roc_curve,f1_score,accuracy_score
from lifelines.utils import concordance_index
import pickle  # for loading neural network weights

def calc_c_index_benchmark(o_test,y_prob):
    time = np.array(np.expand_dims(o_test,axis=1),dtype=float)
    prob = np.tanh(y_prob)
    mean = np.mean(prob)
    y_score = (prob - mean)
    return concordance_index(time,-y_score)

def calc_acc_score(thr,y_label,y_prob):
    label = 1.0 * np.ones_like(y_prob)
    for i in range(np.shape(label)[0]):
        if y_prob[i] < thr:  label[i] = 0.0
    return accuracy_score(y_label.astype(int),label.astype(int))

class m2_aae:
    def __init__(self,input_dims,hidden_dims,latent_dims,coeffs,learning_rates,n_hs,n_samples):
        self.dim_x,self.dim_c = input_dims[0],input_dims[1]
        self.w_gen_hidden,self.w_dis_hidden = hidden_dims[0],hidden_dims[1]
        self.w_cla_hidden_x,self.w_cla_hidden_c = hidden_dims[2],hidden_dims[3]
        self.dim_z = latent_dims
        self.beta,self.lamb = coeffs[0],coeffs[1]
        self.learning_rate_m2 = learning_rates[0]
        self.n_h,self.n_h_2 = n_hs[0],n_hs[1]
        self.w_gen_size = [self.dim_x] + self.w_gen_hidden + []  # encoder
        self.w_dis_size = [self.n_h] + self.w_dis_hidden + [self.dim_x]  # decoder
        self.w_cla_size_x = [self.dim_x] + self.w_cla_hidden_x + []  # aux classifier x
        self.w_cla_size_c = [self.dim_c] + self.w_cla_hidden_c + []  # aux classifier c
        self.num_lab_m2,self.num_ulab_m2,self.num_ulab_m2_2 = n_samples[0],n_samples[1],n_samples[2]
        self.batch_size = n_samples[3]

    @zs.reuse_variables(scope="q_model")
    def _q_model(x_input,c_input,dim_z,n_h_2,w_gen_size,w_dis_size,w_cla_size_x,w_cla_size_c,n_particles,
                is_subx,is_subc):
        bn_m2  = zs.BayesianNet()  # VAE latent: z
        bn_ws  = zs.BayesianNet()  # p-model decoder weight variational dropout
        with tf.variable_scope('encoder'):  # q(z|x)
            h = x_input
            for i in range(len(w_gen_size)-1):
                h = tf.layers.dense(h,w_gen_size[i+1],name='enc_' + str(i),activation=tf.nn.elu)
            z_mean = tf.layers.dense(h,dim_z,name='enc_mean',activation=tf.nn.elu)
            z_logstd = tf.layers.dense(h,dim_z,name='enc_logstd')  # log should be real numbers
            bn_m2.normal('z',z_mean,logstd=z_logstd,n_samples=n_particles,group_ndims=1)
            bn_m2.deterministic('z_mean_out',z_mean)
            bn_m2.deterministic('z_logstd_out',z_logstd)
        with tf.variable_scope('decoder'):  # p(x|z,y)
            for i in range(len(w_dis_size)-1):
                with tf.variable_scope('layer_dis_' + str(i)):
                    logit_alpha = tf.get_variable('logit_alpha',[w_dis_size[i]])
                    std = tf.sqrt(tf.nn.sigmoid(logit_alpha) + 1E-10)
                    std = tf.tile(tf.expand_dims(std,0),[tf.shape(x_input)[0],1])
                    bn_ws.normal('w_dis_' + str(i),1.0,std=std,n_samples=n_particles,group_ndims=1)
        with tf.variable_scope('merge'):  # q(y|x,c)
            hx = x_input
            for i in range(len(w_cla_size_x)-1):
                hx = tf.layers.dense(hx,w_cla_size_x[i+1],name='x_mer_' + str(i),activation=tf.nn.softplus)
            hc = c_input
            for i in range(len(w_cla_size_c)-1):
                hc = tf.layers.dense(hc,w_cla_size_c[i+1],name='c_mer_' + str(i),activation=tf.nn.softplus)
            hs_x = tf.layers.dense(hx,n_h_2,name='x_to_merge')
            hs_c = tf.layers.dense(hc,n_h_2,name='c_to_merge')
            hs_x = tf.cond(is_subx,lambda: hs_x,lambda: tf.stop_gradient(hs_x * 0.0))
            hs_c = tf.cond(is_subc,lambda: hs_c,lambda: tf.stop_gradient(hs_c * 0.0))
            y_sub = tf.layers.dense(hs_x + hs_c,2,name='merge_output',activation=tf.nn.sigmoid)  # probability
        return bn_m2,bn_ws,y_sub

    @zs.meta_bayesian_net(scope="p_model",reuse_variables=True)
    def _p_model(self,y_input,n_h,dim_z,w_dis_size,n_particles,is_vae):
        '''only one p-model and use `log_joint` to specify latent nodes in elbo'''
        bn = zs.BayesianNet()
        with tf.variable_scope('decoder'):  # p(x|z,y)
            with tf.variable_scope('input_layers'):
                y_logits = tf.zeros([tf.shape(y_input)[0], 2])  # y prior ~ Cat(uniform)
                y = bn.onehot_categorical('y', y_logits)
                z_mean = tf.zeros([tf.shape(y_input)[0], dim_z])  # z prior ~ N(O,I)
                z = bn.normal('z', z_mean, std=1., group_ndims=1,n_samples=n_particles)
                h_from_y = tf.layers.dense(tf.cast(y,tf.float32), n_h,name='y_dec',activation=tf.nn.elu)
                h_from_c = tf.layers.dense(z, n_h,name='h_dec',activation=tf.nn.elu)
                h = tf.nn.softplus(h_from_y + h_from_c)
            for i in range(len(w_dis_size)-1):
                w_mean = tf.ones([w_dis_size[i]])
                w = bn.normal('w_dis_' + str(i),w_mean,std=1.0,n_samples=n_particles,group_ndims=1)
                if i == len(w_dis_size) - 2:
                    with tf.variable_scope('output_layer'):
                        h = tf.cond(is_vae,lambda: h, lambda: h*w)
                        h = tf.layers.dense(h,w_dis_size[i+1],name='w_dis_' + str(i))
                else:
                    with tf.variable_scope('hidden_' + str(i)):
                        h = tf.cond(is_vae,lambda: h, lambda: h*w)
                        h = tf.layers.dense(h,w_dis_size[i+1],name='w_dis_' + str(i),activation=tf.nn.elu)
            x_mean = h
            bn.normal('x_recon', x_mean, std=1., group_ndims=1)
            bn.deterministic('x_recon_out',tf.reduce_mean(x_mean,axis=0))
        return bn

    def build_graph(self):
        # complete labeled data triplets (x,c,y)
        self.x = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='labeled_expression')
        self.c = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_c],name='labeled_clinincal')
        self.y = tf.compat.v1.placeholder(tf.float32,shape=[None],name='labeled_class')
        # unlabeled (most censored) samples (x,c)
        self.x_u = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='unlabeled_expression')
        self.c_u = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_c],name='unlabeled_clinincal')
        # unlabeled (most censored) samples (x)
        self.x_x = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='only_expression')

        self.n_particles = tf.compat.v1.placeholder(tf.int32,shape=[],name='n_particles')
        self.is_training = tf.compat.v1.placeholder(tf.bool,shape=[],name='is_training')  # deprecated

        # ========================================= M2 VAE =============================================== #
        # labeled samples used to build q-model & train auxiliary classifier (fully-supervised fashion)
        bn_m2,bn_ws,self.y_x = self._q_model(
            self.x,self.c,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,self.y_c = self._q_model(
            self.x,self.c,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        # p-model for x_recons given various types of data; q_ones/zeros are built accordingly
        p_model_m2 = self._p_model(
            self.y,self.n_h,
            self.dim_z,self.w_dis_size,self.n_particles,tf.constant(True,dtype=tf.bool))
        p_model_ws = self._p_model(
            self.y,self.n_h,
            self.dim_z,self.w_dis_size,self.n_particles,tf.constant(False,dtype=tf.bool))
        
        # L(x,y)
        y_onehot = tf.one_hot(indices=tf.cast(self.y,tf.int32),depth=2,on_value=1,off_value=0)
        def log_joint_decoder(bn):
            log_px  = bn.cond_log_prob('x_recon')
            log_py  = bn.cond_log_prob('y')
            log_pz  = bn.cond_log_prob('z')
            return log_px + log_py + log_pz
        p_model_m2.log_joint = log_joint_decoder
        self.lb_l_m2 = zs.variational.elbo(
            p_model_m2,observed={'x_recon': self.x,'y':y_onehot},variational=bn_m2,axis=0)
        self.lb_l_m2_sgvb = self.lb_l_m2.sgvb()
        def log_joint_weight(bn):
            w_names = ['w_dis_' + str(i) for i in range(len(self.w_dis_size)-1)]
            log_pws = bn.cond_log_prob(w_names)
            return tf.add_n(log_pws)
        p_model_ws.log_joint = log_joint_weight
        self.lb_l2_ws = zs.variational.elbo(
            p_model_ws,observed={'x_recon': self.x,'y':y_onehot},variational=bn_ws,axis=0)
        self.lb_l2_ws_sgvb = self.lb_l2_ws.sgvb()
        self.lb_l_sgvb = self.lb_l_m2_sgvb * self.num_lab_m2 + self.lb_l2_ws_sgvb
        
        # U(x) = U'(x_u,c_u) + U'(x_x); use all unlabeled data
        # 1. U'(x_u,c_u): use both x_u & c_u for ensemble prediction - q(yu|x_u,c_u)
        y_onehot_ones = tf.one_hot(indices=tf.cast(tf.ones_like(self.x_u[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        y_onehot_zeros = tf.one_hot(indices=tf.cast(tf.zeros_like(self.x_u[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        q_model_ones_m2_xc,q_model_ones_ws_xc,yu_ones_xc_x = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,yu_ones_xc_c = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        q_model_zeros_m2_xc,q_model_zeros_ws_xc,yu_zeros_xc_x = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,yu_zeros_xc_c = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        # ensemble prediction
        yu_ones_xc = 0.5 * yu_ones_xc_x + 0.5 * yu_ones_xc_c
        yu_zeros_xc = 0.5 * yu_zeros_xc_x + 0.5 * yu_zeros_xc_c
        self.lb_z_ones_m2_xc = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_u,'y':y_onehot_ones},variational=q_model_ones_m2_xc,axis=0)
        self.lb_z_ones_ws_xc = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_u,'y':y_onehot_ones},variational=q_model_ones_ws_xc,axis=0)
        self.lb_z_zeros_m2_xc = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_u,'y':y_onehot_zeros},variational=q_model_zeros_m2_xc,axis=0)
        self.lb_z_zeros_ws_xc = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_u,'y':y_onehot_zeros},variational=q_model_zeros_ws_xc,axis=0)
        self.lb_z_m2_xc = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_m2_xc,axis=1),tf.expand_dims(self.lb_z_ones_m2_xc,axis=1)],axis=1)
        self.lb_z_ws_xc = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_ws_xc,axis=1),tf.expand_dims(self.lb_z_ones_ws_xc,axis=1)],axis=1)
        self.yu_xc = tf.concat(
            [tf.expand_dims(yu_zeros_xc[:,1],axis=1),tf.expand_dims(yu_ones_xc[:,1],axis=1)],axis=1)
        qy_u_xc = self.yu_xc / tf.reduce_sum(self.yu_xc, 1, keepdims=True)
        log_qy_u_xc = tf.log(qy_u_xc + 1E-10)
        self.lb_u_m2_xc = tf.reduce_sum(qy_u_xc * (self.lb_z_m2_xc - log_qy_u_xc), 1)
        self.lb_u_ws_xc = tf.reduce_sum(qy_u_xc * (self.lb_z_ws_xc - log_qy_u_xc), 1)
        self.lb_u_sgvb_xc = - (self.lb_u_m2_xc * self.num_ulab_m2 + self.lb_u_ws_xc) 

        # 2. U'(x_x): use only x_x for single prediction - q(yu|x_x)
        y_onehot_ones_2 = tf.one_hot(indices=tf.cast(tf.ones_like(self.x_x[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        y_onehot_zeros_2 = tf.one_hot(indices=tf.cast(tf.zeros_like(self.x_x[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        q_model_ones_m2_x,q_model_ones_ws_x,yu_ones_x = self._q_model(
            self.x_x,tf.ones([tf.shape(self.x_x)[0],self.dim_c]),self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        q_model_zeros_m2_x,q_model_zeros_ws_x,yu_zeros_x = self._q_model(
            self.x_x,tf.ones([tf.shape(self.x_x)[0],self.dim_c]),self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        self.lb_z_ones_m2_x = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_x,'y':y_onehot_ones_2},variational=q_model_ones_m2_x,axis=0)
        self.lb_z_ones_ws_x = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_x,'y':y_onehot_ones_2},variational=q_model_ones_ws_x,axis=0)
        self.lb_z_zeros_m2_x = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_x,'y':y_onehot_zeros_2},variational=q_model_zeros_m2_x,axis=0)
        self.lb_z_zeros_ws_x = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_x,'y':y_onehot_zeros_2},variational=q_model_zeros_ws_x,axis=0)
        self.lb_z_m2_x = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_m2_x,axis=1),tf.expand_dims(self.lb_z_ones_m2_x,axis=1)],axis=1)
        self.lb_z_ws_x = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_ws_x,axis=1),tf.expand_dims(self.lb_z_ones_ws_x,axis=1)],axis=1)
        # single prediction
        self.yu_x = tf.concat(
            [tf.expand_dims(yu_zeros_x[:,1],axis=1),tf.expand_dims(yu_ones_x[:,1],axis=1)],axis=1)
        qy_u_x = self.yu_x / tf.reduce_sum(self.yu_x, 1, keepdims=True)
        log_qy_u_x = tf.log(qy_u_x + 1E-10)
        self.lb_u_m2_x = tf.reduce_sum(qy_u_x * (self.lb_z_m2_x - log_qy_u_x), 1)
        self.lb_u_ws_x = tf.reduce_sum(qy_u_x * (self.lb_z_ws_x - log_qy_u_x), 1)
        self.lb_u_sgvb_x = - (self.lb_u_m2_x * self.num_ulab_m2_2 + self.lb_u_ws_x)

        # 3. U(x) = U'(x_u,c_u) + U'(x_x)
        self.lb_u_sgvb = self.lb_u_sgvb_xc + self.lb_u_sgvb_x
        
        # auxiliary classifiers for x and c (using labeled triplets to train)
        onehot_cat_x = zs.distributions.OnehotCategorical(self.y_x)
        self.log_qy_x = onehot_cat_x.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))
        onehot_cat_c = zs.distributions.OnehotCategorical(self.y_c)
        self.log_qy_c = onehot_cat_c.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))
        # ensemble prediction and then calculate log-probability
        self.y_xc = 0.5 * self.y_x + 0.5 * self.y_c
        onehot_cat_xc = zs.distributions.OnehotCategorical(self.y_xc)
        self.log_qy_xc = onehot_cat_xc.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))

        # overall loss; following original paper, L/U are added directly while `cost_m2_aux` is weighted
        cost_m2_lxy = tf.reduce_mean(self.lb_l_sgvb)
        cost_m2_ux  = tf.reduce_mean(self.lb_u_sgvb)
        cost_m2_aux = tf.reduce_mean(self.log_qy_xc)
        self.cost_m2 = cost_m2_lxy - self.beta * cost_m2_aux + cost_m2_ux
        self.lb_l = self.lb_l_m2_sgvb  # observing reconstruction

        # ======================================== reconstructions ===========================================
        self.z_mean = bn_m2.get('z_mean_out')
        self.z_logstd = bn_m2.get('z_logstd_out')
        self.x_recon_q = self.lb_l_m2.bn['x_recon_out']  # observing z ~ q(z|x)

        z_sample = tf.random.normal(
            shape=tf.shape(bn_m2.get('z')),mean=self.z_mean,stddev=tf.exp(self.z_logstd),seed=0)
        lb_n = zs.variational.elbo(p_model_m2,
            observed={'y':y_onehot,'z':z_sample},variational=bn_m2,axis=0)
        self.x_recon_n = lb_n.bn['x_recon_out']  # observing z ~ N(0.1)

        # ========================== regularization & optimizers ========================================= #
        dec_m2 = tf.trainable_variables(scope='p_model/decoder')    # variational dropout (m2)
        dec_ws = tf.trainable_variables(scope='p_model_1/decoder')  # variational dropout (ws)
        dec_var = tf.trainable_variables(scope='q_model/decoder')   # tf.dense.layers
        dec_list = dec_ws + dec_m2 + dec_var
        enc_list = tf.trainable_variables(scope='q_model/encoder')  # tf.dense.layers
        mer_list = tf.trainable_variables(scope='q_model/merge')    # tf.dense.layers
        def get_vec_rms(var):
            return tf.reduce_mean(tf.sqrt(tf.reduce_mean(tf.expand_dims(tf.square(var),0),axis=1) + 1E-10))
        def get_mat_rms(var):
            return tf.reduce_mean(tf.sqrt(tf.reduce_mean(tf.square(var),axis=1) + 1E-10))
        enc_reg,dec_reg,mer_reg = 0.0,0.0,0.0
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/encoder'):
            if len(np.shape(var)) == 1:  enc_reg += get_vec_rms(var)
            else:  enc_reg += get_mat_rms(var)
        enc_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/encoder'))
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='p_model/decoder'):
            if len(np.shape(var)) == 1:  dec_reg += get_vec_rms(var)
            else:  dec_reg += get_mat_rms(var)
        dec_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='p_model/decoder'))
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/merge'):
            if len(np.shape(var)) == 1:  mer_reg += get_vec_rms(var)
            else:  mer_reg += get_mat_rms(var)
        mer_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/merge'))

        optimizer_m2 = tf.compat.v1.train.RMSPropOptimizer(learning_rate=self.learning_rate_m2)
        grads_m2,norm_m2 = tf.clip_by_global_norm(tf.gradients(
            self.cost_m2 + self.lamb * (enc_reg + dec_reg + mer_reg),
            enc_list + dec_list + mer_list),5)
        self.infer_op_m2 = optimizer_m2.apply_gradients(zip(grads_m2,enc_list + dec_list + mer_list))

def parse_args():
    '''
        python3 ../src/train_m2_fullvb_bimodal_cv.py  
        --w_gen_hidden 12  --w_dis_hidden 8  --w_cla_hidden_x 12  
        --dim_z 5 --n_h 10 --n_h_2 5 --beta 0.1 --lamb 0.1 
        --epochs 100 --split_num 3 --count 972 &
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--w_gen_hidden',default=[12],nargs='+',type=int,help='q-model hidden')
    parser.add_argument('--w_dis_hidden',default=[8],nargs='+',type=int,help='p-model hidden')
    parser.add_argument('--w_cla_hidden_x',default=[12],nargs='+',type=int,help='classifier Ax hidden')
    parser.add_argument('--w_cla_hidden_c',default=[],nargs='+',type=int,help='classifier Ac hidden')
    parser.add_argument('--dim_z',default=5,type=int,help='latent dimension')
    parser.add_argument('--n_h',default=10,type=int,help='latent input transformation dimension')
    parser.add_argument('--n_h_2',default=5,type=int,help='latent input transformation dimension')
    
    parser.add_argument('--beta',type=float,default=0.1,help='classification regularization constant in M2 U(x)')
    parser.add_argument('--lamb',type=float,default=0.1,help='weight L2-regulatization constant')
    parser.add_argument('--n_particles_train',type=int,default=50)
    parser.add_argument('--n_particles_test',type=int,default=500)
    parser.add_argument('--batch_size',type=int,default=32)
    parser.add_argument('--learning_rate_m2',type=float,default=0.01)
    parser.add_argument('--std',default=1.0)
    parser.add_argument('--epochs',type=int,default=100,help='total epochs')

    parser.add_argument('--data_path',default='../data/nsclc/nsclc_')
    parser.add_argument('--udata_path',default='../data/nsclc/nsclc_unlabeled')
    parser.add_argument('--split_num',default=3)

    return parser.parse_args()

def main(args):
    with tf.Graph().as_default():
        tf.compat.v1.set_random_seed(1237)
        np.random.seed(1234)

        data = np.load(args.data_path + str(args.split_num) + '.npz',allow_pickle=True)
        data_unlabeled = np.load(args.udata_path + '.npz',allow_pickle=True)
        x_w_full = (data_unlabeled['x_w_full'] - data['x_mean']) / data['x_scale']
        c_w_full = (data_unlabeled['c_w_full'] - data['c_mean']) / data['c_scale']
        x_n_full = (data_unlabeled['x_n_full'] - data['x_mean']) / data['x_scale']
        c_n_full = (data_unlabeled['c_n_full'] - data['c_mean']) / data['c_scale']

        x_train,x_valid,x_test = data['x_train'],data['x_valid'],data['x_test']
        c_train,c_valid,c_test = data['c_train'],data['c_valid'],data['c_test']
        y_train,y_valid,y_test = data['y_train'],data['y_valid'],data['y_test']
        o_train,o_valid,o_test = data['o_train'],data['o_valid'],data['o_test']
        e_train,e_valid,e_test = data['e_train'],data['e_valid'],data['e_test']

        # after 4-CV, use (train+valid) & best CV hyper-params to retrain the model
        x_train = np.concatenate((x_train,x_valid),axis=0)
        c_train = np.concatenate((c_train,c_valid),axis=0)
        y_train = np.concatenate((y_train,y_valid))
        o_train = np.concatenate((o_train,o_valid),axis=0)

        num_lab_m2 = np.shape(x_train)[0]      # #(x,c,y)
        num_ulab_m2 = np.shape(x_w_full)[0]    # #(x,c)
        num_ulab_m2_2 = np.shape(x_n_full)[0]  # #(x)
        num_lab_batch = args.batch_size
        num_ulab_batch = args.batch_size

        model = m2_aae(
            input_dims=[np.shape(x_train)[1],np.shape(c_train)[1]],
            hidden_dims=[args.w_gen_hidden,args.w_dis_hidden,args.w_cla_hidden_x,args.w_cla_hidden_c],
            latent_dims=args.dim_z,
            coeffs=[args.beta,args.lamb],
            learning_rates=[args.learning_rate_m2],
            n_hs=[args.n_h,args.n_h_2],
            n_samples=[num_lab_m2,num_ulab_m2,num_ulab_m2_2,args.batch_size])
        model.build_graph()

        train_dict,valid_dict,test_dict = {},{},{}
        saver = tf.compat.v1.train.Saver()  # model checkpoint
        
        # gpu configuration
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            # sess.run(tf.global_variables_initializer())
            saver.restore(sess,'../model/scan/nsclc.ckpt')
            print('Model restored')

            # # independent validation
            # data_indep = np.load('../data/nsclc/indep.npz',allow_pickle=True)
            # x_test = (data_indep['x_test'] - data['x_mean']) / data['x_scale']
            # c_test = (data_indep['c_test'] - data['c_mean']) / data['c_scale']
            # y_test = data_indep['y_test']
            # test_dict_test = {
            #     model.x: x_test,model.c: c_test,model.y: y_test,
            #     model.n_particles: args.n_particles_test}
            # y_x_logit,y_c_logit,y_xc_logit = sess.run(
            #     [model.y_x,model.y_c,model.y_xc],feed_dict=test_dict_test)
            # np.savez_compressed('../model/scan/nsclc_indep_logits.npz',
            #     y_x_logit=y_x_logit,y_c_logit=y_c_logit,y_xc_logit=y_xc_logit)
            # return
              
            # get predictions
            test_dict_test = {
                model.x: x_test,model.c: c_test,model.y: y_test,
                model.n_particles: args.n_particles_test} 
            y_x_logit,y_c_logit,y_xc_logit = sess.run(
                [model.y_x,model.y_c,model.y_xc],feed_dict=test_dict_test)

            valid_dict_test = {
                model.x: x_valid,model.c: c_valid,model.y: y_valid,
                model.n_particles: args.n_particles_test} 
            y_x_logit_val,y_c_logit_val,y_xc_logit_val = sess.run(
                [model.y_x,model.y_c,model.y_xc],feed_dict=valid_dict_test)
            fpr,tpr,thr = roc_curve(y_valid.astype(int),(np.nan_to_num(y_x_logit_val[:,1])),pos_label=1)
            thr_best_x = thr[np.argmax(np.subtract(tpr,fpr))]            
            fpr,tpr,thr = roc_curve(y_valid.astype(int),(np.nan_to_num(y_c_logit_val[:,1])),pos_label=1)
            thr_best_c = thr[np.argmax(np.subtract(tpr,fpr))]
            fpr,tpr,thr = roc_curve(y_valid.astype(int),(np.nan_to_num(y_xc_logit_val[:,1])),pos_label=1)
            thr_best_xc = thr[np.argmax(np.subtract(tpr,fpr))]

            np.savez_compressed('../model/scan/nsclc_logits.npz',
                y_x_logit=y_x_logit,y_c_logit=y_c_logit,y_xc_logit=y_xc_logit,
                thr_best_x=thr_best_x,thr_best_c=thr_best_c,thr_best_xc=thr_best_xc)

            # learned VAE latent representations
            z_mean,z_logstd = sess.run([model.z_mean,model.z_logstd],feed_dict=test_dict_test)
            np.savez_compressed('../model/scan/nsclc_z.npz',
                z_mean=z_mean,z_logstd=z_logstd)

            # get learned weights
            wp_m2,wp_ws,wq_merge = {},{},{}
            for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
                last_name  = var.name.split('/')[-1]
                layer_name = var.name.split('/')[-2]
                if last_name == 'kernel:0':  # if it's kernel (weight)
                    dom_name = var.name.split('/')[0]  # [q_model|p_model|p_model_1]
                    if dom_name == 'q_model':  wq_merge[layer_name] = sess.run(var)
                    if dom_name == 'p_model':  wp_m2[layer_name] = sess.run(var)
                    if dom_name == 'p_model_1':  wp_ws[layer_name] = sess.run(var)

            with open('../model/scan/nsclc_wp_m2.p','wb') as fp:
                pickle.dump(wp_m2,fp,protocol=pickle.HIGHEST_PROTOCOL)
            with open('../model/scan/nsclc_wp_ws.p','wb') as fp:
                pickle.dump(wp_ws,fp,protocol=pickle.HIGHEST_PROTOCOL)
            with open('../model/scan/nsclc_wq_merge.p','wb') as fp:
                pickle.dump(wq_merge,fp,protocol=pickle.HIGHEST_PROTOCOL)
            

if __name__ == "__main__":
    args = parse_args()
    main(args)