scan / src / train_scan_nsclc.py
train_scan_nsclc.py
Raw
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
from six.moves import range, zip
import numpy as np
import zhusuan as zs
import argparse
from sklearn.metrics import roc_auc_score,auc,roc_curve,f1_score,accuracy_score
from lifelines.utils import concordance_index

def calc_c_index_benchmark(o_test,y_prob):
    time = np.array(np.expand_dims(o_test,axis=1),dtype=float)
    prob = np.tanh(y_prob)
    mean = np.mean(prob)
    y_score = (prob - mean)
    return concordance_index(time,-y_score)

def calc_acc_score(thr,y_label,y_prob):
    label = 1.0 * np.ones_like(y_prob)
    for i in range(np.shape(label)[0]):
        if y_prob[i] < thr:  label[i] = 0.0
    return accuracy_score(y_label.astype(int),label.astype(int))

class m2_aae:
    def __init__(self,input_dims,hidden_dims,latent_dims,coeffs,learning_rates,n_hs,n_samples):
        self.dim_x,self.dim_c = input_dims[0],input_dims[1]
        self.w_gen_hidden,self.w_dis_hidden = hidden_dims[0],hidden_dims[1]
        self.w_cla_hidden_x,self.w_cla_hidden_c = hidden_dims[2],hidden_dims[3]
        self.dim_z = latent_dims
        self.beta,self.lamb = coeffs[0],coeffs[1]
        self.learning_rate_m2 = learning_rates[0]
        self.n_h,self.n_h_2 = n_hs[0],n_hs[1]
        self.w_gen_size = [self.dim_x] + self.w_gen_hidden + []  # encoder
        self.w_dis_size = [self.n_h] + self.w_dis_hidden + [self.dim_x]  # decoder
        self.w_cla_size_x = [self.dim_x] + self.w_cla_hidden_x + []  # aux classifier x
        self.w_cla_size_c = [self.dim_c] + self.w_cla_hidden_c + []  # aux classifier c
        self.num_lab_m2,self.num_ulab_m2,self.num_ulab_m2_2 = n_samples[0],n_samples[1],n_samples[2]
        self.batch_size = n_samples[3]

    @zs.reuse_variables(scope="q_model")
    def _q_model(x_input,c_input,dim_z,n_h_2,w_gen_size,w_dis_size,w_cla_size_x,w_cla_size_c,n_particles,
                is_subx,is_subc):
        bn_m2  = zs.BayesianNet()  # VAE latent: z
        bn_ws  = zs.BayesianNet()  # p-model decoder weight variational dropout
        with tf.variable_scope('encoder'):  # q(z|x)
            h = x_input
            for i in range(len(w_gen_size)-1):
                h = tf.layers.dense(h,w_gen_size[i+1],name='enc_' + str(i),activation=tf.nn.elu)
            z_mean = tf.layers.dense(h,dim_z,name='enc_mean',activation=tf.nn.elu)
            z_logstd = tf.layers.dense(h,dim_z,name='enc_logstd')  # log should be real numbers
            bn_m2.normal('z',z_mean,logstd=z_logstd,n_samples=n_particles,group_ndims=1)
            bn_m2.deterministic('z_mean_out',z_mean)
            bn_m2.deterministic('z_logstd_out',z_logstd)
        with tf.variable_scope('decoder'):  # p(x|z,y)
            for i in range(len(w_dis_size)-1):
                with tf.variable_scope('layer_dis_' + str(i)):
                    logit_alpha = tf.get_variable('logit_alpha',[w_dis_size[i]])
                    std = tf.sqrt(tf.nn.sigmoid(logit_alpha) + 1E-10)
                    std = tf.tile(tf.expand_dims(std,0),[tf.shape(x_input)[0],1])
                    bn_ws.normal('w_dis_' + str(i),1.0,std=std,n_samples=n_particles,group_ndims=1)
        with tf.variable_scope('merge'):  # q(y|x,c)
            hx = x_input
            for i in range(len(w_cla_size_x)-1):
                hx = tf.layers.dense(hx,w_cla_size_x[i+1],name='x_mer_' + str(i),activation=tf.nn.softplus)
            hc = c_input
            for i in range(len(w_cla_size_c)-1):
                hc = tf.layers.dense(hc,w_cla_size_c[i+1],name='c_mer_' + str(i),activation=tf.nn.softplus)
            hs_x = tf.layers.dense(hx,n_h_2,name='x_to_merge')
            hs_c = tf.layers.dense(hc,n_h_2,name='c_to_merge')
            hs_x = tf.cond(is_subx,lambda: hs_x,lambda: tf.stop_gradient(hs_x * 0.0))
            hs_c = tf.cond(is_subc,lambda: hs_c,lambda: tf.stop_gradient(hs_c * 0.0))
            y_sub = tf.layers.dense(hs_x + hs_c,2,name='merge_output',activation=tf.nn.sigmoid)  # probability
        return bn_m2,bn_ws,y_sub

    @zs.meta_bayesian_net(scope="p_model",reuse_variables=True)
    def _p_model(self,y_input,n_h,dim_z,w_dis_size,n_particles,is_vae):
        '''only one p-model and use `log_joint` to specify latent nodes in elbo'''
        bn = zs.BayesianNet()
        with tf.variable_scope('decoder'):  # p(x|z,y)
            with tf.variable_scope('input_layers'):
                y_logits = tf.zeros([tf.shape(y_input)[0], 2])  # y prior ~ Cat(uniform)
                y = bn.onehot_categorical('y', y_logits)
                z_mean = tf.zeros([tf.shape(y_input)[0], dim_z])  # z prior ~ N(O,I)
                z = bn.normal('z', z_mean, std=1., group_ndims=1,n_samples=n_particles)
                h_from_y = tf.layers.dense(tf.cast(y,tf.float32), n_h,name='y_dec',activation=tf.nn.elu)
                h_from_c = tf.layers.dense(z, n_h,name='h_dec',activation=tf.nn.elu)
                h = tf.nn.softplus(h_from_y + h_from_c)
            for i in range(len(w_dis_size)-1):
                w_mean = tf.ones([w_dis_size[i]])
                w = bn.normal('w_dis_' + str(i),w_mean,std=1.0,n_samples=n_particles,group_ndims=1)
                if i == len(w_dis_size) - 2:
                    with tf.variable_scope('output_layer'):
                        h = tf.cond(is_vae,lambda: h, lambda: h*w)
                        h = tf.layers.dense(h,w_dis_size[i+1],name='w_dis_' + str(i))
                else:
                    with tf.variable_scope('hidden_' + str(i)):
                        h = tf.cond(is_vae,lambda: h, lambda: h*w)
                        h = tf.layers.dense(h,w_dis_size[i+1],name='w_dis_' + str(i),activation=tf.nn.elu)
            x_mean = h
            bn.normal('x_recon', x_mean, std=1., group_ndims=1)
            bn.deterministic('x_recon_out',tf.reduce_mean(x_mean,axis=0))
        return bn

    def build_graph(self):
        # complete labeled data triplets (x,c,y)
        self.x = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='labeled_expression')
        self.c = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_c],name='labeled_clinincal')
        self.y = tf.compat.v1.placeholder(tf.float32,shape=[None],name='labeled_class')
        # unlabeled (most censored) samples (x,c)
        self.x_u = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='unlabeled_expression')
        self.c_u = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_c],name='unlabeled_clinincal')
        # unlabeled (most censored) samples (x)
        self.x_x = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='only_expression')

        self.n_particles = tf.compat.v1.placeholder(tf.int32,shape=[],name='n_particles')
        self.is_training = tf.compat.v1.placeholder(tf.bool,shape=[],name='is_training')  # deprecated

        # ========================================= M2 VAE =============================================== #
        # labeled samples used to build q-model & train auxiliary classifier (fully-supervised fashion)
        bn_m2,bn_ws,self.y_x = self._q_model(
            self.x,self.c,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,self.y_c = self._q_model(
            self.x,self.c,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        # p-model for x_recons given various types of data; q_ones/zeros are built accordingly
        p_model_m2 = self._p_model(
            self.y,self.n_h,
            self.dim_z,self.w_dis_size,self.n_particles,tf.constant(True,dtype=tf.bool))
        p_model_ws = self._p_model(
            self.y,self.n_h,
            self.dim_z,self.w_dis_size,self.n_particles,tf.constant(False,dtype=tf.bool))
        
        # L(x,y)
        y_onehot = tf.one_hot(indices=tf.cast(self.y,tf.int32),depth=2,on_value=1,off_value=0)
        def log_joint_decoder(bn):
            log_px  = bn.cond_log_prob('x_recon')
            log_py  = bn.cond_log_prob('y')
            log_pz  = bn.cond_log_prob('z')
            return log_px + log_py + log_pz
        p_model_m2.log_joint = log_joint_decoder
        self.lb_l_m2 = zs.variational.elbo(
            p_model_m2,observed={'x_recon': self.x,'y':y_onehot},variational=bn_m2,axis=0)
        self.lb_l_m2_sgvb = self.lb_l_m2.sgvb()
        def log_joint_weight(bn):
            w_names = ['w_dis_' + str(i) for i in range(len(self.w_dis_size)-1)]
            log_pws = bn.cond_log_prob(w_names)
            return tf.add_n(log_pws)
        p_model_ws.log_joint = log_joint_weight
        self.lb_l2_ws = zs.variational.elbo(
            p_model_ws,observed={'x_recon': self.x,'y':y_onehot},variational=bn_ws,axis=0)
        self.lb_l2_ws_sgvb = self.lb_l2_ws.sgvb()
        self.lb_l_sgvb = self.lb_l_m2_sgvb * self.num_lab_m2 + self.lb_l2_ws_sgvb
        
        # U(x) = U'(x_u,c_u) + U'(x_x); use all unlabeled data
        # 1. U'(x_u,c_u): use both x_u & c_u for ensemble prediction - q(yu|x_u,c_u)
        y_onehot_ones = tf.one_hot(indices=tf.cast(tf.ones_like(self.x_u[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        y_onehot_zeros = tf.one_hot(indices=tf.cast(tf.zeros_like(self.x_u[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        q_model_ones_m2_xc,q_model_ones_ws_xc,yu_ones_xc_x = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,yu_ones_xc_c = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        q_model_zeros_m2_xc,q_model_zeros_ws_xc,yu_zeros_xc_x = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,yu_zeros_xc_c = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        # ensemble prediction
        yu_ones_xc = 0.5 * yu_ones_xc_x + 0.5 * yu_ones_xc_c
        yu_zeros_xc = 0.5 * yu_zeros_xc_x + 0.5 * yu_zeros_xc_c
        self.lb_z_ones_m2_xc = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_u,'y':y_onehot_ones},variational=q_model_ones_m2_xc,axis=0)
        self.lb_z_ones_ws_xc = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_u,'y':y_onehot_ones},variational=q_model_ones_ws_xc,axis=0)
        self.lb_z_zeros_m2_xc = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_u,'y':y_onehot_zeros},variational=q_model_zeros_m2_xc,axis=0)
        self.lb_z_zeros_ws_xc = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_u,'y':y_onehot_zeros},variational=q_model_zeros_ws_xc,axis=0)
        self.lb_z_m2_xc = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_m2_xc,axis=1),tf.expand_dims(self.lb_z_ones_m2_xc,axis=1)],axis=1)
        self.lb_z_ws_xc = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_ws_xc,axis=1),tf.expand_dims(self.lb_z_ones_ws_xc,axis=1)],axis=1)
        self.yu_xc = tf.concat(
            [tf.expand_dims(yu_zeros_xc[:,1],axis=1),tf.expand_dims(yu_ones_xc[:,1],axis=1)],axis=1)
        qy_u_xc = self.yu_xc / tf.reduce_sum(self.yu_xc, 1, keepdims=True)
        log_qy_u_xc = tf.log(qy_u_xc + 1E-10)
        self.lb_u_m2_xc = tf.reduce_sum(qy_u_xc * (self.lb_z_m2_xc - log_qy_u_xc), 1)
        self.lb_u_ws_xc = tf.reduce_sum(qy_u_xc * (self.lb_z_ws_xc - log_qy_u_xc), 1)
        self.lb_u_sgvb_xc = - (self.lb_u_m2_xc * self.num_ulab_m2 + self.lb_u_ws_xc) 

        # 2. U'(x_x): use only x_x for single prediction - q(yu|x_x)
        y_onehot_ones_2 = tf.one_hot(indices=tf.cast(tf.ones_like(self.x_x[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        y_onehot_zeros_2 = tf.one_hot(indices=tf.cast(tf.zeros_like(self.x_x[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        q_model_ones_m2_x,q_model_ones_ws_x,yu_ones_x = self._q_model(
            self.x_x,tf.ones([tf.shape(self.x_x)[0],self.dim_c]),self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        q_model_zeros_m2_x,q_model_zeros_ws_x,yu_zeros_x = self._q_model(
            self.x_x,tf.ones([tf.shape(self.x_x)[0],self.dim_c]),self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        self.lb_z_ones_m2_x = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_x,'y':y_onehot_ones_2},variational=q_model_ones_m2_x,axis=0)
        self.lb_z_ones_ws_x = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_x,'y':y_onehot_ones_2},variational=q_model_ones_ws_x,axis=0)
        self.lb_z_zeros_m2_x = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_x,'y':y_onehot_zeros_2},variational=q_model_zeros_m2_x,axis=0)
        self.lb_z_zeros_ws_x = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_x,'y':y_onehot_zeros_2},variational=q_model_zeros_ws_x,axis=0)
        self.lb_z_m2_x = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_m2_x,axis=1),tf.expand_dims(self.lb_z_ones_m2_x,axis=1)],axis=1)
        self.lb_z_ws_x = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_ws_x,axis=1),tf.expand_dims(self.lb_z_ones_ws_x,axis=1)],axis=1)
        # single prediction
        self.yu_x = tf.concat(
            [tf.expand_dims(yu_zeros_x[:,1],axis=1),tf.expand_dims(yu_ones_x[:,1],axis=1)],axis=1)
        qy_u_x = self.yu_x / tf.reduce_sum(self.yu_x, 1, keepdims=True)
        log_qy_u_x = tf.log(qy_u_x + 1E-10)
        self.lb_u_m2_x = tf.reduce_sum(qy_u_x * (self.lb_z_m2_x - log_qy_u_x), 1)
        self.lb_u_ws_x = tf.reduce_sum(qy_u_x * (self.lb_z_ws_x - log_qy_u_x), 1)
        self.lb_u_sgvb_x = - (self.lb_u_m2_x * self.num_ulab_m2_2 + self.lb_u_ws_x)

        # 3. U(x) = U'(x_u,c_u) + U'(x_x)
        self.lb_u_sgvb = self.lb_u_sgvb_xc + self.lb_u_sgvb_x
        
        # auxiliary classifiers for x and c (using labeled triplets to train)
        onehot_cat_x = zs.distributions.OnehotCategorical(self.y_x)
        self.log_qy_x = onehot_cat_x.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))
        onehot_cat_c = zs.distributions.OnehotCategorical(self.y_c)
        self.log_qy_c = onehot_cat_c.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))
        # ensemble prediction and then calculate log-probability
        self.y_xc = 0.5 * self.y_x + 0.5 * self.y_c
        onehot_cat_xc = zs.distributions.OnehotCategorical(self.y_xc)
        self.log_qy_xc = onehot_cat_xc.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))

        # overall loss; following original paper, L/U are added directly while `cost_m2_aux` is weighted
        cost_m2_lxy = tf.reduce_mean(self.lb_l_sgvb)
        cost_m2_ux  = tf.reduce_mean(self.lb_u_sgvb)
        cost_m2_aux = tf.reduce_mean(self.log_qy_xc)
        self.cost_m2 = cost_m2_lxy - self.beta * cost_m2_aux + cost_m2_ux
        self.lb_l = self.lb_l_m2_sgvb  # observing reconstruction

        # ======================================== reconstructions ===========================================
        self.z_mean = bn_m2.get('z_mean_out')
        self.z_logstd = bn_m2.get('z_logstd_out')
        self.x_recon_q = self.lb_l_m2.bn['x_recon_out']  # observing z ~ q(z|x)

        z_sample = tf.random.normal(
            shape=tf.shape(bn_m2.get('z')),mean=self.z_mean,stddev=tf.exp(self.z_logstd),seed=0)
        lb_n = zs.variational.elbo(p_model_m2,
            observed={'y':y_onehot,'z':z_sample},variational=bn_m2,axis=0)
        self.x_recon_n = lb_n.bn['x_recon_out']  # observing z ~ N(0.1)

        # ========================== regularization & optimizers ========================================= #
        dec_m2 = tf.trainable_variables(scope='p_model/decoder')    # variational dropout (m2)
        dec_ws = tf.trainable_variables(scope='p_model_1/decoder')  # variational dropout (ws)
        dec_var = tf.trainable_variables(scope='q_model/decoder')   # tf.dense.layers
        dec_list = dec_ws + dec_m2 + dec_var
        enc_list = tf.trainable_variables(scope='q_model/encoder')  # tf.dense.layers
        mer_list = tf.trainable_variables(scope='q_model/merge')    # tf.dense.layers
        def get_vec_rms(var):
            return tf.reduce_mean(tf.sqrt(tf.reduce_mean(tf.expand_dims(tf.square(var),0),axis=1) + 1E-10))
        def get_mat_rms(var):
            return tf.reduce_mean(tf.sqrt(tf.reduce_mean(tf.square(var),axis=1) + 1E-10))
        enc_reg,dec_reg,mer_reg = 0.0,0.0,0.0
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/encoder'):
            if len(np.shape(var)) == 1:  enc_reg += get_vec_rms(var)
            else:  enc_reg += get_mat_rms(var)
        enc_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/encoder'))
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='p_model/decoder'):
            if len(np.shape(var)) == 1:  dec_reg += get_vec_rms(var)
            else:  dec_reg += get_mat_rms(var)
        dec_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='p_model/decoder'))
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/merge'):
            if len(np.shape(var)) == 1:  mer_reg += get_vec_rms(var)
            else:  mer_reg += get_mat_rms(var)
        mer_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/merge'))

        optimizer_m2 = tf.compat.v1.train.RMSPropOptimizer(learning_rate=self.learning_rate_m2)
        grads_m2,norm_m2 = tf.clip_by_global_norm(tf.gradients(
            self.cost_m2 + self.lamb * (enc_reg + dec_reg + mer_reg),
            enc_list + dec_list + mer_list),5)
        self.infer_op_m2 = optimizer_m2.apply_gradients(zip(grads_m2,enc_list + dec_list + mer_list))

