scan / src / train_benchmark_all.py
train_benchmark_all.py
Raw
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import argparse
import numpy as np
from sklearn.metrics import auc,roc_curve
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from lifelines.utils import concordance_index

from lifelines import KaplanMeierFitter
from lifelines.statistics import logrank_test
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--split_num',type=int,default=0,help='number of validation split')  # SVM: 1; RF: 0
    parser.add_argument('--is_svm',default=False,help='`True` if running SVM')
    parser.add_argument('--is_rf',default=True,help='`True` if running RF')

    # SVM CV hyper-parameters
    parser.add_argument('--c',type=float,default=0.01)
    parser.add_argument('--gamma',type=float,default=0.01)

    # RF CV hyper-parameters
    parser.add_argument('--rf_max_depth',type=int,default=1)
    parser.add_argument('--rf_max_features',default='sqrt')  # [None, 'sqrt', 'log2']

    parser.add_argument('--is_nsclc',default=False)
    parser.add_argument('--is_breast',default=False)
    parser.add_argument('--is_colon',default=True)

    return parser.parse_args()

def main(args):
    np.random.seed(1234)


    if args.is_nsclc:
        data_path = '../data/nsclc/nsclc_'
        split_num = 0

    if args.is_breast:
        data_path = '../data/breast/breast_'
        split_num = 1

    if args.is_colon:  # need tuning
        data_path = '../data/colon/resplit/colon_'
        split_num = 0

    data = np.load(data_path + str(split_num) + '.npz',allow_pickle=True)
    x_train,x_valid,x_test = data['x_train'],data['x_valid'],data['x_test']
    c_train,c_valid,c_test = data['c_train'],data['c_valid'],data['c_test']
    y_train,y_valid,y_test = data['y_train'],data['y_valid'],data['y_test']

    if args.is_nsclc or args.is_colon:
        x_train = np.concatenate((x_train,x_valid),axis=0)
        c_train = np.concatenate((c_train,c_valid),axis=0)
        y_train = np.concatenate((y_train,y_valid))
        x_train = np.concatenate((x_train,c_train),axis=1)
        x_valid = np.concatenate((x_valid,c_valid),axis=1)
        x_test  = np.concatenate((x_test ,c_test ),axis=1)

    if args.is_breast:
        x_train = x_train * data['scale'][10:] + data['mean'][10:]
        x_test  = x_test  * data['scale'][10:] + data['mean'][10:]
        c_train = c_train * data['scale'][:10] + data['mean'][:10]
        c_test  = c_test  * data['scale'][:10] + data['mean'][:10]
        x_train = np.concatenate((x_train,x_valid),axis=0)
        c_train = np.concatenate((c_train,c_valid),axis=0)
        y_train = np.concatenate((y_train,y_valid))
        x_train = np.concatenate((x_train,c_train),axis=1)
        x_valid = np.concatenate((x_valid,c_valid),axis=1)
        x_test  = np.concatenate((x_test ,c_test ),axis=1)

        # normalization
        num_clinical = 10
        MIN = x_train.min(axis=0)
        MAX = x_train.max(axis=0)
        for col in [0,2,6]:
            x_train[:,col] = (x_train[:,col] - MIN[col]) / (MAX[col] - MIN[col])
            x_test[:,col] = (x_test[:,col] - MIN[col]) / (MAX[col] - MIN[col])

        for col in range(30):
            if col < num_clinical:  continue
            x_train[:,col] = (x_train[:,col] - MIN[col]) / (MAX[col] - MIN[col])
            x_test[:,col] = (x_test[:,col] - MIN[col]) / (MAX[col] - MIN[col])    

    if args.is_svm:
        if args.is_nsclc:   model = SVC(C=4,gamma=0.25)
        if args.is_breast:  model = SVC(C=10,gamma=0.001)
        if args.is_colon:   model = SVC(C=10,gamma=0.001)  # need tuning

        model = model.fit(x_train,y_train.astype(int))
        y_pred = model.predict(x_test)
        fpr,tpr,thr = roc_curve(y_test.astype(int),y_pred)
        print(auc(fpr,tpr))

        if args.is_nsclc:
            np.savez_compressed('../model/benchmarks/nsclc_svm_logits.npz',pred=y_pred,thr=0.0)
        if args.is_breast:
            np.savez_compressed('../model/benchmarks/breast_svm_logits.npz',pred=y_pred,thr=0.0)

    if args.is_rf:
        if args.is_nsclc:  model = RandomForestClassifier(max_depth=5,n_estimators=500)
        if args.is_breast:
            model = RandomForestClassifier(max_features=9,max_depth=4,min_samples_leaf=10,n_estimators=1000)
        if args.is_colon:  model = RandomForestClassifier(max_depth=5,n_estimators=500)  # need tuning

        model = model.fit(x_train,y_train.astype(int))
        y_pred = model.predict_proba(x_test)[:,1]
        fpr,tpr,thr = roc_curve(y_test.astype(int),y_pred)
        print(auc(fpr,tpr))

        if args.is_nsclc:
            np.savez_compressed('../model/benchmarks/nsclc_rf_logits.npz',pred=y_pred,thr=0.5)
        if args.is_breast:
            np.savez_compressed('../model/benchmarks/breast_rf_logits.npz',pred=y_pred,thr=0.5)

if __name__ == "__main__":
    args = parse_args()
    main(args)