scan / utils / plot_properties.py
plot_properties.py
Raw
import numpy as np
import tensorflow as tf
import pickle

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib import rcParams, cycler
import seaborn as sns

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score,auc,roc_curve,f1_score
from sklearn.metrics import accuracy_score,average_precision_score
from sklearn.metrics import precision_recall_curve
from sklearn.svm import SVC

from lifelines.utils import concordance_index
from lifelines import KaplanMeierFitter
from lifelines import CoxPHFitter
from lifelines.statistics import logrank_test

from scipy import interp
from scipy.stats import ks_2samp,wasserstein_distance

def softplus(x):  return np.log(np.add(1.0,np.exp(x)))
def sigmoid(x):  return np.divide(1.0,np.add(1.0,np.exp(-x)))
def calc_c_index_benchmark(o_test,y_prob):
    '''calculate C-index given event time and predicted probs'''
    time = np.array(np.expand_dims(o_test,axis=1),dtype=float)
    # passing predicted probs through monotonically-increasing function gives the same CI
    prob = np.tanh(y_prob)  # tanh prevents large values causing arithemetic errors
    # prob = np.exp(y_prob)
    mean = np.mean(prob)
    y_score = (prob - mean)
    return concordance_index(time,-y_score)
def calc_acc_score(thr,y_label,y_prob):
    '''calculate accuracy scores'''
    label = 1.0 * np.ones_like(y_prob)
    for i in range(np.shape(label)[0]):
        if y_prob[i] < thr:  label[i] = 0.0
    return accuracy_score(y_label,label)

def create_bs_idx_indep(num_bs_samples=1000):
    '''create bootstrapped test sets for independent validation sets'''
    bs_nsclc,bs_breast = [],[]

    # nsclc
    data_nsclc = np.load('../data/nsclc/indep.npz',allow_pickle=True)
    label_nsclc = data_nsclc['y_test']
    index_nsclc = np.arange(np.shape(label_nsclc)[0])
    for i in range(num_bs_samples):
        bs_idx = []
        for j in range(len(index_nsclc)):
            bs_idx.extend(np.random.choice(index_nsclc,size=1))  # draw with replacement
        bs_nsclc.append(np.array(bs_idx,dtype=np.int32))

    # breast
    data_breast = np.load('../data/breast/indep_valid_geos_21653.npz',allow_pickle=True)
    label_breast = data_breast['labs']
    index_breast = np.arange(np.shape(label_breast)[0])
    for i in range(num_bs_samples):
        bs_idx = []
        for j in range(len(index_breast)):
            bs_idx.extend(np.random.choice(index_breast,size=1))  # draw with replacement
        bs_breast.append(np.array(bs_idx,dtype=np.int32))

    np.savez_compressed('../data/bs_idx_indep.npz',
        bs_nsclc=bs_nsclc,bs_breast=bs_breast)
def create_bs_idx(num_bs_samples=1000):
    '''create bootstrapped test sets'''
    bs_nsclc,bs_breast = [],[]

    # nsclc
    data_nsclc = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True)
    label_nsclc = data_nsclc['y_test']
    index_nsclc = np.arange(np.shape(label_nsclc)[0])
    for i in range(num_bs_samples):
        bs_idx = []
        for j in range(len(index_nsclc)):
            bs_idx.extend(np.random.choice(index_nsclc,size=1))  # draw with replacement
        bs_nsclc.append(np.array(bs_idx,dtype=np.int32))

    # breast
    data_breast = np.load('../data/breast/breast_1.npz',allow_pickle=True)
    label_breast = data_breast['y_test']
    index_breast = np.arange(np.shape(label_breast)[0])
    for i in range(num_bs_samples):
        bs_idx = []
        for j in range(len(index_breast)):
            bs_idx.extend(np.random.choice(index_breast,size=1))  # draw with replacement
        bs_breast.append(np.array(bs_idx,dtype=np.int32))
    
    np.savez_compressed('../data/bs_idx.npz',
        bs_nsclc=bs_nsclc,bs_breast=bs_breast)

def plot_metric_bars(mean_est,std_est,name):
    '''summarize performance metrics with bar charts'''
    num_models = 4
    custom_lines = [Line2D([0], [0], color='black',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='navy',   lw=4, alpha=0.5),
                    Line2D([0], [0], color='brown',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='orange', lw=4, alpha=0.5)]
    plt.legend(custom_lines, ['SCAN','Bimodal','RF','SVM'],
                loc='lower right',fontsize='x-large')

    mean_aucs,mean_prcs,mean_uf1s,mean_cis,mean_accs = [],[],[],[],[]
    std_aucs,std_prcs,std_uf1s,std_cis,std_accs = [],[],[],[],[]

    for i in range(num_models):
        mean_aucs.append(mean_est[i][0])
        mean_prcs.append(mean_est[i][4])
        mean_uf1s.append(mean_est[i][1])
        mean_cis.append(mean_est[i][2])
        mean_accs.append(mean_est[i][3])
        std_aucs.append(std_est[i][0])
        std_prcs.append(std_est[i][4])
        std_uf1s.append(std_est[i][1])
        std_cis.append(std_est[i][2])
        std_accs.append(std_est[i][3])

    colors=['black','navy','brown','orange']
    plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs,
            yerr=std_aucs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs,
            yerr=std_prcs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(4,5,num_models) - 0.5, mean_uf1s,
            yerr=std_uf1s, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(6,7,num_models) - 0.5, mean_cis,
            yerr=std_cis, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(8,9,num_models) - 0.5, mean_accs,
            yerr=std_accs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    tick_labels = ['AUROC','AUPRC','macro F1','CI','ACC']
    plt.xticks(np.arange(5) * 2,tick_labels,fontsize='x-large')
    plt.title('Performance metrics (' + name + ')',fontsize='x-large')
    plt.show() 

