scan / src / test_scan_ens.py
test_scan_ens.py
Raw
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
from six.moves import range, zip
import numpy as np
import zhusuan as zs
import argparse
from sklearn.metrics import roc_auc_score,auc,roc_curve,f1_score,accuracy_score
from lifelines.utils import concordance_index

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import GridSearchCV
import inspect
from sklearn.metrics import accuracy_score,average_precision_score,precision_recall_curve


# for validation splitting
from sklearn.model_selection import StratifiedKFold

def calc_c_index_benchmark(o_test,y_prob):
    time = np.array(np.expand_dims(o_test,axis=1),dtype=float)
    prob = np.tanh(y_prob)
    mean = np.mean(prob)
    y_score = (prob - mean)
    return concordance_index(time,-y_score)

def calc_acc_score(thr,y_label,y_prob):
    label = 1.0 * np.ones_like(y_prob)
    for i in range(np.shape(label)[0]):
        if y_prob[i] < thr:  label[i] = 0.0
    return accuracy_score(y_label,label)


class m2_aae(BaseEstimator):
    def __init__(
        self,dim_x=0,dim_c=0,dim_z=0,w_gen_hidden=[],w_dis_hidden=[],w_cla_hidden_x=[],w_cla_hidden_c=[],
        beta=0.0,lamb=0.0,learning_rate_m2=0.0,n_h=0,n_h_2=0,
        num_lab_m2=0,num_ulab_m2=0,num_ulab_m2_2=0,batch_size=0,
        epochs=0,n_particles_train=0,n_particles_test=0):
        
        args, _, _, values = inspect.getargvalues(inspect.currentframe())
        values.pop("self")
        for arg, val in values.items():
            setattr(self, arg, val)
            # print("{} = {}".format(arg,val))
        self.w_gen_size   = [self.dim_x] + self.w_gen_hidden + []  # encoder
        self.w_dis_size   = [self.n_h]   + self.w_dis_hidden + [self.dim_x]  # decoder
        self.w_cla_size_x = [self.dim_x] + self.w_cla_hidden_x + []  # aux classifier x
        self.w_cla_size_c = [self.dim_c] + self.w_cla_hidden_c + []  # aux classifier c

    @zs.reuse_variables(scope="q_model")
    def _q_model(
        x_input,c_input,dim_z,n_h_2,w_gen_size,w_dis_size,
        w_cla_size_x,w_cla_size_c,n_particles,is_subx,is_subc):
        bn_m2  = zs.BayesianNet()  # VAE latent: z
        bn_ws  = zs.BayesianNet()  # p-model decoder weight variational dropout
        with tf.variable_scope('encoder'):  # q(z|x)
            h = x_input
            for i in range(len(w_gen_size)-1):
                h = tf.layers.dense(h,w_gen_size[i+1],name='enc_' + str(i),activation=tf.nn.elu)
            z_mean = tf.layers.dense(h,dim_z,name='enc_mean',activation=tf.nn.elu)
            z_logstd = tf.layers.dense(h,dim_z,name='enc_logstd')  # log should be real numbers
            bn_m2.normal('z',z_mean,logstd=z_logstd,n_samples=n_particles,group_ndims=1)
            bn_m2.deterministic('z_mean_out',z_mean)
            bn_m2.deterministic('z_logstd_out',z_logstd)
        with tf.variable_scope('decoder'):  # p(x|z,y)
            for i in range(len(w_dis_size)-1):
                with tf.variable_scope('layer_dis_' + str(i)):
                    logit_alpha = tf.get_variable('logit_alpha',[w_dis_size[i]])
                    std = tf.sqrt(tf.nn.sigmoid(logit_alpha) + 1E-10)
                    std = tf.tile(tf.expand_dims(std,0),[tf.shape(x_input)[0],1])
                    bn_ws.normal('w_dis_' + str(i),1.0,std=std,n_samples=n_particles,group_ndims=1)
        with tf.variable_scope('merge'):  # q(y|x,c)
            hx = x_input
            for i in range(len(w_cla_size_x)-1):
                hx = tf.layers.dense(hx,w_cla_size_x[i+1],name='x_mer_' + str(i),activation=tf.nn.softplus)
            hc = c_input
            for i in range(len(w_cla_size_c)-1):
                hc = tf.layers.dense(hc,w_cla_size_c[i+1],name='c_mer_' + str(i),activation=tf.nn.softplus)
            hs_x = tf.layers.dense(hx,n_h_2,name='x_to_merge')
            hs_c = tf.layers.dense(hc,n_h_2,name='c_to_merge')
            hs_x = tf.cond(is_subx,lambda: hs_x,lambda: tf.stop_gradient(hs_x * 0.0))
            hs_c = tf.cond(is_subc,lambda: hs_c,lambda: tf.stop_gradient(hs_c * 0.0))
            y_sub = tf.layers.dense(hs_x + hs_c,2,name='merge_output',activation=tf.nn.sigmoid)  # probability
        return bn_m2,bn_ws,y_sub

    @zs.meta_bayesian_net(scope="p_model",reuse_variables=True)
    def _p_model(self,y_input,n_h,dim_z,w_dis_size,n_particles,is_vae):
        '''only one p-model and use `log_joint` to specify latent nodes in elbo'''
        bn = zs.BayesianNet()
        with tf.variable_scope('decoder'):  # p(x|z,y)
            with tf.variable_scope('input_layers'):
                y_logits = tf.zeros([tf.shape(y_input)[0], 2])  # y prior ~ Cat(uniform)
                y = bn.onehot_categorical('y', y_logits)
                z_mean = tf.zeros([tf.shape(y_input)[0], dim_z])  # z prior ~ N(O,I)
                z = bn.normal('z', z_mean, std=1., group_ndims=1,n_samples=n_particles)
                h_from_y = tf.layers.dense(tf.cast(y,tf.float32), n_h,name='y_dec',activation=tf.nn.elu)
                h_from_c = tf.layers.dense(z, n_h,name='h_dec',activation=tf.nn.elu)
                h = tf.nn.softplus(h_from_y + h_from_c)
            for i in range(len(w_dis_size)-1):
                w_mean = tf.ones([w_dis_size[i]])
                w = bn.normal('w_dis_' + str(i),w_mean,std=1.0,n_samples=n_particles,group_ndims=1)
                if i == len(w_dis_size) - 2:
                    with tf.variable_scope('output_layer'):
                        h = tf.cond(is_vae,lambda: h, lambda: h*w)
                        h = tf.layers.dense(h,w_dis_size[i+1],name='w_dis_' + str(i))
                else:
                    with tf.variable_scope('hidden_' + str(i)):
                        h = tf.cond(is_vae,lambda: h, lambda: h*w)
                        h = tf.layers.dense(h,w_dis_size[i+1],name='w_dis_' + str(i),activation=tf.nn.elu)
            x_mean = h
            bn.normal('x_recon', x_mean, std=1., group_ndims=1)
            bn.deterministic('x_recon_out',tf.reduce_mean(x_mean,axis=0))
        return bn

    def build_graph(self):
        # complete labeled data triplets (x,c,y)
        self.x = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='labeled_expression')
        self.c = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_c],name='labeled_clinincal')
        self.y = tf.compat.v1.placeholder(tf.float32,shape=[None],name='labeled_class')
        # unlabeled (most censored) samples (x,c)
        self.x_u = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='unlabeled_expression')
        self.c_u = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_c],name='unlabeled_clinincal')
        # unlabeled (most censored) samples (x)
        self.x_x = tf.compat.v1.placeholder(tf.float32,shape=[None,self.dim_x],name='only_expression')

        self.n_particles = tf.compat.v1.placeholder(tf.int32,shape=[],name='n_particles')
        self.is_training = tf.compat.v1.placeholder(tf.bool,shape=[],name='is_training')  # deprecated

        # ========================================= M2 VAE =============================================== #
        # labeled samples used to build q-model & train auxiliary classifier (fully-supervised fashion)
        bn_m2,bn_ws,self.y_x = self._q_model(
            self.x,self.c,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,self.y_c = self._q_model(
            self.x,self.c,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        # p-model for x_recons given various types of data; q_ones/zeros are built accordingly
        p_model_m2 = self._p_model(
            self.y,self.n_h,
            self.dim_z,self.w_dis_size,self.n_particles,tf.constant(True,dtype=tf.bool))
        p_model_ws = self._p_model(
            self.y,self.n_h,
            self.dim_z,self.w_dis_size,self.n_particles,tf.constant(False,dtype=tf.bool))
        
        # L(x,y)
        y_onehot = tf.one_hot(indices=tf.cast(self.y,tf.int32),depth=2,on_value=1,off_value=0)
        def log_joint_decoder(bn):
            log_px  = bn.cond_log_prob('x_recon')
            log_py  = bn.cond_log_prob('y')
            log_pz  = bn.cond_log_prob('z')
            return log_px + log_py + log_pz
        p_model_m2.log_joint = log_joint_decoder
        self.lb_l_m2 = zs.variational.elbo(
            p_model_m2,observed={'x_recon': self.x,'y':y_onehot},variational=bn_m2,axis=0)
        self.lb_l_m2_sgvb = self.lb_l_m2.sgvb()
        def log_joint_weight(bn):
            w_names = ['w_dis_' + str(i) for i in range(len(self.w_dis_size)-1)]
            log_pws = bn.cond_log_prob(w_names)
            return tf.add_n(log_pws)
        p_model_ws.log_joint = log_joint_weight
        self.lb_l2_ws = zs.variational.elbo(
            p_model_ws,observed={'x_recon': self.x,'y':y_onehot},variational=bn_ws,axis=0)
        self.lb_l2_ws_sgvb = self.lb_l2_ws.sgvb()
        self.lb_l_sgvb = self.lb_l_m2_sgvb * self.num_lab_m2 + self.lb_l2_ws_sgvb
        
        # U(x) = U'(x_u,c_u) + U'(x_x); use all unlabeled data
        # 1. U'(x_u,c_u): use both x_u & c_u for ensemble prediction - q(yu|x_u,c_u)
        y_onehot_ones = tf.one_hot(indices=tf.cast(tf.ones_like(self.x_u[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        y_onehot_zeros = tf.one_hot(indices=tf.cast(tf.zeros_like(self.x_u[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        q_model_ones_m2_xc,q_model_ones_ws_xc,yu_ones_xc_x = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,yu_ones_xc_c = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        q_model_zeros_m2_xc,q_model_zeros_ws_xc,yu_zeros_xc_x = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        _,_,yu_zeros_xc_c = self._q_model(
            self.x_u,self.c_u,self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(False,dtype=tf.bool),tf.constant(True,dtype=tf.bool))
        # ensemble prediction
        yu_ones_xc = 0.5 * yu_ones_xc_x + 0.5 * yu_ones_xc_c
        yu_zeros_xc = 0.5 * yu_zeros_xc_x + 0.5 * yu_zeros_xc_c
        self.lb_z_ones_m2_xc = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_u,'y':y_onehot_ones},variational=q_model_ones_m2_xc,axis=0)
        self.lb_z_ones_ws_xc = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_u,'y':y_onehot_ones},variational=q_model_ones_ws_xc,axis=0)
        self.lb_z_zeros_m2_xc = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_u,'y':y_onehot_zeros},variational=q_model_zeros_m2_xc,axis=0)
        self.lb_z_zeros_ws_xc = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_u,'y':y_onehot_zeros},variational=q_model_zeros_ws_xc,axis=0)
        self.lb_z_m2_xc = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_m2_xc,axis=1),tf.expand_dims(self.lb_z_ones_m2_xc,axis=1)],axis=1)
        self.lb_z_ws_xc = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_ws_xc,axis=1),tf.expand_dims(self.lb_z_ones_ws_xc,axis=1)],axis=1)
        self.yu_xc = tf.concat(
            [tf.expand_dims(yu_zeros_xc[:,1],axis=1),tf.expand_dims(yu_ones_xc[:,1],axis=1)],axis=1)
        qy_u_xc = self.yu_xc / tf.reduce_sum(self.yu_xc, 1, keepdims=True)
        log_qy_u_xc = tf.log(qy_u_xc + 1E-10)
        self.lb_u_m2_xc = tf.reduce_sum(qy_u_xc * (self.lb_z_m2_xc - log_qy_u_xc), 1)
        self.lb_u_ws_xc = tf.reduce_sum(qy_u_xc * (self.lb_z_ws_xc - log_qy_u_xc), 1)
        self.lb_u_sgvb_xc = - (self.lb_u_m2_xc * self.num_ulab_m2 + self.lb_u_ws_xc) 

        # 2. U'(x_x): use only x_x for single prediction - q(yu|x_x)
        y_onehot_ones_2 = tf.one_hot(indices=tf.cast(tf.ones_like(self.x_x[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        y_onehot_zeros_2 = tf.one_hot(indices=tf.cast(tf.zeros_like(self.x_x[:,0]),tf.int32),
                                    depth=2,on_value=1,off_value=0)
        q_model_ones_m2_x,q_model_ones_ws_x,yu_ones_x = self._q_model(
            self.x_x,tf.ones([tf.shape(self.x_x)[0],self.dim_c]),self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        q_model_zeros_m2_x,q_model_zeros_ws_x,yu_zeros_x = self._q_model(
            self.x_x,tf.ones([tf.shape(self.x_x)[0],self.dim_c]),self.dim_z,self.n_h_2,
            self.w_gen_size,self.w_dis_size,self.w_cla_size_x,self.w_cla_size_c,self.n_particles,
            tf.constant(True,dtype=tf.bool),tf.constant(False,dtype=tf.bool))
        self.lb_z_ones_m2_x = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_x,'y':y_onehot_ones_2},variational=q_model_ones_m2_x,axis=0)
        self.lb_z_ones_ws_x = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_x,'y':y_onehot_ones_2},variational=q_model_ones_ws_x,axis=0)
        self.lb_z_zeros_m2_x = zs.variational.elbo(p_model_m2,
            observed={'x_recon': self.x_x,'y':y_onehot_zeros_2},variational=q_model_zeros_m2_x,axis=0)
        self.lb_z_zeros_ws_x = zs.variational.elbo(p_model_ws,
            observed={'x_recon': self.x_x,'y':y_onehot_zeros_2},variational=q_model_zeros_ws_x,axis=0)
        self.lb_z_m2_x = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_m2_x,axis=1),tf.expand_dims(self.lb_z_ones_m2_x,axis=1)],axis=1)
        self.lb_z_ws_x = tf.concat(
            [tf.expand_dims(self.lb_z_zeros_ws_x,axis=1),tf.expand_dims(self.lb_z_ones_ws_x,axis=1)],axis=1)
        # single prediction
        self.yu_x = tf.concat(
            [tf.expand_dims(yu_zeros_x[:,1],axis=1),tf.expand_dims(yu_ones_x[:,1],axis=1)],axis=1)
        qy_u_x = self.yu_x / tf.reduce_sum(self.yu_x, 1, keepdims=True)
        log_qy_u_x = tf.log(qy_u_x + 1E-10)
        self.lb_u_m2_x = tf.reduce_sum(qy_u_x * (self.lb_z_m2_x - log_qy_u_x), 1)
        self.lb_u_ws_x = tf.reduce_sum(qy_u_x * (self.lb_z_ws_x - log_qy_u_x), 1)
        self.lb_u_sgvb_x = - (self.lb_u_m2_x * self.num_ulab_m2_2 + self.lb_u_ws_x)

        # 3. U(x) = U'(x_u,c_u) + U'(x_x)
        self.lb_u_sgvb = self.lb_u_sgvb_xc + self.lb_u_sgvb_x
        
        # auxiliary classifiers for x and c (using labeled triplets to train)
        onehot_cat_x = zs.distributions.OnehotCategorical(self.y_x)
        self.log_qy_x = onehot_cat_x.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))
        onehot_cat_c = zs.distributions.OnehotCategorical(self.y_c)
        self.log_qy_c = onehot_cat_c.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))
        # ensemble prediction and then calculate log-probability
        self.y_xc = 0.5 * self.y_x + 0.5 * self.y_c
        onehot_cat_xc = zs.distributions.OnehotCategorical(self.y_xc)
        self.log_qy_xc = onehot_cat_xc.log_prob(tf.one_hot(indices=tf.cast(self.y,tf.int32),
                        depth=2,on_value=1,off_value=0))

        # overall loss; following original paper, L/U are added directly while `cost_m2_aux` is weighted
        cost_m2_lxy = tf.reduce_mean(self.lb_l_sgvb)
        cost_m2_ux  = tf.reduce_mean(self.lb_u_sgvb)
        cost_m2_aux = tf.reduce_mean(self.log_qy_xc)
        self.cost_m2 = cost_m2_lxy - self.beta * cost_m2_aux + cost_m2_ux
        self.lb_l = self.lb_l_m2_sgvb  # observing reconstruction

        # ======================================== reconstructions ===========================================
        self.z_mean = bn_m2.get('z_mean_out')
        self.z_logstd = bn_m2.get('z_logstd_out')
        self.x_recon_q = self.lb_l_m2.bn['x_recon_out']  # observing z ~ q(z|x)

        z_sample = tf.random.normal(
            shape=tf.shape(bn_m2.get('z')),mean=self.z_mean,stddev=tf.exp(self.z_logstd),seed=0)
        lb_n = zs.variational.elbo(p_model_m2,
            observed={'y':y_onehot,'z':z_sample},variational=bn_m2,axis=0)
        self.x_recon_n = lb_n.bn['x_recon_out']  # observing z ~ N(0.1)

        # ========================== regularization & optimizers ========================================= #
        dec_m2 = tf.trainable_variables(scope='p_model/decoder')    # variational dropout (m2)
        dec_ws = tf.trainable_variables(scope='p_model_1/decoder')  # variational dropout (ws)
        dec_var = tf.trainable_variables(scope='q_model/decoder')   # tf.dense.layers
        dec_list = dec_ws + dec_m2 + dec_var
        enc_list = tf.trainable_variables(scope='q_model/encoder')  # tf.dense.layers
        mer_list = tf.trainable_variables(scope='q_model/merge')    # tf.dense.layers
        def get_vec_rms(var):
            return tf.reduce_mean(tf.sqrt(tf.reduce_mean(tf.expand_dims(tf.square(var),0),axis=1) + 1E-10))
        def get_mat_rms(var):
            return tf.reduce_mean(tf.sqrt(tf.reduce_mean(tf.square(var),axis=1) + 1E-10))
        enc_reg,dec_reg,mer_reg = 0.0,0.0,0.0
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/encoder'):
            if len(np.shape(var)) == 1:  enc_reg += get_vec_rms(var)
            else:  enc_reg += get_mat_rms(var)
        enc_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/encoder'))
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='p_model/decoder'):
            if len(np.shape(var)) == 1:  dec_reg += get_vec_rms(var)
            else:  dec_reg += get_mat_rms(var)
        dec_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='p_model/decoder'))
        for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/merge'):
            if len(np.shape(var)) == 1:  mer_reg += get_vec_rms(var)
            else:  mer_reg += get_mat_rms(var)
        mer_reg /= len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope='q_model/merge'))

        optimizer_m2 = tf.compat.v1.train.RMSPropOptimizer(learning_rate=self.learning_rate_m2)
        grads_m2,norm_m2 = tf.clip_by_global_norm(tf.gradients(
            self.cost_m2 + self.lamb * (enc_reg + dec_reg + mer_reg),
            enc_list + dec_list + mer_list),5)
        self.infer_op_m2 = optimizer_m2.apply_gradients(zip(grads_m2,enc_list + dec_list + mer_list))

    def fit(self,X,sess):
        train_dict = {}
        for epoch in range(self.epochs):
            print(epoch)
            iters_I = int(np.floor(np.shape(X['x_train'])[0] / float(self.batch_size)))
            for t in range(iters_I):
                bidx_lab  = np.random.choice(np.arange(self.num_lab_m2),size=self.batch_size,replace=False)
                bidx_ulab = np.random.choice(np.arange(self.num_ulab_m2),size=self.batch_size,replace=False)
                bidx_ulab_2 = np.random.choice(np.arange(self.num_ulab_m2_2),size=self.batch_size,replace=False)
                train_dict.update({self.x: X['x_train'][bidx_lab,:]})
                train_dict.update({self.c: X['c_train'][bidx_lab,:]})
                train_dict.update({self.y: X['y_train'][bidx_lab]})
                train_dict.update({self.x_u: X['x_w_full'][bidx_ulab,:]})
                train_dict.update({self.c_u: X['c_w_full'][bidx_ulab,:]})
                train_dict.update({self.x_x: X['x_n_full'][bidx_ulab_2,:]})
                train_dict.update({self.n_particles: self.n_particles_train})
                sess.run(self.infer_op_m2,feed_dict=train_dict)
        return self

    def score(self,X,sess):  # validation
        test_dict_test = {
            self.x: X['x_valid'],self.c: X['c_valid'],self.y: X['y_valid'],
            self.n_particles: self.n_particles_test}
        prob = sess.run(self.y_xc,feed_dict=test_dict_test)
        prob = np.nan_to_num(prob[:,1])
        fpr,tpr,thr = roc_curve(X['y_valid'].astype(int),prob,pos_label=1)
        valid_auc = auc(fpr,tpr)
        return valid_auc

    def score_2(self,X,sess):  # test
        test_dict_test = {
            self.x: X['x_test'],self.c: X['c_test'],self.y: X['y_test'],
            self.n_particles: self.n_particles_test}
        prob = sess.run(self.y_xc,feed_dict=test_dict_test)
        prob = np.nan_to_num(prob[:,1])
        fpr,tpr,thr = roc_curve(X['y_test'].astype(int),prob,pos_label=1)
        valid_auc = auc(fpr,tpr)
        return valid_auc

