import numpy as np import tensorflow as tf import pickle import matplotlib.pyplot as plt from matplotlib.lines import Line2D from matplotlib import rcParams, cycler import seaborn as sns from sklearn.model_selection import StratifiedShuffleSplit from sklearn.metrics import roc_auc_score,auc,roc_curve,f1_score from sklearn.metrics import accuracy_score,average_precision_score from sklearn.metrics import precision_recall_curve from sklearn.svm import SVC from lifelines.utils import concordance_index from lifelines import KaplanMeierFitter from lifelines import CoxPHFitter from lifelines.statistics import logrank_test from scipy import interp from scipy.stats import ks_2samp,wasserstein_distance def softplus(x): return np.log(np.add(1.0,np.exp(x))) def sigmoid(x): return np.divide(1.0,np.add(1.0,np.exp(-x))) def calc_c_index_benchmark(o_test,y_prob): '''calculate C-index given event time and predicted probs''' time = np.array(np.expand_dims(o_test,axis=1),dtype=float) # passing predicted probs through monotonically-increasing function gives the same CI prob = np.tanh(y_prob) # tanh prevents large values causing arithemetic errors # prob = np.exp(y_prob) mean = np.mean(prob) y_score = (prob - mean) return concordance_index(time,-y_score) def calc_acc_score(thr,y_label,y_prob): '''calculate accuracy scores''' label = 1.0 * np.ones_like(y_prob) for i in range(np.shape(label)[0]): if y_prob[i] < thr: label[i] = 0.0 return accuracy_score(y_label,label) def create_bs_idx_indep(num_bs_samples=1000): '''create bootstrapped test sets for independent validation sets''' bs_nsclc,bs_breast = [],[] # nsclc data_nsclc = np.load('../data/nsclc/indep.npz',allow_pickle=True) label_nsclc = data_nsclc['y_test'] index_nsclc = np.arange(np.shape(label_nsclc)[0]) for i in range(num_bs_samples): bs_idx = [] for j in range(len(index_nsclc)): bs_idx.extend(np.random.choice(index_nsclc,size=1)) # draw with replacement bs_nsclc.append(np.array(bs_idx,dtype=np.int32)) # breast data_breast = np.load('../data/breast/indep_valid_geos_21653.npz',allow_pickle=True) label_breast = data_breast['labs'] index_breast = np.arange(np.shape(label_breast)[0]) for i in range(num_bs_samples): bs_idx = [] for j in range(len(index_breast)): bs_idx.extend(np.random.choice(index_breast,size=1)) # draw with replacement bs_breast.append(np.array(bs_idx,dtype=np.int32)) np.savez_compressed('../data/bs_idx_indep.npz', bs_nsclc=bs_nsclc,bs_breast=bs_breast) def create_bs_idx(num_bs_samples=1000): '''create bootstrapped test sets''' bs_nsclc,bs_breast = [],[] # nsclc data_nsclc = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True) label_nsclc = data_nsclc['y_test'] index_nsclc = np.arange(np.shape(label_nsclc)[0]) for i in range(num_bs_samples): bs_idx = [] for j in range(len(index_nsclc)): bs_idx.extend(np.random.choice(index_nsclc,size=1)) # draw with replacement bs_nsclc.append(np.array(bs_idx,dtype=np.int32)) # breast data_breast = np.load('../data/breast/breast_1.npz',allow_pickle=True) label_breast = data_breast['y_test'] index_breast = np.arange(np.shape(label_breast)[0]) for i in range(num_bs_samples): bs_idx = [] for j in range(len(index_breast)): bs_idx.extend(np.random.choice(index_breast,size=1)) # draw with replacement bs_breast.append(np.array(bs_idx,dtype=np.int32)) np.savez_compressed('../data/bs_idx.npz', bs_nsclc=bs_nsclc,bs_breast=bs_breast) def plot_metric_bars(mean_est,std_est,name): '''summarize performance metrics with bar charts''' num_models = 4 custom_lines = [Line2D([0], [0], color='black', lw=4, alpha=0.5), Line2D([0], [0], color='navy', lw=4, alpha=0.5), Line2D([0], [0], color='brown', lw=4, alpha=0.5), Line2D([0], [0], color='orange', lw=4, alpha=0.5)] plt.legend(custom_lines, ['SCAN','Bimodal','RF','SVM'], loc='lower right',fontsize='x-large') mean_aucs,mean_prcs,mean_uf1s,mean_cis,mean_accs = [],[],[],[],[] std_aucs,std_prcs,std_uf1s,std_cis,std_accs = [],[],[],[],[] for i in range(num_models): mean_aucs.append(mean_est[i][0]) mean_prcs.append(mean_est[i][4]) mean_uf1s.append(mean_est[i][1]) mean_cis.append(mean_est[i][2]) mean_accs.append(mean_est[i][3]) std_aucs.append(std_est[i][0]) std_prcs.append(std_est[i][4]) std_uf1s.append(std_est[i][1]) std_cis.append(std_est[i][2]) std_accs.append(std_est[i][3]) colors=['black','navy','brown','orange'] plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs, yerr=std_aucs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs, yerr=std_prcs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(4,5,num_models) - 0.5, mean_uf1s, yerr=std_uf1s, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(6,7,num_models) - 0.5, mean_cis, yerr=std_cis, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(8,9,num_models) - 0.5, mean_accs, yerr=std_accs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) tick_labels = ['AUROC','AUPRC','macro F1','CI','ACC'] plt.xticks(np.arange(5) * 2,tick_labels,fontsize='x-large') plt.title('Performance metrics (' + name + ')',fontsize='x-large') plt.show() def get_bs_estimates(bs_idx,data,pred,thr=None,num_bs_samples=1000): '''get 95% CI bootstrap estimates''' label = data['y_test'] if thr != None: Label = np.zeros(np.shape(label)[0]) for i in range(len(Label)): if pred[i,1] > thr: Label[i] = 1 else: Label = np.argmax(pred,axis=1) o_test = data['o_test'] acc_bs,auc_bs,ci_bs,uf1_bs,prc_bs = [],[],[],[],[] tprs,prcs = [],[] mean_fpr = np.linspace(0,1,100) mean_rec = np.linspace(0,1,100) ratio = np.bincount(label.astype(int)) ratio = ratio[1] / (ratio[0] + ratio[1]) # random PRC # mean estimates fpr,tpr,thr = roc_curve(label.astype(int),pred[:,1],pos_label=1) mean_auc = auc(fpr,tpr) mean_uf1 = f1_score(label.astype(int),Label.astype(int),pos_label=1,average='macro') mean_ci = calc_c_index_benchmark(o_test,pred[:,1]) mean_acc = accuracy_score(label.astype(int),Label.astype(int)) mean_prc = average_precision_score(label.astype(int),pred[:,1]) for i in range(num_bs_samples): label_bs = label[bs_idx[i]] Label_bs = Label[bs_idx[i]] pred_bs = pred[bs_idx[i],:] o_test_bs = o_test[bs_idx[i]] fpr,tpr,thr = roc_curve(label_bs.astype(int),pred_bs[:,1],pos_label=1) auc_test_bs = auc(fpr,tpr) uf1_test_bs = f1_score(label_bs.astype(int),Label_bs.astype(int),pos_label=1,average='macro') ci_test_bs = calc_c_index_benchmark(o_test_bs,pred_bs[:,1]) acc_test_bs = accuracy_score(label_bs.astype(int),Label_bs.astype(int)) prc_test_bs = average_precision_score(label_bs.astype(int),pred_bs[:,1]) auc_bs.append(auc_test_bs) uf1_bs.append(uf1_test_bs) ci_bs.append(ci_test_bs) acc_bs.append(acc_test_bs) prc_bs.append(prc_test_bs) # ROC tprs.append(interp(mean_fpr,fpr,tpr)) tprs[-1][0] = 0.0 # PRC prc,rec,_ = precision_recall_curve(label_bs.astype(int),pred_bs[:,1]) prcs.append(interp(mean_rec,1 - rec,prc)) # compute the \delta distribution auc_bs = np.array(auc_bs) - mean_auc uf1_bs = np.array(uf1_bs) - mean_uf1 ci_bs = np.array(ci_bs) - mean_ci acc_bs = np.array(acc_bs) - mean_acc prc_bs = np.array(prc_bs) - mean_prc def get_ci_dev(delta): per025 = np.percentile(delta,2.5) per975 = np.percentile(delta,97.5) return (per975 - per025)/2.0 print('AUC = %.4f +- %.4f' % (mean_auc,get_ci_dev(auc_bs))) print('UF1 = %.4f +- %.4f' % (mean_uf1,get_ci_dev(uf1_bs))) print('CI = %.4f +- %.4f' % (mean_ci,get_ci_dev(ci_bs))) print('ACC = %.4f +- %.4f' % (mean_acc,get_ci_dev(acc_bs))) print('PRC = %.4f +- %.4f' % (mean_prc,get_ci_dev(prc_bs))) # plot ROC curves plt.plot([0,1],[0,1],color='navy',linestyle='--') mean_tpr = np.mean(tprs,axis=0) mean_tpr[-1] = 1.0 plt.plot(mean_fpr, mean_tpr, color='b', label='Mean ROC', lw=2, alpha=.8) std_tpr = np.std(tprs, axis=0) tprs_upper = np.minimum(mean_tpr + std_tpr, 1) tprs_lower = np.maximum(mean_tpr - std_tpr, 0) plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.4, label='$\pm$ 1 std. dev.') tprs_upper = np.minimum(mean_tpr + 2*std_tpr, 1) tprs_lower = np.maximum(mean_tpr - 2*std_tpr, 0) plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, label='$\pm$ 2 std. dev.') plt.legend(loc='lower right',fontsize='x-large') plt.xlim([0.0,1.0]) plt.ylim([0.0,1.05]) plt.grid(True) plt.title('Bootstrap AUROC = %0.4f $\pm$ %0.4f' % (mean_auc, get_ci_dev(auc_bs)),fontsize='x-large') plt.ylabel('True positive rate',fontsize='x-large') plt.xlabel('False positive rate',fontsize='x-large') plt.show() # plot PRC curve mean_prc_ = np.mean(prcs,axis=0) mean_prc_[-1] = 1.0 mean_prc_ = mean_prc_[::-1] std_prc = np.std(prcs, axis=0) std_prc = std_prc[::-1] plt.plot(mean_rec, mean_prc_, color='b', label='Mean PRC', lw=2, alpha=.8) prcs_upper = np.minimum(mean_prc_ + std_prc, 1) prcs_lower = np.maximum(mean_prc_ - std_prc, 0) plt.fill_between(mean_rec, prcs_lower, prcs_upper, color='grey', alpha=.4, label='$\pm$ 1 std. dev.') prcs_upper = np.minimum(mean_prc_ + 2*std_prc, 1) prcs_lower = np.maximum(mean_prc_ - 2*std_prc, 0) plt.fill_between(mean_rec, prcs_lower, prcs_upper, color='grey', alpha=.4, label='$\pm$ 2 std. dev.') random_x = np.arange(len(mean_prc_)) random_y = np.ones_like(random_x) * ratio plt.plot(random_x,random_y,color='navy',linestyle='--') plt.legend(loc='lower right',fontsize='x-large') plt.xlim([0.0,1.0]) plt.ylim([0.0,1.05]) plt.grid(True) plt.title('Bootstrap AUPRC = %0.4f $\pm$ %0.4f' % (mean_prc, get_ci_dev(prc_bs)),fontsize='x-large') plt.ylabel('Precision',fontsize='x-large') plt.xlabel('Recall',fontsize='x-large') plt.show() return [mean_auc,mean_uf1,mean_ci,mean_acc,mean_prc], \ [get_ci_dev(auc_bs),get_ci_dev(uf1_bs),get_ci_dev(ci_bs),get_ci_dev(acc_bs),get_ci_dev(prc_bs)] def plot_km(time,pred,thr,event=[]): '''KM-analysis''' np.random.seed(5) time = np.array(np.expand_dims(time,axis=1),dtype=float) if len(event) == 0: event = np.ones_like(time) else: event = np.array(np.expand_dims(event,axis=1),dtype=float) t = np.linspace(0,60,61) idx = pred >= thr kmf = KaplanMeierFitter() kmf.fit(time[idx],event[idx],timeline=t,label='high risk') ax = kmf.plot() kmf.fit(time[~idx],event[~idx],timeline=t,label='low risk') kmf.plot(ax=ax) plt.legend(loc='lower left',fontsize='x-large') #plt.grid(True) plt.title('Survival plot',fontsize='x-large') plt.show(block=True) logrank = logrank_test(time[idx],time[~idx],event[idx],event[~idx]) logrank.print_summary() def plot_subgroup(group0,group1,title='title',unit='unit'): '''plot the predicted value distributions for subgroups''' kwargs = dict(hist_kws={'alpha':.4},kde_kws={'linewidth':2}) sns.distplot(group0,label='good prognosis',**kwargs) sns.distplot(group1,label='poor prognosis',**kwargs) plt.title(title,fontsize='x-large') plt.xlabel(title + ' (' + unit + ')',fontsize='x-large') plt.legend(fontsize='x-large') plt.show() def plot_z_interp(z_mean,prob,label,add_idx=0): '''plot learned latent codes interpolation''' # identify subgroups with labels idx_1 = [idx for idx,lab in enumerate(label) if lab >= 0.5] idx_0 = [idx for idx,lab in enumerate(label) if lab < 0.5] z_mean_0,z_mean_1 = z_mean[idx_0,:],z_mean[idx_1,:] # identify max/min/med prob acts max_prob,min_prob,med_prob = np.max(prob),np.min(prob),np.median(prob) max_idx = np.argwhere(prob == max_prob) min_idx = np.argwhere(prob == min_prob) med_idx = np.argwhere(prob == med_prob) # theoretically the hardest one to separate dim_z = np.shape(z_mean)[1] p_vals = [] for i in range(dim_z): t_stat,p_val = ks_2samp(z_mean_0[:,i],z_mean_1[:,i]) p_vals.append(p_val) p_vals = np.array(p_vals) best_idx = np.argmax(p_vals) # print(best_idx) # include add_idx for better illustration for idx in [best_idx,add_idx]: kwargs = dict(hist_kws={'alpha':.4},kde_kws={'linewidth':2}) sns.distplot(z_mean_0[:,idx],label='good prognosis',**kwargs) sns.distplot(z_mean_1[:,idx],label='poor prognosis',**kwargs) plt.axvline(x=z_mean[max_idx,idx],ymin=0,ymax=1,color='red') plt.axvline(x=z_mean[med_idx,idx],ymin=0,ymax=1,color='navy') plt.axvline(x=z_mean[min_idx,idx],ymin=0,ymax=1,color='green') plt.legend() plt.title('Autoencoder latent distribution (' + str(idx + 1) + ')') # index starts from 1 plt.show() def get_feat_imp(weight_dict,w_names,v_name,tick_labels,method): '''connection weight analysis (Garson's algo. w/o absolute value)''' if len(w_names) == 1: W = weight_dict[w_names[0]] else: W = weight_dict[w_names[0]] for w_name in w_names[1:]: # W = softplus(np.matmul(W,weight_dict[w_name])) W = np.matmul(W,weight_dict[w_name]) # V = sigmoid(weight_dict[v_name[0]]) V = weight_dict[v_name[0]] V = np.transpose(np.expand_dims(V[:,1],axis=1)) V = np.tile(V,[np.shape(W)[0],1]) cw = np.multiply(W,V) if method == 'cw': method_disp = 'Connection Weights' elif method == 'garson': method_disp = 'Garson\'s Algorithm' norm = np.sum(np.abs(cw),axis=0) cw = np.divide(cw,norm) elif method == 'garson_abs': method_disp = 'Garson\'s Algorithm (Abs.)' norm = np.sum(np.abs(cw),axis=0) cw = np.divide(np.abs(cw),norm) cw = np.sum(cw,axis=1) plt.figure() abs_imp,colors = np.abs(cw),[] per25,per50,per75 = np.percentile(abs_imp,25),np.percentile(abs_imp,50),np.percentile(abs_imp,75) for idx,imp in enumerate(abs_imp): if imp <= per25: colors.append('navy') elif imp <= per50: colors.append('brown') elif imp <= per75: colors.append('orange') else: colors.append('red') # cmap = plt.cm.coolwarm # color = cmap(0.5) custom_lines = [Line2D([0], [0], color='navy', lw=4, alpha=0.5), Line2D([0], [0], color='brown', lw=4, alpha=0.5), Line2D([0], [0], color='orange', lw=4, alpha=0.5), Line2D([0], [0], color='red', lw=4, alpha=0.5)] plt.legend(custom_lines, ['Low','Inter. (low)','Inter. (high)','High']) conf_axis = np.arange(len(cw)) + 1 basis_axis = np.arange(len(cw) + 2) basis = np.zeros_like(basis_axis) plt.plot(basis_axis,basis,alpha=0.25,color='grey') plt.bar(conf_axis,cw,width=0.8,align='center',alpha=0.5,color=colors,label='feature importance') plt.xticks(conf_axis,tick_labels,rotation='vertical') plt.subplots_adjust(bottom=0.20) plt.ylabel('Feature Importance ') plt.title('Feature Importance (' + method_disp + ')') plt.show() def predict_np(x,w_type,is_x): '''make predictions with learned weights (q-model)''' with open('../model/scan/' + w_type + '_wq_merge.p','rb') as fp: wq_merge = pickle.load(fp) # load q-model (classifier) weights if w_type == 'breast': if is_x: w_names = ['x_mer_0','x_mer_1'] else: w_names = ['c_mer_0'] if w_type == 'nsclc': if is_x: w_names = ['x_mer_0'] else: w_names = [] for name in w_names: x = np.matmul(x,wq_merge[name]) x = softplus(x) # output layer if is_x: x = np.matmul(x,wq_merge['x_to_merge']) else: x = np.matmul(x,wq_merge['c_to_merge']) x = np.matmul(x,wq_merge['merge_output']) x = sigmoid(x) return x def partial_dep_plot(data,label,tick_labels,w_type,is_x,n_steps=100): '''partial dependency plot''' plt.figure() diffs = [] # difference in AUC scores for targ_idx in range(np.shape(data)[1]): x_max = np.max(data[:,targ_idx]) x_min = np.min(data[:,targ_idx]) # x_step = np.linspace(x_min,x_max,n_steps) n_scale = 2.0 # fold increase/decrease x_scale = np.linspace(-n_scale,n_scale,n_steps) aucs = [] for step in range(n_steps): x_new = data.copy() # for i in range(np.shape(data)[0]): x_new[i,targ_idx] = x_step[step] for i in range(np.shape(data)[0]): x_new[i,targ_idx] += x_new[i,targ_idx] * x_scale[step] logit = predict_np(x_new,w_type,is_x) fpr,tpr,thr = roc_curve(label.astype(int),logit[:,1],pos_label=1) auc_new = auc(fpr,tpr) aucs.append(auc_new) # plt.plot(np.arange(n_steps),aucs,label=str(targ_idx)) print('diff = ' + str(np.max(aucs) - np.min(aucs))) diffs.append(aucs[0] - aucs[-1]) diffs = np.array(diffs) print(np.argsort(-diffs)) abs_diff,colors = np.abs(diffs),[] per25,per50,per75 = np.percentile(abs_diff,25),np.percentile(abs_diff,50),np.percentile(abs_diff,75) for idx,diff in enumerate(abs_diff): if diff <= per25: colors.append('navy') elif diff <= per50: colors.append('brown') elif diff <= per75: colors.append('orange') else: colors.append('red') custom_lines = [Line2D([0], [0], color='navy', lw=4, alpha=0.5), Line2D([0], [0], color='brown', lw=4, alpha=0.5), Line2D([0], [0], color='orange', lw=4, alpha=0.5), Line2D([0], [0], color='red', lw=4, alpha=0.5)] plt.legend(custom_lines, ['Low','Inter. (low)','Inter. (high)','High']) basis_axis = np.arange(len(diffs) + 2) basis = np.zeros_like(basis_axis) plt.plot(basis_axis,basis,alpha=0.25,color='grey') plt.title('Partial dependency plot') plt.ylabel('AUROC score differences') plt.xticks(np.arange(np.shape(data)[1]),tick_labels,rotation='vertical') plt.subplots_adjust(bottom=0.20) plt.bar(np.arange(np.shape(data)[1]),diffs,width=0.8,align='center',color=colors,alpha=0.5) # plt.legend() plt.show() def plot_metric_bars_ens(mean_est,std_est): '''summarize performance metrics with bar charts''' num_models = 4 custom_lines = [Line2D([0], [0], color='black', lw=4, alpha=0.5), Line2D([0], [0], color='navy', lw=4, alpha=0.5), Line2D([0], [0], color='brown', lw=4, alpha=0.5), Line2D([0], [0], color='orange', lw=4, alpha=0.5)] plt.legend(custom_lines, ['M2 (breast)','Bimoal (breast)','M2 (nsclc)','Bimodal (nsclc)'],loc='lower right') mean_aucs,mean_prcs,mean_uf1s,mean_cis,mean_accs = [],[],[],[],[] std_aucs,std_prcs,std_uf1s,std_cis,std_accs = [],[],[],[],[] for i in range(num_models): mean_aucs.append(mean_est[i][0]) mean_prcs.append(mean_est[i][4]) mean_uf1s.append(mean_est[i][1]) mean_cis.append(mean_est[i][2]) mean_accs.append(mean_est[i][3]) std_aucs.append(std_est[i][0]) std_prcs.append(std_est[i][4]) std_uf1s.append(std_est[i][1]) std_cis.append(std_est[i][2]) std_accs.append(std_est[i][3]) colors=['black','navy','brown','orange'] plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs, yerr=std_aucs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs, yerr=std_prcs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(4,5,num_models) - 0.5, mean_uf1s, yerr=std_uf1s, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(6,7,num_models) - 0.5, mean_cis, yerr=std_cis, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(8,9,num_models) - 0.5, mean_accs, yerr=std_accs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) tick_labels = ['AUC','AUPRC','macro F1','CI','ACC'] plt.xticks(np.arange(5) * 2,tick_labels) plt.title('Performance metrics (ensemble)') plt.show() def plot_metric_bars_2(mean_est,std_est,name): '''summarize performance metrics with bar charts''' num_models = 6 custom_lines = [Line2D([0], [0], color='black', lw=4, alpha=0.5), Line2D([0], [0], color='navy', lw=4, alpha=0.5), Line2D([0], [0], color='brown', lw=4, alpha=0.5), Line2D([0], [0], color='orange', lw=4, alpha=0.5), Line2D([0], [0], color='green', lw=4, alpha=0.5), Line2D([0], [0], color='pink', lw=4, alpha=0.5)] plt.legend(custom_lines, ['100%','80%','60%','40%','20%','0%'], loc='lower right',fontsize='x-large') mean_aucs,mean_prcs,mean_uf1s,mean_cis,mean_accs = [],[],[],[],[] std_aucs,std_prcs,std_uf1s,std_cis,std_accs = [],[],[],[],[] for i in range(num_models): mean_aucs.append(mean_est[0][i]) mean_prcs.append(mean_est[3][i]) mean_uf1s.append(mean_est[1][i]) mean_cis.append(mean_est[2][i]) mean_accs.append(mean_est[4][i]) std_aucs.append(std_est[0][i]) std_prcs.append(std_est[3][i]) std_uf1s.append(std_est[1][i]) std_cis.append(std_est[2][i]) std_accs.append(std_est[4][i]) colors=['black','navy','brown','orange','green','pink'] plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs, yerr=std_aucs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs, yerr=std_prcs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(4,5,num_models) - 0.5, mean_uf1s, yerr=std_uf1s, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(6,7,num_models) - 0.5, mean_cis, yerr=std_cis, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(8,9,num_models) - 0.5, mean_accs, yerr=std_accs, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) tick_labels = ['AUC','AUPRC','macro F1','CI','ACC'] plt.xticks(np.arange(5) * 2,tick_labels,fontsize='x-large') plt.title('Performance with decreased unlabeled data (' + name + ')', fontsize='x-large') plt.ylim([0.30,1.00]) plt.show() def plot_metric_bars_3(mean_est,std_est): '''summarize performance metrics with bar charts''' num_models = 6 custom_lines = [Line2D([0], [0], color='black', lw=4, alpha=0.5), Line2D([0], [0], color='navy', lw=4, alpha=0.5), Line2D([0], [0], color='brown', lw=4, alpha=0.5), Line2D([0], [0], color='orange', lw=4, alpha=0.5), Line2D([0], [0], color='green', lw=4, alpha=0.5), Line2D([0], [0], color='pink', lw=4, alpha=0.5)] plt.legend(custom_lines, ['100%','80%','60%','40%','20%','0%'], loc='lower right',fontsize='x-large') mean_aucs_breast,mean_prcs_breast,mean_cis_breast = [],[],[] std_aucs_breast,std_prcs_breast,std_cis_breast = [],[],[] mean_aucs_nsclc,mean_prcs_nsclc,mean_cis_nsclc = [],[],[] std_aucs_nsclc,std_prcs_nsclc,std_cis_nsclc = [],[],[] for i in range(num_models): mean_aucs_breast.append(mean_est[0][i]) mean_prcs_breast.append(mean_est[1][i]) mean_cis_breast.append(mean_est[2][i]) mean_aucs_nsclc.append(mean_est[3][i]) mean_prcs_nsclc.append(mean_est[4][i]) mean_cis_nsclc.append(mean_est[5][i]) std_aucs_breast.append(std_est[0][i]) std_prcs_breast.append(std_est[1][i]) std_cis_breast.append(std_est[2][i]) std_aucs_nsclc.append(std_est[3][i]) std_prcs_nsclc.append(std_est[4][i]) std_cis_nsclc.append(std_est[5][i]) colors=['black','navy','brown','orange','green','pink'] plt.bar(np.linspace(0,1,num_models) - 0.5, mean_aucs_breast, yerr=std_aucs_breast, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(2,3,num_models) - 0.5, mean_prcs_breast, yerr=std_prcs_breast, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(4,5,num_models) - 0.5, mean_cis_breast, yerr=std_cis_breast, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(6,7,num_models) - 0.5, mean_aucs_nsclc, yerr=std_aucs_nsclc, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(8,9,num_models) - 0.5, mean_prcs_nsclc, yerr=std_prcs_nsclc, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) plt.bar(np.linspace(10,11,num_models) - 0.5, mean_cis_nsclc, yerr=std_cis_nsclc, align='center', alpha=0.5, color=colors, width=0.8/(num_models/1.6)) tick_labels = [ 'AUROC\n(breast)','AUPRC\n(breast)','CI\n(breast)', 'AUROC\n(nsclc)','AUPRC\n(nsclc)','CI\n(nsclc)'] plt.xticks(np.arange(6) * 2,tick_labels,fontsize='x-large') plt.title('Performance with decreased unlabeled data', fontsize='x-large') plt.ylim([0.30,1.00]) plt.show() def main(): if False: # init bootstrap indices # only do once and fixed for the following experiments create_bs_idx_indep(num_bs_samples=1000) create_bs_idx(num_bs_samples=1000) if False: # plot_metric_bars (subfigure a) is_breast = False is_nsclc = not is_breast if is_breast: name = 'breast' mean_est_m2,std_est_m2 = [0.8173,0.7255,0.6902,0.7265,0.7701],[0.0798,0.0833,0.0443,0.0855,0.1172] mean_est_bi,std_est_bi = [0.7771,0.7434,0.6733,0.7436,0.7571],[0.0846,0.0806,0.0468,0.0812,0.1144] mean_est_rf,std_est_rf = [0.7836,0.6987,0.6828,0.7009,0.7661],[0.0825,0.0808,0.0455,0.0812,0.1123] mean_est_svm,std_est_svm = [0.7435,0.7431,0.6445,0.7521,0.6935],[0.0750,0.0802,0.0395,0.0769,0.0956] if is_nsclc: name = 'nsclc' mean_est_m2,std_est_m2 = [0.8046,0.7270,0.6103,0.7485,0.6083],[0.0661,0.0718,0.0470,0.0673,0.1285] mean_est_bi,std_est_bi = [0.7867,0.6895,0.6012,0.7193,0.5650],[0.0675,0.0724,0.0440,0.0673,0.1455] mean_est_rf,std_est_rf = [0.7941,0.6790,0.5882,0.7719,0.5979],[0.0651,0.0792,0.0480,0.0614,0.1351] mean_est_svm,std_est_svm = [0.5716,0.5713,0.5201,0.6901,0.3522],[0.0693,0.0827,0.0360,0.0673,0.0926] mean_est = [mean_est_m2,mean_est_bi,mean_est_rf,mean_est_svm] std_est = [std_est_m2,std_est_bi,std_est_rf,std_est_svm] plot_metric_bars(mean_est,std_est,name) if False: # get_bs_estimates (single) (subfig b,c; supp J) # you can also use these codes for ablation study & less unlabeled data bs_idx_all = np.load('../data/bs_idx.npz',allow_pickle=True) is_breast = False is_nsclc = not is_breast if is_breast: data = np.load('../data/breast/breast_1.npz',allow_pickle=True) bs_idx = bs_idx_all['bs_breast'] if is_nsclc: data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True) bs_idx = bs_idx_all['bs_nsclc'] # scan if is_breast: logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True) if is_nsclc: logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True) pred = logits['y_xc_logit'] thr = logits['thr_best_xc'] mean_est_m2,std_est_m2 = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000) # # Bimodal # if is_breast: logits = np.load('../model/bimodal/breast_bimodal_logits.npz',allow_pickle=True) # if is_nsclc: logits = np.load('../model/bimodal/nsclc_bimodal_logits.npz',allow_pickle=True) # if is_breast: # pred = np.expand_dims(logits['pred'],axis=1) # pred_slack = np.ones_like(pred) * (-999) # pred = np.concatenate((pred_slack,pred),axis=1) # if is_nsclc: pred = logits['pred'] # thr = logits['thr'] # mean_est_bimodal,std_est_bimodal = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000) # # RF # if is_breast: logits = np.load('../model/benchmarks/breast_rf_logits.npz',allow_pickle=True) # if is_nsclc: logits = np.load('../model/benchmarks/nsclc_rf_logits.npz',allow_pickle=True) # pred = np.expand_dims(logits['pred'],axis=1) # thr = logits['thr'] # pred_slack = np.ones_like(pred) * (-999) # pred = np.concatenate((pred_slack,pred),axis=1) # mean_est_rf,std_est_rf = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000) # # SVM # if is_breast: logits = np.load('../model/benchmarks/breast_svm_logits.npz',allow_pickle=True) # if is_nsclc: logits = np.load('../model/benchmarks/nsclc_svm_logits.npz',allow_pickle=True) # pred = np.expand_dims(logits['pred'],axis=1) # thr = logits['thr'] # pred_slack = np.ones_like(pred) * (-999) # pred = np.concatenate((pred_slack,pred),axis=1) # mean_est_svm,std_est_svm = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000) if False: # plot_km (subfig d) is_breast = False is_nsclc = not is_breast if is_breast: data = np.load('../data/breast/breast_1.npz',allow_pickle=True) logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True) if is_nsclc: data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True) logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True) time = data['o_test'] pred = logits['y_xc_logit'][:,1] thr = logits['thr_best_xc'] if is_breast: plot_km(time,pred,thr,event=[]) if is_nsclc: event = data['e_test'] plot_km(time,pred,thr,event=event) if False: # plot_subgroup (subfig e,f) is_breast = False is_nsclc = not is_breast if is_breast: data = np.load('../data/breast/breast_1.npz',allow_pickle=True) logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True) if is_nsclc: data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True) logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True) prob = logits['y_xc_logit'] thr = logits['thr_best_xc'] # prediction probability versus real label (subfig f) idx0_label = [idx for idx,targ in enumerate(data['y_test']) if targ < 0.5] idx1_label = [idx for idx,targ in enumerate(data['y_test']) if targ > 0.5] group0 = prob[idx0_label,1] group1 = prob[idx1_label,1] plot_subgroup(group0,group1,title='Prediction probability',unit='probability') # predicted probability (subfig e) idx0_prob = [idx for idx,targ in enumerate(prob[:,1]) if targ < thr] idx1_prob = [idx for idx,targ in enumerate(prob[:,1]) if targ > thr] group0 = data['o_test'][idx0_label] group1 = data['o_test'][idx1_label] print(np.median(group0)) print(np.median(group1)) plot_subgroup(group0,group1,title='Survival time',unit='months') if False: # plot_z_interp (supp F) is_breast = False is_nsclc = not is_breast if is_breast: z_data = np.load('../model/scan/breast_z.npz',allow_pickle=True) logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True) data = np.load('../data/breast/breast_1.npz',allow_pickle=True) if is_nsclc: z_data = np.load('../model/scan/nsclc_z.npz',allow_pickle=True) logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True) data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True) z_mean,z_logstd = z_data['z_mean'],z_data['z_logstd'] prob = logits['y_xc_logit'][:,1] label = data['y_test'] if is_breast: plot_z_interp(z_mean,prob,label,add_idx=9) if is_nsclc: plot_z_interp(z_mean,prob,label,add_idx=1) if False: # get_feat_imp (supp G) is_breast,is_x = False,False is_nsclc,is_c = not is_breast,not is_x if is_breast: with open('../model/scan/breast_wq_merge.p','rb') as fp: wq_merge = pickle.load(fp) clinical_feature = ['Age', 'Menopausal \nState', 'Size', 'Radio \nTherapy', 'Chemotherapy', 'Hormone \nTherapy', 'Neoplasm \nHistologic \nGrade', 'Cellularity', 'Surgery-\nbreast \nconserving', 'Surgery-\nmastectomy'] gene_feature = ['ESR1','PGR','ERBB2','MKI67','PLAU', 'ELAVL1','EGFR','BTRC','FBXO6','SHMT2','KRAS','SRPK2', 'YWHAQ','PDHA1','EWSR1','ZDHHC17','ENO1','DBN1','PLK1','GSK3B'] if is_x: w_names = ['x_mer_0','x_mer_1','x_to_merge'] if is_c: w_names = ['c_mer_0','c_to_merge'] v_name = ['merge_output'] if is_x: get_feat_imp(wq_merge,w_names,v_name,gene_feature,'cw') if is_c: get_feat_imp(wq_merge,w_names,v_name,clinical_feature,'cw') if is_nsclc: with open('../model/scan/nsclc_wq_merge.p','rb') as fp: wq_merge = pickle.load(fp) gene_feature = ['EPCAM','HIF1A','PKM','PTK7','ALCAM','CADM1','SLC2A1', 'CUL1','CUL3','EGFR','ELAVL1','GRB2','NRF1','RNF2','RPA2'] clinical_feature = ['Age','Gender','Stage'] if is_x: w_names = ['x_mer_0','x_to_merge'] if is_c: w_names = ['c_to_merge'] v_name = ['merge_output'] if is_x: get_feat_imp(wq_merge,w_names,v_name,gene_feature,'cw') if is_c: get_feat_imp(wq_merge,w_names,v_name,clinical_feature,'cw') if False: # partial_dep_plot (supp H) is_breast,is_x = False,True is_nsclc,is_c = not is_breast,not is_x if is_breast: data = np.load('../data/breast/breast_1.npz',allow_pickle=True) clinical_feature = ['Age', 'Menopausal \nState', 'Size', 'Radio \nTherapy', 'Chemotherapy', 'Hormone \nTherapy', 'Neoplasm \nHistologic \nGrade', 'Cellularity', 'Surgery-\nbreast \nconserving', 'Surgery-\nmastectomy'] gene_feature = ['ESR1','PGR','ERBB2','MKI67','PLAU', 'ELAVL1','EGFR','BTRC','FBXO6','SHMT2','KRAS','SRPK2', 'YWHAQ','PDHA1','EWSR1','ZDHHC17','ENO1','DBN1','PLK1','GSK3B'] label = data['y_test'] if is_x: data = data['x_test'] partial_dep_plot(data,label,gene_feature,'breast',True) if is_c: data = data['c_test'] partial_dep_plot(data,label,clinical_feature,'breast',False) if is_nsclc: data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True) gene_feature = ['EPCAM','HIF1A','PKM','PTK7','ALCAM','CADM1','SLC2A1', 'CUL1','CUL3','EGFR','ELAVL1','GRB2','NRF1','RNF2','RPA2'] clinical_feature = ['Age','Gender','Stage'] label = data['y_test'] if is_x: data = data['x_test'] partial_dep_plot(data,label,gene_feature,'nsclc',True) if is_c: data = data['c_test'] partial_dep_plot(data,label,clinical_feature,'nsclc',False) if False: # independent validation sets (supp I) bs_idx_all = np.load('../data/bs_idx_indep.npz',allow_pickle=True) is_breast = False is_nsclc = not is_breast if is_nsclc: data = np.load('../data/nsclc/indep.npz',allow_pickle=True) bs_idx = bs_idx_all['bs_nsclc'] # scan logits = np.load('../model/scan/nsclc_indep_logits.npz',allow_pickle=True) ori_logits = np.load('../model/scan/nsclc_logits.npz',allow_pickle=True) # scan (ensemble) # logits = np.load('../model/scan_ens/nsclc/nsclc_indep_logits.npz',allow_pickle=True) # ori_logits = np.load('../model/scan_ens/nsclc/nsclc_logits.npz',allow_pickle=True) # bimodal # logits = np.load('../model/bimodal/nsclc_bimodal_indep_logits.npz',allow_pickle=True) # ori_logits = np.load('../model/bimodal/nsclc_bimodal_logits.npz',allow_pickle=True) # bimodal (ensemble) # logits = np.load('../model/bimodal_ens/nsclc/nsclc_bimodal_indep_logits.npz',allow_pickle=True) # ori_logits = np.load('../model/bimodal_ens/nsclc/nsclc_bimodal_logits.npz',allow_pickle=True) if is_breast: data = np.load('../data/breast/indep.npz',allow_pickle=True) bs_idx = bs_idx_all['bs_breast'] # scan logits = np.load('../model/scan/breast_indep_logits.npz',allow_pickle=True) ori_logits = np.load('../model/scan/breast_logits.npz',allow_pickle=True) # scan (ensemble) # logits = np.load('../model/scan_ens/breast/breast_indep_logits.npz',allow_pickle=True) # ori_logits = np.load('../model/scan_ens/breast/breast_logits.npz',allow_pickle=True) # bimodal # logits = np.load('../model/bimodal/breast_bimodal_indep_logits.npz',allow_pickle=True) # ori_logits = np.load('../model/bimodal/breast_bimodal_logits.npz',allow_pickle=True) # bimodal (ensemble) # logits = np.load('../model/bimodal_ens/breast/breast_bimodal_indep_logits.npz',allow_pickle=True) # ori_logits = np.load('../model/bimodal_ens/breast/breast_bimodal_logits.npz',allow_pickle=True) pred = logits['y_xc_logit'] # scan thr = ori_logits['thr_best_xc'] # scan # pred = logits['pred'] # thr = ori_logits['thr'] # pred = 1.0 - pred ################################################# # if is_breast: # pred = np.expand_dims(logits['pred'],axis=1) # pred_slack = np.ones_like(pred) * (-999) # pred = np.concatenate((pred_slack,pred),axis=1) # if is_nsclc: pred = logits['pred'] # thr = logits['thr'] ############################################# # need this for ensemble scan pred = np.expand_dims(pred,axis=1) pred_slack = np.ones_like(pred) * (-999) pred = np.concatenate((pred_slack,pred),axis=1) mean_est_m2,std_est_m2 = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000) if False: # get_bs_estimates (ensemble) (supp I) bs_idx_all = np.load('../data/bs_idx.npz',allow_pickle=True) is_breast = False is_nsclc = not is_breast is_m2 = True is_bimodal = not is_m2 if is_breast: data = np.load('../data/breast/breast_1.npz',allow_pickle=True) bs_idx = bs_idx_all['bs_breast'] if is_nsclc: data = np.load('../data/nsclc/nsclc_3.npz',allow_pickle=True) bs_idx = bs_idx_all['bs_nsclc'] if is_breast and is_m2: logits = np.load('../model/scan_ens/breast/breast_logits.npz',allow_pickle=True) if is_breast and is_bimodal: logits = np.load('../model/bimodal_ens/breast/breast_bimodal_logits.npz',allow_pickle=True) if is_nsclc and is_m2: logits = np.load('../model/scan_ens/nsclc/nsclc_logits.npz',allow_pickle=True) if is_nsclc and is_bimodal: logits = np.load('../model/bimodal_ens/nsclc/nsclc_bimodal_logits.npz',allow_pickle=True) if is_m2: pred = logits['y_xc_logit'] thr = logits['thr_best_xc'] if is_bimodal: pred = logits['pred'] thr = logits['thr'] pred = np.expand_dims(pred,axis=1) pred_slack = np.ones_like(pred) * (-999) pred = np.concatenate((pred_slack,pred),axis=1) if is_bimodal and is_nsclc: pred = 1.0 - pred mean_est_m2,std_est_m2 = get_bs_estimates(bs_idx,data,pred,thr=thr,num_bs_samples=1000) if False: # plot metric bars with less unlabeled data (supp J) is_breast = False is_nsclc = not is_breast if is_breast: name = 'breast' mean_est_auc = [0.8173,0.8170,0.8173,0.8073,0.8073,0.7968] mean_est_f1 = [0.7255,0.7347,0.7431,0.7350,0.7436,0.7179] mean_est_ci = [0.6902,0.6904,0.6903,0.6909,0.6879,0.6813] mean_est_prc = [0.7701,0.7694,0.7686,0.7535,0.7523,0.7434] mean_est_acc = [0.7265,0.7350,0.7436,0.7350,0.7436,0.7179] std_est_auc = [0.0798,0.0794,0.0789,0.0809,0.0823,0.0834] std_est_f1 = [0.0833,0.0806,0.0807,0.0776,0.0806,0.0806] std_est_ci = [0.0443,0.0436,0.0442,0.0459,0.0453,0.0456] std_est_prc = [0.1172,0.1163,0.1183,0.1195,0.1203,0.1211] std_est_acc = [0.0855,0.0769,0.0769,0.0769,0.0812,0.0769] if is_nsclc: name = 'nsclc' mean_est_auc = [0.8046,0.8072,0.8145,0.8153,0.8109,0.8054] mean_est_f1 = [0.7270,0.7153,0.7250,0.7260,0.7129,0.6681] mean_est_ci = [0.6103,0.5930,0.6037,0.6022,0.5929,0.6125] mean_est_prc = [0.6083,0.6078,0.6294,0.6321,0.6216,0.5874] mean_est_acc = [0.7485,0.7427,0.7485,0.7544,0.7544,0.7076] std_est_auc = [0.0661,0.0653,0.0662,0.0640,0.0667,0.0661] std_est_f1 = [0.0718,0.0698,0.0723,0.0688,0.0726,0.0734] std_est_ci = [0.0470,0.0488,0.0471,0.0485,0.0494,0.0467] std_est_prc = [0.1285,0.1378,0.1271,0.1256,0.1353,0.1352] std_est_acc = [0.0673,0.0643,0.0673,0.0643,0.0614,0.0644] mean_est = [mean_est_auc,mean_est_f1,mean_est_ci,mean_est_prc,mean_est_acc] std_est = [std_est_auc,std_est_f1,std_est_ci,std_est_prc,std_est_acc] plot_metric_bars_2(mean_est,std_est,name) if False: # plot metric bars with less unlabeled data (supp J) (breast + nsclc) mean_est_auc_breast = [0.8173,0.8170,0.8173,0.8073,0.8073,0.7968] std_est_auc_breast = [0.0798,0.0794,0.0789,0.0809,0.0823,0.0834] mean_est_prc_breast = [0.7701,0.7694,0.7686,0.7535,0.7523,0.7434] std_est_prc_breast = [0.1172,0.1163,0.1183,0.1195,0.1203,0.1211] mean_est_ci_breast = [0.6902,0.6904,0.6903,0.6909,0.6879,0.6813] std_est_ci_breast = [0.0443,0.0436,0.0442,0.0459,0.0453,0.0456] mean_est_auc_nsclc = [0.8046,0.8072,0.8145,0.8153,0.8109,0.8054] std_est_auc_nsclc = [0.0661,0.0653,0.0662,0.0640,0.0667,0.0661] mean_est_prc_nsclc = [0.6083,0.6078,0.6294,0.6321,0.6216,0.5874] std_est_prc_nsclc = [0.1285,0.1378,0.1271,0.1256,0.1353,0.1352] mean_est_ci_nsclc = [0.6103,0.5930,0.6037,0.6022,0.5929,0.6125] std_est_ci_nsclc = [0.0470,0.0488,0.0471,0.0485,0.0494,0.0467] mean_est = [mean_est_auc_breast,mean_est_prc_breast,mean_est_ci_breast, mean_est_auc_nsclc,mean_est_prc_nsclc,mean_est_ci_nsclc] std_est = [std_est_auc_breast,std_est_prc_breast,std_est_ci_breast, std_est_auc_nsclc,std_est_prc_nsclc,std_est_ci_nsclc] plot_metric_bars_3(mean_est,std_est) if __name__ == '__main__': main()