def get_bs_estimates(bs_idx,data,pred,thr=None,num_bs_samples=1000):
    '''get 95% CI bootstrap estimates'''
    label = data['y_test']
    if thr != None:
        Label = np.zeros(np.shape(label)[0])
        for i in range(len(Label)):
            if pred[i,1] > thr:  Label[i] = 1
    else:  Label = np.argmax(pred,axis=1)
    o_test = data['o_test']

    acc_bs,auc_bs,ci_bs,uf1_bs,prc_bs = [],[],[],[],[]
    tprs,prcs = [],[]
    mean_fpr = np.linspace(0,1,100)
    mean_rec = np.linspace(0,1,100)
    ratio = np.bincount(label.astype(int))
    ratio = ratio[1] / (ratio[0] + ratio[1])  # random PRC

    # mean estimates
    fpr,tpr,thr = roc_curve(label.astype(int),pred[:,1],pos_label=1)
    mean_auc = auc(fpr,tpr)
    mean_uf1 = f1_score(label.astype(int),Label.astype(int),pos_label=1,average='macro')
    mean_ci = calc_c_index_benchmark(o_test,pred[:,1])
    mean_acc = accuracy_score(label.astype(int),Label.astype(int))
    mean_prc = average_precision_score(label.astype(int),pred[:,1])

    for i in range(num_bs_samples):
        label_bs = label[bs_idx[i]]
        Label_bs = Label[bs_idx[i]]
        pred_bs  = pred[bs_idx[i],:]
        o_test_bs = o_test[bs_idx[i]]

        fpr,tpr,thr = roc_curve(label_bs.astype(int),pred_bs[:,1],pos_label=1)
        auc_test_bs = auc(fpr,tpr)
        uf1_test_bs = f1_score(label_bs.astype(int),Label_bs.astype(int),pos_label=1,average='macro')
        ci_test_bs  = calc_c_index_benchmark(o_test_bs,pred_bs[:,1])
        acc_test_bs = accuracy_score(label_bs.astype(int),Label_bs.astype(int))
        prc_test_bs = average_precision_score(label_bs.astype(int),pred_bs[:,1])
        auc_bs.append(auc_test_bs)
        uf1_bs.append(uf1_test_bs)
        ci_bs.append(ci_test_bs)
        acc_bs.append(acc_test_bs)
        prc_bs.append(prc_test_bs)

        # ROC
        tprs.append(interp(mean_fpr,fpr,tpr))
        tprs[-1][0] = 0.0

        # PRC
        prc,rec,_ = precision_recall_curve(label_bs.astype(int),pred_bs[:,1])
        prcs.append(interp(mean_rec,1 - rec,prc))

    # compute the \delta distribution
    auc_bs = np.array(auc_bs) - mean_auc
    uf1_bs = np.array(uf1_bs) - mean_uf1
    ci_bs  = np.array(ci_bs)  - mean_ci
    acc_bs = np.array(acc_bs) - mean_acc
    prc_bs = np.array(prc_bs) - mean_prc

    def get_ci_dev(delta):
        per025 = np.percentile(delta,2.5)
        per975 = np.percentile(delta,97.5)
        return (per975 - per025)/2.0           

    print('AUC = %.4f +- %.4f' % (mean_auc,get_ci_dev(auc_bs)))
    print('UF1 = %.4f +- %.4f' % (mean_uf1,get_ci_dev(uf1_bs)))
    print('CI  = %.4f +- %.4f' % (mean_ci,get_ci_dev(ci_bs)))
    print('ACC = %.4f +- %.4f' % (mean_acc,get_ci_dev(acc_bs)))
    print('PRC = %.4f +- %.4f' % (mean_prc,get_ci_dev(prc_bs)))

    # plot ROC curves
    plt.plot([0,1],[0,1],color='navy',linestyle='--')
    mean_tpr = np.mean(tprs,axis=0)
    mean_tpr[-1] = 1.0
    plt.plot(mean_fpr, mean_tpr, color='b', label='Mean ROC', lw=2, alpha=.8)
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.4,
                     label='$\pm$ 1 std. dev.')
    tprs_upper = np.minimum(mean_tpr + 2*std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - 2*std_tpr, 0)
    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
                     label='$\pm$ 2 std. dev.')
    plt.legend(loc='lower right',fontsize='x-large')
    plt.xlim([0.0,1.0])
    plt.ylim([0.0,1.05])
    plt.grid(True)
    plt.title('Bootstrap AUROC = %0.4f $\pm$ %0.4f' % (mean_auc, get_ci_dev(auc_bs)),fontsize='x-large')
    plt.ylabel('True positive rate',fontsize='x-large')
    plt.xlabel('False positive rate',fontsize='x-large')
    plt.show()

    # plot PRC curve
    mean_prc_ = np.mean(prcs,axis=0)
    mean_prc_[-1] = 1.0
    mean_prc_ = mean_prc_[::-1]
    std_prc = np.std(prcs, axis=0)
    std_prc = std_prc[::-1]
    plt.plot(mean_rec, mean_prc_, color='b', label='Mean PRC', lw=2, alpha=.8)
    prcs_upper = np.minimum(mean_prc_ + std_prc, 1)
    prcs_lower = np.maximum(mean_prc_ - std_prc, 0)
    plt.fill_between(mean_rec, prcs_lower, prcs_upper, color='grey', alpha=.4,
                     label='$\pm$ 1 std. dev.')
    prcs_upper = np.minimum(mean_prc_ + 2*std_prc, 1)
    prcs_lower = np.maximum(mean_prc_ - 2*std_prc, 0)
    plt.fill_between(mean_rec, prcs_lower, prcs_upper, color='grey', alpha=.4,
                     label='$\pm$ 2 std. dev.')
    random_x = np.arange(len(mean_prc_))
    random_y = np.ones_like(random_x) * ratio
    plt.plot(random_x,random_y,color='navy',linestyle='--')
    plt.legend(loc='lower right',fontsize='x-large')
    plt.xlim([0.0,1.0])
    plt.ylim([0.0,1.05])
    plt.grid(True)
    plt.title('Bootstrap AUPRC = %0.4f $\pm$ %0.4f' % (mean_prc, get_ci_dev(prc_bs)),fontsize='x-large')
    plt.ylabel('Precision',fontsize='x-large')
    plt.xlabel('Recall',fontsize='x-large')
    plt.show()

    return [mean_auc,mean_uf1,mean_ci,mean_acc,mean_prc], \
           [get_ci_dev(auc_bs),get_ci_dev(uf1_bs),get_ci_dev(ci_bs),get_ci_dev(acc_bs),get_ci_dev(prc_bs)]

