#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import print_function from __future__ import division import argparse import numpy as np from sklearn.metrics import auc,roc_curve from sklearn.svm import SVC from sklearn.metrics import accuracy_score from sklearn.ensemble import RandomForestClassifier from lifelines.utils import concordance_index from lifelines import KaplanMeierFitter from lifelines.statistics import logrank_test import matplotlib.pyplot as plt from sklearn.metrics import f1_score def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--split_num',type=int,default=0,help='number of validation split') # SVM: 1; RF: 0 parser.add_argument('--is_svm',default=False,help='`True` if running SVM') parser.add_argument('--is_rf',default=True,help='`True` if running RF') # SVM CV hyper-parameters parser.add_argument('--c',type=float,default=0.01) parser.add_argument('--gamma',type=float,default=0.01) # RF CV hyper-parameters parser.add_argument('--rf_max_depth',type=int,default=1) parser.add_argument('--rf_max_features',default='sqrt') # [None, 'sqrt', 'log2'] parser.add_argument('--is_nsclc',default=False) parser.add_argument('--is_breast',default=False) parser.add_argument('--is_colon',default=True) return parser.parse_args() def main(args): np.random.seed(1234) if args.is_nsclc: data_path = '../data/nsclc/nsclc_' split_num = 0 if args.is_breast: data_path = '../data/breast/breast_' split_num = 1 if args.is_colon: # need tuning data_path = '../data/colon/resplit/colon_' split_num = 0 data = np.load(data_path + str(split_num) + '.npz',allow_pickle=True) x_train,x_valid,x_test = data['x_train'],data['x_valid'],data['x_test'] c_train,c_valid,c_test = data['c_train'],data['c_valid'],data['c_test'] y_train,y_valid,y_test = data['y_train'],data['y_valid'],data['y_test'] if args.is_nsclc or args.is_colon: x_train = np.concatenate((x_train,x_valid),axis=0) c_train = np.concatenate((c_train,c_valid),axis=0) y_train = np.concatenate((y_train,y_valid)) x_train = np.concatenate((x_train,c_train),axis=1) x_valid = np.concatenate((x_valid,c_valid),axis=1) x_test = np.concatenate((x_test ,c_test ),axis=1) if args.is_breast: x_train = x_train * data['scale'][10:] + data['mean'][10:] x_test = x_test * data['scale'][10:] + data['mean'][10:] c_train = c_train * data['scale'][:10] + data['mean'][:10] c_test = c_test * data['scale'][:10] + data['mean'][:10] x_train = np.concatenate((x_train,x_valid),axis=0) c_train = np.concatenate((c_train,c_valid),axis=0) y_train = np.concatenate((y_train,y_valid)) x_train = np.concatenate((x_train,c_train),axis=1) x_valid = np.concatenate((x_valid,c_valid),axis=1) x_test = np.concatenate((x_test ,c_test ),axis=1) # normalization num_clinical = 10 MIN = x_train.min(axis=0) MAX = x_train.max(axis=0) for col in [0,2,6]: x_train[:,col] = (x_train[:,col] - MIN[col]) / (MAX[col] - MIN[col]) x_test[:,col] = (x_test[:,col] - MIN[col]) / (MAX[col] - MIN[col]) for col in range(30): if col < num_clinical: continue x_train[:,col] = (x_train[:,col] - MIN[col]) / (MAX[col] - MIN[col]) x_test[:,col] = (x_test[:,col] - MIN[col]) / (MAX[col] - MIN[col]) if args.is_svm: if args.is_nsclc: model = SVC(C=4,gamma=0.25) if args.is_breast: model = SVC(C=10,gamma=0.001) if args.is_colon: model = SVC(C=10,gamma=0.001) # need tuning model = model.fit(x_train,y_train.astype(int)) y_pred = model.predict(x_test) fpr,tpr,thr = roc_curve(y_test.astype(int),y_pred) print(auc(fpr,tpr)) if args.is_nsclc: np.savez_compressed('../model/benchmarks/nsclc_svm_logits.npz',pred=y_pred,thr=0.0) if args.is_breast: np.savez_compressed('../model/benchmarks/breast_svm_logits.npz',pred=y_pred,thr=0.0) if args.is_rf: if args.is_nsclc: model = RandomForestClassifier(max_depth=5,n_estimators=500) if args.is_breast: model = RandomForestClassifier(max_features=9,max_depth=4,min_samples_leaf=10,n_estimators=1000) if args.is_colon: model = RandomForestClassifier(max_depth=5,n_estimators=500) # need tuning model = model.fit(x_train,y_train.astype(int)) y_pred = model.predict_proba(x_test)[:,1] fpr,tpr,thr = roc_curve(y_test.astype(int),y_pred) print(auc(fpr,tpr)) if args.is_nsclc: np.savez_compressed('../model/benchmarks/nsclc_rf_logits.npz',pred=y_pred,thr=0.5) if args.is_breast: np.savez_compressed('../model/benchmarks/breast_rf_logits.npz',pred=y_pred,thr=0.5) if __name__ == "__main__": args = parse_args() main(args)