def parse_args():
    parser = argparse.ArgumentParser()

    # nsclc hyper-parameters
    parser.add_argument('--w_gen_hidden',default=[12],nargs='+',type=int,help='q-model hidden')
    parser.add_argument('--w_dis_hidden',default=[8],nargs='+',type=int,help='p-model hidden')
    parser.add_argument('--w_cla_hidden_x',default=[12],nargs='+',type=int,help='classifier Ax hidden')
    parser.add_argument('--w_cla_hidden_c',default=[],nargs='+',type=int,help='classifier Ac hidden')
    parser.add_argument('--dim_z',default=5,type=int,help='latent dimension')
    parser.add_argument('--n_h',default=10,type=int,help='latent input transformation dimension')
    parser.add_argument('--n_h_2',default=5,type=int,help='latent input transformation dimension')
    
    parser.add_argument('--beta',type=float,default=0.1,help='classification regularization constant in M2 U(x)')
    parser.add_argument('--lamb',type=float,default=0.1,help='weight L2-regulatization constant')
    parser.add_argument('--n_particles_train',type=int,default=50)
    parser.add_argument('--n_particles_test',type=int,default=500)
    parser.add_argument('--batch_size',type=int,default=32)
    parser.add_argument('--learning_rate_m2',type=float,default=0.01)
    parser.add_argument('--std',default=1.0)
    parser.add_argument('--epochs',type=int,default=100,help='total epochs')

    parser.add_argument('--data_path',default='../data/nsclc/nsclc_')
    parser.add_argument('--udata_path',default='../data/nsclc/nsclc_unlabeled')
    parser.add_argument('--split_num',default=3)

    # breast hyper-parameters
    # parser.add_argument('--w_gen_hidden',default=[10,10,5],nargs='+',type=int,help='q-model hidden')  # [15,10]
    # parser.add_argument('--w_dis_hidden',default=[5],nargs='+',type=int,help='p-model hidden')  # [10,15]
    # parser.add_argument('--w_cla_hidden_x',default=[10,5],nargs='+',type=int,help='classifier Ax hidden')  # [10,5]
    # parser.add_argument('--w_cla_hidden_c',default=[10],nargs='+',type=int,help='classifier Ac hidden')  # [10,5]
    # parser.add_argument('--dim_z',default=10,type=int,help='latent dimension')
    # parser.add_argument('--n_h',default=5,type=int,help='decoder latent transformation dimension')  # 10
    # parser.add_argument('--n_h_2',default=5,type=int,help='classifier latent transformation dimension')

    # parser.add_argument('--beta',type=float,default=0.1,help='classification loss weight')  # 0.1
    # parser.add_argument('--lamb',type=float,default=0.1,help='weight L2-reg constant')  # 1E-3
    # parser.add_argument('--n_particles_train',type=int,default=50)
    # parser.add_argument('--n_particles_test',type=int,default=500)
    # parser.add_argument('--batch_size',type=int,default=64)
    # parser.add_argument('--learning_rate_m2',type=float,default=0.01)
    # parser.add_argument('--std',default=1.0)
    # parser.add_argument('--epochs',type=int,default=300,help='total epochs')  # 150
 
    # parser.add_argument('--data_path',default='../data/breast/breast_')
    # parser.add_argument('--split_num',default=1)
    # parser.add_argument('--udata_path',default='../data/breast/breast_unlabeled')
    # parser.add_argument('--count',default=0)

    # random seeds (will change in each ensemble through another bash file)
    parser.add_argument('--np_seed',type=int,default=0)  # 1234
    parser.add_argument('--tf_seed',type=int,default=0)  # 1237

    return parser.parse_args()