def plot_km(time,pred,thr,event=[]):
    '''KM-analysis'''
    np.random.seed(5)
    time = np.array(np.expand_dims(time,axis=1),dtype=float)
    if len(event) == 0:  event = np.ones_like(time)
    else:  event = np.array(np.expand_dims(event,axis=1),dtype=float)
    t = np.linspace(0,60,61)
    idx = pred >= thr

    kmf = KaplanMeierFitter()
    kmf.fit(time[idx],event[idx],timeline=t,label='high risk')
    ax = kmf.plot()
    kmf.fit(time[~idx],event[~idx],timeline=t,label='low risk')
    kmf.plot(ax=ax)
    plt.legend(loc='lower left',fontsize='x-large')
    #plt.grid(True)
    plt.title('Survival plot',fontsize='x-large')
    plt.show(block=True)
    logrank = logrank_test(time[idx],time[~idx],event[idx],event[~idx])
    logrank.print_summary() 

def plot_subgroup(group0,group1,title='title',unit='unit'):
    '''plot the predicted value distributions for subgroups'''
    kwargs = dict(hist_kws={'alpha':.4},kde_kws={'linewidth':2})
    sns.distplot(group0,label='good prognosis',**kwargs)
    sns.distplot(group1,label='poor prognosis',**kwargs)
    plt.title(title,fontsize='x-large')
    plt.xlabel(title + ' (' + unit + ')',fontsize='x-large')
    plt.legend(fontsize='x-large')
    plt.show()

def plot_z_interp(z_mean,prob,label,add_idx=0):
    '''plot learned latent codes interpolation'''

    # identify subgroups with labels
    idx_1 = [idx for idx,lab in enumerate(label) if lab >= 0.5]
    idx_0 = [idx for idx,lab in enumerate(label) if lab <  0.5]
    z_mean_0,z_mean_1 = z_mean[idx_0,:],z_mean[idx_1,:]

    # identify max/min/med prob acts
    max_prob,min_prob,med_prob = np.max(prob),np.min(prob),np.median(prob)
    max_idx = np.argwhere(prob == max_prob)
    min_idx = np.argwhere(prob == min_prob)
    med_idx = np.argwhere(prob == med_prob)

    # theoretically the hardest one to separate
    dim_z = np.shape(z_mean)[1]
    p_vals = []
    for i in range(dim_z):
        t_stat,p_val = ks_2samp(z_mean_0[:,i],z_mean_1[:,i])
        p_vals.append(p_val)
    p_vals = np.array(p_vals)
    best_idx = np.argmax(p_vals)
    # print(best_idx)

    # include add_idx for better illustration
    for idx in [best_idx,add_idx]:
        kwargs = dict(hist_kws={'alpha':.4},kde_kws={'linewidth':2})
        sns.distplot(z_mean_0[:,idx],label='good prognosis',**kwargs)
        sns.distplot(z_mean_1[:,idx],label='poor prognosis',**kwargs)
        plt.axvline(x=z_mean[max_idx,idx],ymin=0,ymax=1,color='red')
        plt.axvline(x=z_mean[med_idx,idx],ymin=0,ymax=1,color='navy')
        plt.axvline(x=z_mean[min_idx,idx],ymin=0,ymax=1,color='green')
        plt.legend()
        plt.title('Autoencoder latent distribution (' + str(idx + 1) + ')')  # index starts from 1
        plt.show()

def get_feat_imp(weight_dict,w_names,v_name,tick_labels,method):
    '''connection weight analysis (Garson's algo. w/o absolute value)'''
    if len(w_names) == 1:  W = weight_dict[w_names[0]]
    else:
        W = weight_dict[w_names[0]]
        for w_name in w_names[1:]:
            # W = softplus(np.matmul(W,weight_dict[w_name]))
            W = np.matmul(W,weight_dict[w_name])

    # V = sigmoid(weight_dict[v_name[0]])
    V = weight_dict[v_name[0]]
    V = np.transpose(np.expand_dims(V[:,1],axis=1))
    V = np.tile(V,[np.shape(W)[0],1])
    cw = np.multiply(W,V)

    if method == 'cw':
        method_disp = 'Connection Weights'
    elif method == 'garson':
        method_disp = 'Garson\'s Algorithm'
        norm = np.sum(np.abs(cw),axis=0)
        cw = np.divide(cw,norm)
    elif method == 'garson_abs':
        method_disp = 'Garson\'s Algorithm (Abs.)'
        norm = np.sum(np.abs(cw),axis=0)
        cw = np.divide(np.abs(cw),norm)    
    cw = np.sum(cw,axis=1)

    plt.figure()
    abs_imp,colors = np.abs(cw),[]
    per25,per50,per75 = np.percentile(abs_imp,25),np.percentile(abs_imp,50),np.percentile(abs_imp,75)
    for idx,imp in enumerate(abs_imp):
        if imp <= per25:    colors.append('navy')
        elif imp <= per50:  colors.append('brown')
        elif imp <= per75:  colors.append('orange')
        else:  colors.append('red')
        # cmap = plt.cm.coolwarm
        # color = cmap(0.5)
    custom_lines = [Line2D([0], [0], color='navy',   lw=4, alpha=0.5),
                    Line2D([0], [0], color='brown',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='orange', lw=4, alpha=0.5),
                    Line2D([0], [0], color='red',    lw=4, alpha=0.5)]
    plt.legend(custom_lines, ['Low','Inter. (low)','Inter. (high)','High'])
    
    conf_axis = np.arange(len(cw)) + 1
    basis_axis = np.arange(len(cw) + 2)
    basis = np.zeros_like(basis_axis)
    plt.plot(basis_axis,basis,alpha=0.25,color='grey')
    plt.bar(conf_axis,cw,width=0.8,align='center',alpha=0.5,color=colors,label='feature importance')
    plt.xticks(conf_axis,tick_labels,rotation='vertical')
    plt.subplots_adjust(bottom=0.20)
    plt.ylabel('Feature Importance ')
    plt.title('Feature Importance (' + method_disp + ')')
    plt.show() 