def parse_args():
    '''
        python3 ../src/train_m2_fullvb_bimodal_cv.py  
        --w_gen_hidden 12  --w_dis_hidden 8  --w_cla_hidden_x 12  
        --dim_z 5 --n_h 10 --n_h_2 5 --beta 0.1 --lamb 0.1
         --epochs 100 --split_num 3 --count 972 &
    '''
    parser = argparse.ArgumentParser()
    parser.add_argument('--w_gen_hidden',default=[12],nargs='+',type=int,help='q-model hidden')
    parser.add_argument('--w_dis_hidden',default=[8],nargs='+',type=int,help='p-model hidden')
    parser.add_argument('--w_cla_hidden_x',default=[12],nargs='+',type=int,help='classifier Ax hidden')
    parser.add_argument('--w_cla_hidden_c',default=[],nargs='+',type=int,help='classifier Ac hidden')
    parser.add_argument('--dim_z',default=5,type=int,help='latent dimension')
    parser.add_argument('--n_h',default=10,type=int,help='latent input transformation dimension')
    parser.add_argument('--n_h_2',default=5,type=int,help='latent input transformation dimension')
    
    parser.add_argument('--beta',type=float,default=0.1,help='classification regularization constant in M2 U(x)')
    parser.add_argument('--lamb',type=float,default=0.1,help='weight L2-regulatization constant')
    parser.add_argument('--n_particles_train',type=int,default=50)
    parser.add_argument('--n_particles_test',type=int,default=500)
    parser.add_argument('--batch_size',type=int,default=32)
    parser.add_argument('--learning_rate_m2',type=float,default=0.01)
    parser.add_argument('--std',default=1.0)
    parser.add_argument('--epochs',type=int,default=100,help='total epochs')

    parser.add_argument('--data_path',default='../data/nsclc/nsclc_')
    parser.add_argument('--udata_path',default='../data/nsclc/nsclc_unlabeled')
    parser.add_argument('--split_num',default=3)

    return parser.parse_args()