def main(args):
    with tf.Graph().as_default():
        tf.compat.v1.set_random_seed(args.tf_seed)
        np.random.seed(args.np_seed)

        data = np.load(args.data_path + str(args.split_num) + '.npz',allow_pickle=True)
        data_unlabeled = np.load(args.udata_path + '.npz',allow_pickle=True)
        
        # unlabeled data for breast
        # x_w_full = (data_unlabeled['x_w_full'] - data['mean'][10:]) / data['scale'][10:]
        # c_w_full = (data_unlabeled['c_w_full'] - data['mean'][:10]) / data['scale'][:10]
        # x_n_full = (data_unlabeled['x_n_full'] - data['mean'][10:]) / data['scale'][10:]
        # c_n_full = (data_unlabeled['c_n_full'] - data['mean'][:10]) / data['scale'][:10]
        
        # unlabeled data for nsclc
        x_w_full = (data_unlabeled['x_w_full'] - data['x_mean']) / data['x_scale']
        c_w_full = (data_unlabeled['c_w_full'] - data['c_mean']) / data['c_scale']
        x_n_full = (data_unlabeled['x_n_full'] - data['x_mean']) / data['x_scale']
        c_n_full = (data_unlabeled['c_n_full'] - data['c_mean']) / data['c_scale']

        x_train = np.concatenate((data['x_train'],data['x_valid']),axis=0)
        c_train = np.concatenate((data['c_train'],data['c_valid']),axis=0)
        y_train = np.concatenate((data['y_train'],data['y_valid']))
        o_train = np.concatenate((data['o_train'],data['o_valid']),axis=0)

        data_dict = {}
        for (k,v) in zip(data.keys(),data.values()):  data_dict[k] = v
        data_dict['x_w_full'],data_dict['c_w_full'],data_dict['x_n_full'] = x_w_full,c_w_full,x_n_full

        data_dict['x_train'] = x_train
        data_dict['c_train'] = c_train
        data_dict['y_train'] = y_train
        data_dict['o_train'] = o_train

        num_lab_m2 = np.shape(data_dict['x_train'])[0]
        num_ulab_m2 = np.shape(x_w_full)[0]
        num_ulab_m2_2 = np.shape(x_n_full)[0]
        model = m2_aae(
            dim_x=np.shape(data_dict['x_train'])[1],dim_c=np.shape(data_dict['c_train'])[1],dim_z=args.dim_z,
            w_gen_hidden=args.w_gen_hidden,w_dis_hidden=args.w_dis_hidden,
            w_cla_hidden_x=args.w_cla_hidden_x,w_cla_hidden_c=args.w_cla_hidden_c,
            beta=args.beta,lamb=args.lamb,learning_rate_m2=args.learning_rate_m2,
            n_h=args.n_h,n_h_2=args.n_h_2,
            num_lab_m2=num_lab_m2,num_ulab_m2=num_ulab_m2,num_ulab_m2_2=num_ulab_m2_2,
            batch_size=args.batch_size,epochs=args.epochs,
            n_particles_train=args.n_particles_train,n_particles_test=args.n_particles_test)
        model.build_graph()

        saver = tf.compat.v1.train.Saver()
        # gpu configuration
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.Session(config=config) as sess:
            sess.run(tf.global_variables_initializer())  

            probs_xc,probs_x,probs_c = [],[],[]
            aucs,uf1s,cis,accs,prcs = [],[],[],[],[]
            for i in range(200):
                saver.restore(sess,'../model/scan_ens/nsclc/ens_' + str(i) + '/nsclc_' + str(i) + '.ckpt')
                test_auc = model.score_2(data_dict,sess)
                print('test_auc = %.4f' % test_auc)

                # independent nsclc
                # data_indep = np.load('../data/nsclc/indep.npz',allow_pickle=True)
                # x_test = (data_indep['x_test'] - data['x_mean']) / data['x_scale']
                # c_test = (data_indep['c_test'] - data['c_mean']) / data['c_scale']
                # y_test = data_indep['y_test']
                # test_dict_test = {
                #     model.x: x_test,model.c: c_test,model.y: y_test,
                #     model.n_particles: model.n_particles_test}

                # independent breast
                # data_indep = np.load('../data/breast/indep_valid_geos_21653.npz',allow_pickle=True)
                # x_test = (data_indep['exprs'] - data['mean'][10:]) / data['scale'][10:]
                # y_test  = data_indep['labs']
                # test_dict_test = {
                #     model.x: x_test,model.c: np.zeros((83,10)),model.y: y_test,
                #     model.n_particles: model.n_particles_test}

                # original
                y_test = data['y_test']
                test_dict_test = {
                    model.x: data['x_test'],model.c: data['c_test'],model.y: data['y_test'],
                    model.n_particles: model.n_particles_test}
                
                prob_xc,prob_x,prob_c = sess.run([model.y_xc,model.y_x,model.y_c],feed_dict=test_dict_test)
                prob_x = sess.run(model.y_x,feed_dict=test_dict_test)
                probs_xc.append(prob_xc[:,1])
                probs_x.append(prob_x[:,1])
                probs_c.append(prob_c[:,1])

            probs_x  = np.mean(np.array(probs_x),axis=0)
            fpr,tpr,thr = roc_curve(y_test.astype(int),probs_x,pos_label=1)
            thr_best_x = thr[np.argmax(np.subtract(tpr,fpr))]
            print(probs_x)

            probs_c  = np.mean(np.array(probs_c),axis=0)  # not available for breast
            fpr,tpr,thr = roc_curve(y_test.astype(int),probs_c,pos_label=1)
            thr_best_c = thr[np.argmax(np.subtract(tpr,fpr))]
            print(probs_c)

            probs_xc = np.mean(np.array(probs_xc),axis=0)  # not available for breast
            fpr,tpr,thr = roc_curve(y_test.astype(int),probs_xc,pos_label=1)
            thr_best_xc = thr[np.argmax(np.subtract(tpr,fpr))]
            print(probs_xc)

            # independent nsclc
            # np.savez_compressed('../model/scan_ens/nsclc/nsclc_indep_logits.npz',
            #     y_x_logit=probs_x,y_c_logit=probs_c,y_xc_logit=probs_xc)

            # breast independent            
            # np.savez_compressed('../model/scan_ens/breast/breast_indep_logits.npz',
            #     y_x_logit=probs_x)  # only microarray subnetwork for breast indep.
            
            # original
            np.savez_compressed('../model/scan_ens/nsclc/nsclc_logits.npz',
                y_x_logit=probs_x,y_c_logit=probs_c,y_xc_logit=probs_xc,
                thr_best_x=thr_best_x,thr_best_c=thr_best_c,thr_best_xc=thr_best_xc)



if __name__ == "__main__":
    args = parse_args()
    main(args)