def predict_np(x,w_type,is_x):
    '''make predictions with learned weights (q-model)'''
    with open('../model/scan/' + w_type + '_wq_merge.p','rb') as fp:
        wq_merge = pickle.load(fp)  # load q-model (classifier) weights

        if w_type == 'breast':
            if is_x:  w_names = ['x_mer_0','x_mer_1']
            else:  w_names = ['c_mer_0']
        if w_type == 'nsclc':
            if is_x:  w_names = ['x_mer_0']
            else:  w_names = []

        for name in w_names:
            x = np.matmul(x,wq_merge[name])
            x = softplus(x)

        # output layer
        if is_x:  x = np.matmul(x,wq_merge['x_to_merge'])
        else:  x = np.matmul(x,wq_merge['c_to_merge'])
        x = np.matmul(x,wq_merge['merge_output'])
        x = sigmoid(x)
    return x
def partial_dep_plot(data,label,tick_labels,w_type,is_x,n_steps=100):
    '''partial dependency plot'''
    plt.figure()
    diffs = []  # difference in AUC scores
    for targ_idx in range(np.shape(data)[1]):
        x_max = np.max(data[:,targ_idx])
        x_min = np.min(data[:,targ_idx])
        # x_step = np.linspace(x_min,x_max,n_steps)
        n_scale = 2.0  # fold increase/decrease
        x_scale = np.linspace(-n_scale,n_scale,n_steps)
        aucs = []
        for step in range(n_steps):
            x_new = data.copy()
            # for i in range(np.shape(data)[0]):  x_new[i,targ_idx] = x_step[step]
            for i in range(np.shape(data)[0]):  x_new[i,targ_idx] += x_new[i,targ_idx] * x_scale[step]
            logit = predict_np(x_new,w_type,is_x)
            fpr,tpr,thr = roc_curve(label.astype(int),logit[:,1],pos_label=1)
            auc_new = auc(fpr,tpr)
            aucs.append(auc_new)
        # plt.plot(np.arange(n_steps),aucs,label=str(targ_idx))
        print('diff = ' + str(np.max(aucs) - np.min(aucs)))
        diffs.append(aucs[0] - aucs[-1])
    diffs = np.array(diffs)
    print(np.argsort(-diffs))

    abs_diff,colors = np.abs(diffs),[]
    per25,per50,per75 = np.percentile(abs_diff,25),np.percentile(abs_diff,50),np.percentile(abs_diff,75)
    for idx,diff in enumerate(abs_diff):
        if diff <= per25:    colors.append('navy')
        elif diff <= per50:  colors.append('brown')
        elif diff <= per75:  colors.append('orange')
        else:  colors.append('red')
    custom_lines = [Line2D([0], [0], color='navy',   lw=4, alpha=0.5),
                    Line2D([0], [0], color='brown',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='orange', lw=4, alpha=0.5),
                    Line2D([0], [0], color='red',    lw=4, alpha=0.5)]
    plt.legend(custom_lines, ['Low','Inter. (low)','Inter. (high)','High'])
    basis_axis = np.arange(len(diffs) + 2)
    basis = np.zeros_like(basis_axis)
    plt.plot(basis_axis,basis,alpha=0.25,color='grey')
    plt.title('Partial dependency plot')
    plt.ylabel('AUROC score differences')
    plt.xticks(np.arange(np.shape(data)[1]),tick_labels,rotation='vertical')
    plt.subplots_adjust(bottom=0.20)
    plt.bar(np.arange(np.shape(data)[1]),diffs,width=0.8,align='center',color=colors,alpha=0.5)
    # plt.legend()
    plt.show()    

def plot_metric_bars_ens(mean_est,std_est):
    '''summarize performance metrics with bar charts'''
    num_models = 4
    custom_lines = [Line2D([0], [0], color='black',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='navy',   lw=4, alpha=0.5),
                    Line2D([0], [0], color='brown',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='orange', lw=4, alpha=0.5)]
    plt.legend(custom_lines, ['M2 (breast)','Bimoal (breast)','M2 (nsclc)','Bimodal (nsclc)'],loc='lower right')

    mean_aucs,mean_prcs,mean_uf1s,mean_cis,mean_accs = [],[],[],[],[]
    std_aucs,std_prcs,std_uf1s,std_cis,std_accs = [],[],[],[],[]

    for i in range(num_models):
        mean_aucs.append(mean_est[i][0])
        mean_prcs.append(mean_est[i][4])
        mean_uf1s.append(mean_est[i][1])
        mean_cis.append(mean_est[i][2])
        mean_accs.append(mean_est[i][3])
        std_aucs.append(std_est[i][0])
        std_prcs.append(std_est[i][4])
        std_uf1s.append(std_est[i][1])
        std_cis.append(std_est[i][2])
        std_accs.append(std_est[i][3])

    colors=['black','navy','brown','orange']
    plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs,
            yerr=std_aucs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs,
            yerr=std_prcs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(4,5,num_models) - 0.5, mean_uf1s,
            yerr=std_uf1s, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(6,7,num_models) - 0.5, mean_cis,
            yerr=std_cis, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(8,9,num_models) - 0.5, mean_accs,
            yerr=std_accs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    tick_labels = ['AUC','AUPRC','macro F1','CI','ACC']
    plt.xticks(np.arange(5) * 2,tick_labels)
    plt.title('Performance metrics (ensemble)')
    plt.show()