def main(args):
    with tf.Graph().as_default():
        tf.compat.v1.set_random_seed(1237)  # 1237
        np.random.seed(1234)  # 1234

        data = np.load(args.data_path + str(args.split_num) + '.npz',allow_pickle=True)
        data_unlabeled = np.load(args.udata_path + '.npz',allow_pickle=True)
        x_w_full = (data_unlabeled['x_w_full'] - data['x_mean']) / data['x_scale']
        c_w_full = (data_unlabeled['c_w_full'] - data['c_mean']) / data['c_scale']
        x_n_full = (data_unlabeled['x_n_full'] - data['x_mean']) / data['x_scale']
        c_n_full = (data_unlabeled['c_n_full'] - data['c_mean']) / data['c_scale']

        ########## reduce available unlabeled data ##########
        # keep_per = 0.2  # also replace = True while training
        # x_w_full = x_w_full.copy()[:int(np.shape(x_w_full)[0] * keep_per),:]
        # c_w_full = c_w_full.copy()[:int(np.shape(c_w_full)[0] * keep_per),:]
        # x_n_full = x_n_full.copy()[:int(np.shape(x_n_full)[0] * keep_per),:]

        # extract data from preprocessed data
        x_train,x_valid,x_test = data['x_train'],data['x_valid'],data['x_test']
        c_train,c_valid,c_test = data['c_train'],data['c_valid'],data['c_test']
        y_train,y_valid,y_test = data['y_train'],data['y_valid'],data['y_test']
        o_train,o_valid,o_test = data['o_train'],data['o_valid'],data['o_test']
        e_train,e_valid,e_test = data['e_train'],data['e_valid'],data['e_test']

        # after 4-CV, use (train+valid) & best CV hyper-params to retrain the model
        x_train = np.concatenate((x_train,x_valid),axis=0)
        c_train = np.concatenate((c_train,c_valid),axis=0)
        y_train = np.concatenate((y_train,y_valid))
        o_train = np.concatenate((o_train,o_valid),axis=0)

        num_lab_m2 = np.shape(x_train)[0]      # #(x,c,y)
        num_ulab_m2 = np.shape(x_w_full)[0]    # #(x,c)
        num_ulab_m2_2 = np.shape(x_n_full)[0]  # #(x)
        num_lab_batch = args.batch_size
        num_ulab_batch = args.batch_size

        model = m2_aae(
            input_dims=[np.shape(x_train)[1],np.shape(c_train)[1]],
            hidden_dims=[args.w_gen_hidden,args.w_dis_hidden,args.w_cla_hidden_x,args.w_cla_hidden_c],
            latent_dims=args.dim_z,
            coeffs=[args.beta,args.lamb],
            learning_rates=[args.learning_rate_m2],
            n_hs=[args.n_h,args.n_h_2],
            n_samples=[num_lab_m2,num_ulab_m2,num_ulab_m2_2,args.batch_size])
        model.build_graph()

        train_dict,valid_dict,test_dict = {},{},{}
        saver = tf.compat.v1.train.Saver()  # model checkpoint
        
        def get_performance(sess,feed_dict,y_label,o_label):
            merge = [model.lb_l,model.y_x,model.y_c,model.y_xc,model.x_recon_q,model.x_recon_n,
                        model.z_mean,model.z_logstd]
            out = sess.run(merge,feed_dict=feed_dict)
            lb = np.mean(out[0])
            out[1] = np.nan_to_num(out[1])  # x
            fpr,tpr,thr = roc_curve(y_label.astype(int),(np.nan_to_num(out[1][:,1])),pos_label=1)
            thr_best = thr[np.argmax(np.subtract(tpr,fpr))]
            label = 1.0 * np.ones_like(out[1][:,1])
            for i in range(np.shape(label)[0]):
                if out[1][i,1] < thr_best:  label[i] = 0.0
            f1_x  = f1_score(y_label.astype(int),label.astype(int),pos_label=1,average='macro')
            ci_x  = calc_c_index_benchmark(o_label,out[1][:,1])
            acc_x = calc_acc_score(thr_best,y_label,out[1][:,1])
            auc_x = auc(fpr,tpr)
            metrics_x = [auc_x,f1_x,ci_x,acc_x]
            out[2] = np.nan_to_num(out[2])  # c
            fpr,tpr,thr = roc_curve(y_label.astype(int),(np.nan_to_num(out[2][:,1])),pos_label=1)
            thr_best = thr[np.argmax(np.subtract(tpr,fpr))]
            label = 1.0 * np.ones_like(out[2][:,1])
            for i in range(np.shape(label)[0]):
                if out[2][i,1] < thr_best:  label[i] = 0.0
            f1_c  = f1_score(y_label.astype(int),label.astype(int),pos_label=1,average='macro')
            ci_c  = calc_c_index_benchmark(o_label,out[2][:,1])
            acc_c = calc_acc_score(thr_best,y_label,out[2][:,1])
            auc_c = auc(fpr,tpr)
            metrics_c = [auc_c,f1_c,ci_c,acc_c]
            out[3] = np.nan_to_num(out[3])  # xc
            fpr,tpr,thr = roc_curve(y_label.astype(int),(np.nan_to_num(out[3][:,1])),pos_label=1)
            thr_best = thr[np.argmax(np.subtract(tpr,fpr))]
            label = 1.0 * np.ones_like(out[3][:,1])
            for i in range(np.shape(label)[0]):
                if out[3][i,1] < thr_best:  label[i] = 0.0
            f1_xc  = f1_score(y_label.astype(int),label.astype(int),pos_label=1,average='macro')
            ci_xc  = calc_c_index_benchmark(o_label,out[3][:,1])
            acc_xc = calc_acc_score(thr_best,y_label,out[3][:,1])
            auc_xc = auc(fpr,tpr)
            metrics_xc = [auc_xc,f1_xc,ci_xc,acc_xc]
            return lb,metrics_x,metrics_c,metrics_xc,out

        # gpu configuration
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())
            for epoch in range(args.epochs):
                iters_I = int(np.floor(np.shape(x_train)[0] / float(args.batch_size)))
                for t in range(iters_I):
                    bidx_lab  = np.random.choice(np.arange(num_lab_m2),size=num_lab_batch,replace=False)
                    bidx_ulab = np.random.choice(np.arange(num_ulab_m2),size=num_ulab_batch,replace=False)
                    bidx_ulab_2 = np.random.choice(np.arange(num_ulab_m2_2),size=num_ulab_batch,replace=False)
                    
                    # use these two lines instead for decreasing unlabeled data
                    # bidx_ulab = np.random.choice(np.arange(num_ulab_m2),size=num_ulab_batch,replace=True)
                    # bidx_ulab_2 = np.random.choice(np.arange(num_ulab_m2_2),size=num_ulab_batch,replace=True)
                    
                    train_dict.update({model.x: x_train[bidx_lab,:]})
                    train_dict.update({model.c: c_train[bidx_lab,:]})
                    train_dict.update({model.y: y_train[bidx_lab]})
                    train_dict.update({model.x_u: x_w_full[bidx_ulab,:]})
                    train_dict.update({model.c_u: c_w_full[bidx_ulab,:]})
                    train_dict.update({model.x_x: x_n_full[bidx_ulab_2,:]})
                    train_dict.update({model.n_particles: args.n_particles_train})
                    sess.run(model.infer_op_m2,feed_dict=train_dict)
                train_dict_test = {model.x: x_train,model.c: c_train,model.y: y_train,
                    model.n_particles: args.n_particles_test}
                valid_dict_test = {model.x: x_valid,model.c: c_valid,model.y: y_valid,
                    model.n_particles: args.n_particles_test}
                test_dict_test = {model.x: x_test,model.c: c_test,model.y: y_test,
                    model.n_particles: args.n_particles_test}
                out_train = get_performance(sess,train_dict_test,y_train,o_train)
                out_valid = get_performance(sess,valid_dict_test,y_valid,o_valid)
                out_test  = get_performance(sess,test_dict_test,y_test,o_test)

                print(epoch+1)
                print(
                    'AUC_train_x = %.4f; AUC_valid_x = %.4f; AUC_test_x = %.4f;\n'
                    'AUC_train_c = %.4f; AUC_valid_c = %.4f; AUC_test_c = %.4f;\n'
                    'AUC_train_xc = %.4f; AUC_valid_xc = %.4f; AUC_test_xc = %.4f;\n'

                    'UF1_train_x = %.4f; UF1_valid_x = %.4f; UF1_test_x = %.4f;\n'
                    'UF1_train_c = %.4f; UF1_valid_c = %.4f; UF1_test_c = %.4f;\n'
                    'UF1_train_xc = %.4f; UF1_valid_xc = %.4f; UF1_test_xc = %.4f;\n'

                    'ACC_train_x = %.4f; ACC_valid_x = %.4f; ACC_test_x = %.4f;\n'
                    'ACC_train_c = %.4f; ACC_valid_c = %.4f; ACC_test_c = %.4f;\n'
                    'ACC_train_xc = %.4f; ACC_valid_xc = %.4f; ACC_test_xc = %.4f;\n'

                    'CI_train_x = %.4f; CI_valid_x = %.4f; CI_test_x = %.4f;\n'
                    'CI_train_c = %.4f; CI_valid_c = %.4f; CI_test_c = %.4f;\n'
                    'CI_train_xc = %.4f; CI_valid_xc = %.4f; CI_test_xc = %.4f;\n'
                    'lb_train = %.4f; lb_valid = %.4f; lb_test = %.4f\n' % (
                    out_train[1][0],out_valid[1][0],out_test[1][0],
                    out_train[2][0],out_valid[2][0],out_test[2][0],
                    out_train[3][0],out_valid[3][0],out_test[3][0],

                    out_train[1][1],out_valid[1][1],out_test[1][1],
                    out_train[2][1],out_valid[2][1],out_test[2][1],
                    out_train[3][1],out_valid[3][1],out_test[3][1],

                    out_train[1][3],out_valid[1][3],out_test[1][3],
                    out_train[2][3],out_valid[2][3],out_test[2][3],
                    out_train[3][3],out_valid[3][3],out_test[3][3],

                    out_train[1][2],out_valid[1][2],out_test[1][2],
                    out_train[2][2],out_valid[2][2],out_test[2][2],
                    out_train[3][2],out_valid[3][2],out_test[3][2],
                    out_train[0],out_valid[0],out_test[0]))
            
            save_path = saver.save(sess,'../model/scan/nsclc.ckpt')
 
if __name__ == "__main__":
    args = parse_args()
    main(args)