def plot_metric_bars_2(mean_est,std_est,name):
    '''summarize performance metrics with bar charts'''
    num_models = 6
    custom_lines = [Line2D([0], [0], color='black',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='navy',   lw=4, alpha=0.5),
                    Line2D([0], [0], color='brown',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='orange', lw=4, alpha=0.5),
                    Line2D([0], [0], color='green',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='pink',   lw=4, alpha=0.5)]
    plt.legend(custom_lines, ['100%','80%','60%','40%','20%','0%'],
                loc='lower right',fontsize='x-large')

    mean_aucs,mean_prcs,mean_uf1s,mean_cis,mean_accs = [],[],[],[],[]
    std_aucs,std_prcs,std_uf1s,std_cis,std_accs = [],[],[],[],[]

    for i in range(num_models):
        mean_aucs.append(mean_est[0][i])
        mean_prcs.append(mean_est[3][i])
        mean_uf1s.append(mean_est[1][i])
        mean_cis.append(mean_est[2][i])
        mean_accs.append(mean_est[4][i])
        std_aucs.append(std_est[0][i])
        std_prcs.append(std_est[3][i])
        std_uf1s.append(std_est[1][i])
        std_cis.append(std_est[2][i])
        std_accs.append(std_est[4][i])

    colors=['black','navy','brown','orange','green','pink']
    plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs,
            yerr=std_aucs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs,
            yerr=std_prcs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(4,5,num_models) - 0.5, mean_uf1s,
            yerr=std_uf1s, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(6,7,num_models) - 0.5, mean_cis,
            yerr=std_cis, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(8,9,num_models) - 0.5, mean_accs,
            yerr=std_accs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))

    tick_labels = ['AUC','AUPRC','macro F1','CI','ACC']
    plt.xticks(np.arange(5) * 2,tick_labels,fontsize='x-large')
    plt.title('Performance with decreased unlabeled data (' + name + ')',
                fontsize='x-large')
    
    plt.ylim([0.30,1.00])
    plt.show() 

def plot_metric_bars_3(mean_est,std_est):
    '''summarize performance metrics with bar charts'''
    num_models = 6
    custom_lines = [Line2D([0], [0], color='black',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='navy',   lw=4, alpha=0.5),
                    Line2D([0], [0], color='brown',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='orange', lw=4, alpha=0.5),
                    Line2D([0], [0], color='green',  lw=4, alpha=0.5),
                    Line2D([0], [0], color='pink',   lw=4, alpha=0.5)]
    plt.legend(custom_lines, ['100%','80%','60%','40%','20%','0%'],
                loc='lower right',fontsize='x-large')

    mean_aucs_breast,mean_prcs_breast,mean_cis_breast = [],[],[]
    std_aucs_breast,std_prcs_breast,std_cis_breast = [],[],[]

    mean_aucs_nsclc,mean_prcs_nsclc,mean_cis_nsclc = [],[],[]
    std_aucs_nsclc,std_prcs_nsclc,std_cis_nsclc = [],[],[]

    for i in range(num_models):
        mean_aucs_breast.append(mean_est[0][i])
        mean_prcs_breast.append(mean_est[1][i])
        mean_cis_breast.append(mean_est[2][i])
        mean_aucs_nsclc.append(mean_est[3][i])
        mean_prcs_nsclc.append(mean_est[4][i])
        mean_cis_nsclc.append(mean_est[5][i])

        std_aucs_breast.append(std_est[0][i])
        std_prcs_breast.append(std_est[1][i])
        std_cis_breast.append(std_est[2][i])
        std_aucs_nsclc.append(std_est[3][i])
        std_prcs_nsclc.append(std_est[4][i])
        std_cis_nsclc.append(std_est[5][i])


    colors=['black','navy','brown','orange','green','pink']
    plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs_breast,
            yerr=std_aucs_breast, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs_breast,
            yerr=std_prcs_breast, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(4,5,num_models) - 0.5, mean_cis_breast,
            yerr=std_cis_breast, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))

    plt.bar(np.linspace(6,7,num_models) - 0.5, mean_aucs_nsclc,
            yerr=std_aucs_nsclc, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(8,9,num_models) - 0.5, mean_prcs_nsclc,
            yerr=std_prcs_nsclc, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))
    plt.bar(np.linspace(10,11,num_models) - 0.5, mean_cis_nsclc,
            yerr=std_cis_nsclc, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6))


    tick_labels = [
        'AUROC\n(breast)','AUPRC\n(breast)','CI\n(breast)',
        'AUROC\n(nsclc)','AUPRC\n(nsclc)','CI\n(nsclc)']
    plt.xticks(np.arange(6) * 2,tick_labels,fontsize='x-large')
    plt.title('Performance with decreased unlabeled data',
                fontsize='x-large')
    
    plt.ylim([0.30,1.00])
    plt.show() 

def main():
    
    if False:  # init bootstrap indices
        # only do once and fixed for the following experiments
        create_bs_idx_indep(num_bs_samples=1000)
        create_bs_idx(num_bs_samples=1000)

    if False:  # plot_metric_bars (subfigure a)
        is_breast = False
        is_nsclc = not is_breast
        if is_breast:
            name = 'breast'
            mean_est_m2,std_est_m2   = [0.8173,0.7255,0.6902,0.7265,0.7701],[0.0798,0.0833,0.0443,0.0855,0.1172]
            mean_est_bi,std_est_bi   = [0.7771,0.7434,0.6733,0.7436,0.7571],[0.0846,0.0806,0.0468,0.0812,0.1144]
            mean_est_rf,std_est_rf   = [0.7836,0.6987,0.6828,0.7009,0.7661],[0.0825,0.0808,0.0455,0.0812,0.1123]
            mean_est_svm,std_est_svm = [0.7435,0.7431,0.6445,0.7521,0.6935],[0.0750,0.0802,0.0395,0.0769,0.0956]
        if is_nsclc:
            name = 'nsclc'
            mean_est_m2,std_est_m2   = [0.8046,0.7270,0.6103,0.7485,0.6083],[0.0661,0.0718,0.0470,0.0673,0.1285]
            mean_est_bi,std_est_bi   = [0.7867,0.6895,0.6012,0.7193,0.5650],[0.0675,0.0724,0.0440,0.0673,0.1455]
            mean_est_rf,std_est_rf   = [0.7941,0.6790,0.5882,0.7719,0.5979],[0.0651,0.0792,0.0480,0.0614,0.1351]
            mean_est_svm,std_est_svm = [0.5716,0.5713,0.5201,0.6901,0.3522],[0.0693,0.0827,0.0360,0.0673,0.0926]

        mean_est = [mean_est_m2,mean_est_bi,mean_est_rf,mean_est_svm]
        std_est = [std_est_m2,std_est_bi,std_est_rf,std_est_svm]
        plot_metric_bars(mean_est,std_est,name)

    if False:  # get_bs_estimates (single) (subfig b,c; supp J)
        # you can also use these codes for ablation study & less unlabeled data 
        bs_idx_all = np.load('../data/bs_idx.npz',allow_pickle=True)
        is_breast = False
        is_nsclc = not is_breast
        if is_breast:
            data = np.load('../data/breast/breast_1.npz',allow_pickle=True)
            bs_idx = bs_idx_all['bs_breast']
        if is_nsclc:
            data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True)           
            bs_idx = bs_idx_all['bs_nsclc']            
        
        # scan
        if is_breast:  logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True)
        if is_nsclc:  logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True)
        pred = logits['y_xc_logit']
        thr = logits['thr_best_xc']
        mean_est_m2,std_est_m2 = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000)
        
        # # Bimodal
        # if is_breast:  logits = np.load('../model/bimodal/breast_bimodal_logits.npz',allow_pickle=True)
        # if is_nsclc:  logits = np.load('../model/bimodal/nsclc_bimodal_logits.npz',allow_pickle=True)       
        # if is_breast:
        #     pred = np.expand_dims(logits['pred'],axis=1)
        #     pred_slack = np.ones_like(pred) * (-999)
        #     pred = np.concatenate((pred_slack,pred),axis=1)
        # if is_nsclc:  pred = logits['pred']
        # thr = logits['thr']
        # mean_est_bimodal,std_est_bimodal = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000)

        # # RF
        # if is_breast:  logits = np.load('../model/benchmarks/breast_rf_logits.npz',allow_pickle=True)
        # if is_nsclc:  logits = np.load('../model/benchmarks/nsclc_rf_logits.npz',allow_pickle=True)
        # pred = np.expand_dims(logits['pred'],axis=1)
        # thr = logits['thr']
        # pred_slack = np.ones_like(pred) * (-999)
        # pred = np.concatenate((pred_slack,pred),axis=1)
        # mean_est_rf,std_est_rf = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000)

        # # SVM
        # if is_breast:  logits = np.load('../model/benchmarks/breast_svm_logits.npz',allow_pickle=True)
        # if is_nsclc:  logits = np.load('../model/benchmarks/nsclc_svm_logits.npz',allow_pickle=True)
        # pred = np.expand_dims(logits['pred'],axis=1)
        # thr = logits['thr']
        # pred_slack = np.ones_like(pred) * (-999)
        # pred = np.concatenate((pred_slack,pred),axis=1)
        # mean_est_svm,std_est_svm = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000)            

    if False:  # plot_km (subfig d)
        is_breast = False
        is_nsclc = not is_breast

        if is_breast:
            data = np.load('../data/breast/breast_1.npz',allow_pickle=True)
            logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True)
        if is_nsclc:
            data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True)
            logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True)    
        time = data['o_test']
        pred = logits['y_xc_logit'][:,1]
        thr  = logits['thr_best_xc']
        if is_breast:
            plot_km(time,pred,thr,event=[])
        if is_nsclc:
            event = data['e_test']
            plot_km(time,pred,thr,event=event)        

    if False:  # plot_subgroup (subfig e,f)
        is_breast = False
        is_nsclc = not is_breast

        if is_breast:
            data = np.load('../data/breast/breast_1.npz',allow_pickle=True)
            logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True)
        if is_nsclc:
            data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True)
            logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True)
        prob = logits['y_xc_logit']
        thr = logits['thr_best_xc']

        # prediction probability versus real label (subfig f)
        idx0_label = [idx for idx,targ in enumerate(data['y_test']) if targ < 0.5]
        idx1_label = [idx for idx,targ in enumerate(data['y_test']) if targ > 0.5]
        group0 = prob[idx0_label,1]
        group1 = prob[idx1_label,1]
        plot_subgroup(group0,group1,title='Prediction probability',unit='probability')

        # predicted probability (subfig e)
        idx0_prob = [idx for idx,targ in enumerate(prob[:,1]) if targ < thr]
        idx1_prob = [idx for idx,targ in enumerate(prob[:,1]) if targ > thr]
        group0 = data['o_test'][idx0_label]
        group1 = data['o_test'][idx1_label]        
        print(np.median(group0))
        print(np.median(group1))
        plot_subgroup(group0,group1,title='Survival time',unit='months')

    if False:  # plot_z_interp (supp F)
        is_breast = False
        is_nsclc = not is_breast
        if is_breast:
            z_data = np.load('../model/scan/breast_z.npz',allow_pickle=True)
            logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True)
            data = np.load('../data/breast/breast_1.npz',allow_pickle=True)
        if is_nsclc:
            z_data = np.load('../model/scan/nsclc_z.npz',allow_pickle=True)
            logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True)
            data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True)
        z_mean,z_logstd = z_data['z_mean'],z_data['z_logstd']
        prob = logits['y_xc_logit'][:,1]
        label = data['y_test']
        if is_breast:  plot_z_interp(z_mean,prob,label,add_idx=9)
        if is_nsclc:  plot_z_interp(z_mean,prob,label,add_idx=1)

    if False:  # get_feat_imp (supp G)
        is_breast,is_x = False,False
        is_nsclc,is_c = not is_breast,not is_x
        
        if is_breast:
            with open('../model/scan/breast_wq_merge.p','rb') as fp:  wq_merge = pickle.load(fp)
            clinical_feature = ['Age', 'Menopausal \nState', 'Size', 'Radio \nTherapy',
                'Chemotherapy', 'Hormone \nTherapy', 'Neoplasm \nHistologic \nGrade', 'Cellularity',
                'Surgery-\nbreast \nconserving', 'Surgery-\nmastectomy']
            gene_feature = ['ESR1','PGR','ERBB2','MKI67','PLAU',
                'ELAVL1','EGFR','BTRC','FBXO6','SHMT2','KRAS','SRPK2',
                'YWHAQ','PDHA1','EWSR1','ZDHHC17','ENO1','DBN1','PLK1','GSK3B']
            if is_x:  w_names = ['x_mer_0','x_mer_1','x_to_merge']
            if is_c:  w_names = ['c_mer_0','c_to_merge']
            v_name = ['merge_output']
            if is_x:  get_feat_imp(wq_merge,w_names,v_name,gene_feature,'cw')
            if is_c:  get_feat_imp(wq_merge,w_names,v_name,clinical_feature,'cw')
        
        if is_nsclc:
            with open('../model/scan/nsclc_wq_merge.p','rb') as fp:  wq_merge = pickle.load(fp)        
            gene_feature = ['EPCAM','HIF1A','PKM','PTK7','ALCAM','CADM1','SLC2A1',
                    'CUL1','CUL3','EGFR','ELAVL1','GRB2','NRF1','RNF2','RPA2']    
            clinical_feature = ['Age','Gender','Stage']
            if is_x:  w_names = ['x_mer_0','x_to_merge']
            if is_c:  w_names = ['c_to_merge']
            v_name = ['merge_output']
            if is_x:  get_feat_imp(wq_merge,w_names,v_name,gene_feature,'cw')
            if is_c:  get_feat_imp(wq_merge,w_names,v_name,clinical_feature,'cw')

    if False:  # partial_dep_plot (supp H)
        is_breast,is_x = False,True
        is_nsclc,is_c = not is_breast,not is_x

        if is_breast:
            data = np.load('../data/breast/breast_1.npz',allow_pickle=True)
            clinical_feature = ['Age', 'Menopausal \nState', 'Size', 'Radio \nTherapy',
                'Chemotherapy', 'Hormone \nTherapy', 'Neoplasm \nHistologic \nGrade', 'Cellularity',
                'Surgery-\nbreast \nconserving', 'Surgery-\nmastectomy']
            gene_feature = ['ESR1','PGR','ERBB2','MKI67','PLAU',
                'ELAVL1','EGFR','BTRC','FBXO6','SHMT2','KRAS','SRPK2',
                'YWHAQ','PDHA1','EWSR1','ZDHHC17','ENO1','DBN1','PLK1','GSK3B']
            label = data['y_test']
            if is_x:  
                data = data['x_test']
                partial_dep_plot(data,label,gene_feature,'breast',True)
            if is_c:  
                data = data['c_test']
                partial_dep_plot(data,label,clinical_feature,'breast',False)
            
        if is_nsclc:
            data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True)
            gene_feature = ['EPCAM','HIF1A','PKM','PTK7','ALCAM','CADM1','SLC2A1',
                    'CUL1','CUL3','EGFR','ELAVL1','GRB2','NRF1','RNF2','RPA2']    
            clinical_feature = ['Age','Gender','Stage']
            label = data['y_test']
            if is_x:  
                data = data['x_test']
                partial_dep_plot(data,label,gene_feature,'nsclc',True)
            if is_c:  
                data = data['c_test']
                partial_dep_plot(data,label,clinical_feature,'nsclc',False)

    if False:  # independent validation sets (supp I)
        bs_idx_all = np.load('../data/bs_idx_indep.npz',allow_pickle=True)
        is_breast = False
        is_nsclc = not is_breast

        if is_nsclc:
            data = np.load('../data/nsclc/indep.npz',allow_pickle=True)
            bs_idx = bs_idx_all['bs_nsclc']
            
            # scan
            logits = np.load('../model/scan/nsclc_indep_logits.npz',allow_pickle=True)
            ori_logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True)

            # scan (ensemble)
            # logits = np.load('../model/scan_ens/nsclc/nsclc_indep_logits.npz',allow_pickle=True)
            # ori_logits = np.load('../model/scan_ens/nsclc/nsclc_logits.npz',allow_pickle=True)

            # bimodal 
            # logits = np.load('../model/bimodal/nsclc_bimodal_indep_logits.npz',allow_pickle=True)
            # ori_logits = np.load('../model/bimodal/nsclc_bimodal_logits.npz',allow_pickle=True)

            # bimodal (ensemble)
            # logits = np.load('../model/bimodal_ens/nsclc/nsclc_bimodal_indep_logits.npz',allow_pickle=True)
            # ori_logits = np.load('../model/bimodal_ens/nsclc/nsclc_bimodal_logits.npz',allow_pickle=True)        


        if is_breast:
            data = np.load('../data/breast/indep.npz',allow_pickle=True)
            bs_idx = bs_idx_all['bs_breast']
            
            # scan
            logits = np.load('../model/scan/breast_indep_logits.npz',allow_pickle=True)
            ori_logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True)

            # scan (ensemble)
            # logits = np.load('../model/scan_ens/breast/breast_indep_logits.npz',allow_pickle=True)
            # ori_logits = np.load('../model/scan_ens/breast/breast_logits.npz',allow_pickle=True)

            # bimodal
            # logits = np.load('../model/bimodal/breast_bimodal_indep_logits.npz',allow_pickle=True)
            # ori_logits = np.load('../model/bimodal/breast_bimodal_logits.npz',allow_pickle=True)  

            # bimodal (ensemble)
            # logits = np.load('../model/bimodal_ens/breast/breast_bimodal_indep_logits.npz',allow_pickle=True)
            # ori_logits = np.load('../model/bimodal_ens/breast/breast_bimodal_logits.npz',allow_pickle=True)        


        pred = logits['y_xc_logit']  # scan
        thr = ori_logits['thr_best_xc']  # scan

        # pred = logits['pred']
        # thr = ori_logits['thr']
        # pred = 1.0 - pred

        #################################################
        # if is_breast:
        #     pred = np.expand_dims(logits['pred'],axis=1)
        #     pred_slack = np.ones_like(pred) * (-999)
        #     pred = np.concatenate((pred_slack,pred),axis=1)
        # if is_nsclc:  pred = logits['pred']
        # thr = logits['thr']
        #############################################

        # need this for ensemble scan
        pred = np.expand_dims(pred,axis=1)
        pred_slack = np.ones_like(pred) * (-999)
        pred = np.concatenate((pred_slack,pred),axis=1)

        mean_est_m2,std_est_m2 = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000)

    if False:  # get_bs_estimates (ensemble) (supp I)
        bs_idx_all = np.load('../data/bs_idx.npz',allow_pickle=True)
        is_breast = False
        is_nsclc = not is_breast
        is_m2 = True
        is_bimodal = not is_m2
        if is_breast:
            data = np.load('../data/breast/breast_1.npz',allow_pickle=True)
            bs_idx = bs_idx_all['bs_breast']
        if is_nsclc:
            data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True)           
            bs_idx = bs_idx_all['bs_nsclc'] 
        
        if is_breast and is_m2:  
            logits = np.load('../model/scan_ens/breast/breast_logits.npz',allow_pickle=True)
        if is_breast and is_bimodal:  
            logits = np.load('../model/bimodal_ens/breast/breast_bimodal_logits.npz',allow_pickle=True)
        if is_nsclc and is_m2:  
            logits = np.load('../model/scan_ens/nsclc/nsclc_logits.npz',allow_pickle=True)
        if is_nsclc and is_bimodal:  
            logits = np.load('../model/bimodal_ens/nsclc/nsclc_bimodal_logits.npz',allow_pickle=True)
        if is_m2:
            pred = logits['y_xc_logit']
            thr = logits['thr_best_xc']
        if is_bimodal:
            pred = logits['pred']
            thr = logits['thr']
        pred = np.expand_dims(pred,axis=1)
        pred_slack = np.ones_like(pred) * (-999)
        pred = np.concatenate((pred_slack,pred),axis=1)
        if is_bimodal and is_nsclc:  pred = 1.0 - pred
        mean_est_m2,std_est_m2 = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000)

    if False:  # plot metric bars with less unlabeled data (supp J)
        is_breast = False
        is_nsclc = not is_breast
        if is_breast:
            name = 'breast'
            mean_est_auc = [0.8173,0.8170,0.8173,0.8073,0.8073,0.7968]
            mean_est_f1  = [0.7255,0.7347,0.7431,0.7350,0.7436,0.7179]
            mean_est_ci  = [0.6902,0.6904,0.6903,0.6909,0.6879,0.6813]
            mean_est_prc = [0.7701,0.7694,0.7686,0.7535,0.7523,0.7434]
            mean_est_acc = [0.7265,0.7350,0.7436,0.7350,0.7436,0.7179]
            
            std_est_auc = [0.0798,0.0794,0.0789,0.0809,0.0823,0.0834]
            std_est_f1  = [0.0833,0.0806,0.0807,0.0776,0.0806,0.0806]
            std_est_ci  = [0.0443,0.0436,0.0442,0.0459,0.0453,0.0456]
            std_est_prc = [0.1172,0.1163,0.1183,0.1195,0.1203,0.1211]
            std_est_acc = [0.0855,0.0769,0.0769,0.0769,0.0812,0.0769]

        if is_nsclc:
            name = 'nsclc'
            mean_est_auc = [0.8046,0.8072,0.8145,0.8153,0.8109,0.8054]
            mean_est_f1  = [0.7270,0.7153,0.7250,0.7260,0.7129,0.6681]
            mean_est_ci  = [0.6103,0.5930,0.6037,0.6022,0.5929,0.6125]
            mean_est_prc = [0.6083,0.6078,0.6294,0.6321,0.6216,0.5874]
            mean_est_acc = [0.7485,0.7427,0.7485,0.7544,0.7544,0.7076]
            
            std_est_auc = [0.0661,0.0653,0.0662,0.0640,0.0667,0.0661]
            std_est_f1  = [0.0718,0.0698,0.0723,0.0688,0.0726,0.0734]
            std_est_ci  = [0.0470,0.0488,0.0471,0.0485,0.0494,0.0467]
            std_est_prc = [0.1285,0.1378,0.1271,0.1256,0.1353,0.1352]
            std_est_acc = [0.0673,0.0643,0.0673,0.0643,0.0614,0.0644]

        mean_est = [mean_est_auc,mean_est_f1,mean_est_ci,mean_est_prc,mean_est_acc]
        std_est = [std_est_auc,std_est_f1,std_est_ci,std_est_prc,std_est_acc]
        plot_metric_bars_2(mean_est,std_est,name)

    if False:  # plot metric bars with less unlabeled data (supp J) (breast + nsclc)

        mean_est_auc_breast = [0.8173,0.8170,0.8173,0.8073,0.8073,0.7968]
        std_est_auc_breast  = [0.0798,0.0794,0.0789,0.0809,0.0823,0.0834]
        mean_est_prc_breast = [0.7701,0.7694,0.7686,0.7535,0.7523,0.7434]
        std_est_prc_breast  = [0.1172,0.1163,0.1183,0.1195,0.1203,0.1211]
        mean_est_ci_breast  = [0.6902,0.6904,0.6903,0.6909,0.6879,0.6813]
        std_est_ci_breast   = [0.0443,0.0436,0.0442,0.0459,0.0453,0.0456]

        mean_est_auc_nsclc = [0.8046,0.8072,0.8145,0.8153,0.8109,0.8054]
        std_est_auc_nsclc  = [0.0661,0.0653,0.0662,0.0640,0.0667,0.0661]
        mean_est_prc_nsclc = [0.6083,0.6078,0.6294,0.6321,0.6216,0.5874]
        std_est_prc_nsclc  = [0.1285,0.1378,0.1271,0.1256,0.1353,0.1352]
        mean_est_ci_nsclc  = [0.6103,0.5930,0.6037,0.6022,0.5929,0.6125]
        std_est_ci_nsclc   = [0.0470,0.0488,0.0471,0.0485,0.0494,0.0467]

        mean_est = [mean_est_auc_breast,mean_est_prc_breast,mean_est_ci_breast,
            mean_est_auc_nsclc,mean_est_prc_nsclc,mean_est_ci_nsclc]
        std_est = [std_est_auc_breast,std_est_prc_breast,std_est_ci_breast,
            std_est_auc_nsclc,std_est_prc_nsclc,std_est_ci_nsclc]

        plot_metric_bars_3(mean_est,std_est)

if __name__ == '__main__':
    main()