authentication-ACSAC / runners / runner_train.py
runner_train.py
Raw
# -*- coding: utf-8 -*-
"""
Created on Sat Feb 19 19:22:25 2022

@author: Eidos
"""
import copy
import sys
import os
import wave
# Add the top level directory in system path
top_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
if not top_path in sys.path:
    sys.path.append(top_path)

import joblib
import numpy as np
import pandas as pd
import re
from sklearn import discriminant_analysis
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn import neighbors
from sklearn.ensemble import RandomForestClassifier
from sklearn import svm
# from thundersvm import SVC

import toolbox.info_detector_drone as idd
from toolbox.MFCC_extract import mfcc_extract
from toolbox.name_set import name_set_drone
from toolbox.name_set import name_set_csv
from toolbox import audio_processing as ap
# from toolbox.name_set import drone_set

class Trainer():
    
    def __init__(self, args):
        
        self.args = copy.deepcopy(args)
        # Intrinsic parameters
        self.audio = []
        self.audio_label = []
        self.audio_data = []
        # Carefully treat with this variable
        self.name_check = idd.FileNameProcessing(idd.name_set_drone)
        # Check the dic
        self.dic_quick_check(self.args.dic_choose, self.args.dic_aban, self.name_check)
        # Load all the valid audio
        self.audio, self.audio_label = self.audio_select(self.args.originData_path, 
                                          self.args.dic_choose, 
                                          self.args.dic_aban,
                                          self.name_check)
        if not self.audio:
            print('No drone available in this condition!')
            sys.exit()
        # Audio -> array
        for i in self.audio:
            self.audio_data.append(ap.audio_load(i))
        # Split the train set and evaluate set
        self.train_data , self.train_label = self.train_eval_split(self.audio_data, self.audio_label)
        # MFCC
        self.wave_feature_all, self.wave_label_all = \
            mfcc_extract(self.train_data, self.train_label, 
                         num_filter = self.args.mfcc['num_filter'],
                         num_cep = self.args.mfcc['num_cep'], 
                         winlen = self.args.mfcc['winlen'], 
                         winstep = self.args.mfcc['winstep'], 
                         fs = self.args.mfcc['fs'],
                         mfcc_d1_switch = self.args.mfcc['mfcc_d1_switch'], 
                         mfcc_d2_switch = self.args.mfcc['mfcc_d2_switch'])
        
    def qda_train(self):
        print('Start QDA')
        classifier_qda=discriminant_analysis.QuadraticDiscriminantAnalysis()
        classifier_qda.fit(self.wave_feature_all,self.wave_label_all.ravel())
        print('Finish QDA')
        return classifier_qda
        
    def lda_train(self):
        print('Start LDA')
        classifier_lda=discriminant_analysis.LinearDiscriminantAnalysis()
        classifier_lda.fit(self.wave_feature_all,self.wave_label_all.ravel())
        print('Finish LDA')
        return classifier_lda
    
    def lsvm_train(self):
        print('Start LSVM')
        classifier_lsvm = svm.LinearSVC(C=1,max_iter=1000) # ovo:一对一策略
        classifier_lsvm.fit(self.wave_feature_all,self.wave_label_all.ravel())
        print('Finish LSVM')
        return classifier_lsvm
        
    def svm_train(self):
        print('Start SVM')
        C = 1
        classifier_svm = svm.SVC(C=C,kernel='rbf',gamma="auto",
                                 decision_function_shape='ovo', 
                                 max_iter = 1000) # ovo:一对一策略
        classifier_svm.fit(self.wave_feature_all,self.wave_label_all.ravel())
        print('Finish SVM')
        return classifier_svm
    
    def knn_train(self):
        print('Start KNN')
        k_value = 8 
        classifier_knn = neighbors.KNeighborsClassifier(k_value) 
        classifier_knn.fit(self.wave_feature_all,self.wave_label_all.ravel())
        print('Finish KNN')
        return classifier_knn
    
    def dt_train(self):
        print('Start DT')
        classifier_dt = DecisionTreeClassifier(max_depth=20,
                                               min_samples_leaf=15, 
                                               min_samples_split=15,
                                              # max_features=30,
                                               random_state=1)
        classifier_dt.fit(self.wave_feature_all,self.wave_label_all.ravel())
        dt_depth = classifier_dt.get_depth()
        dt_leaves = classifier_dt.get_n_leaves()
        print("depth = %d"%dt_depth)
        print("leaves = %d"%dt_leaves)
        print('Finish DT')
        return classifier_dt
        
    def rf_train(self):
        print('Start RF')
        classifier_rf = RandomForestClassifier(max_depth=15,
                                               min_samples_leaf=10, 
                                               min_samples_split=10,
                                               max_features='auto',
                                               #random_state=1,
                                               n_jobs=-1)
        classifier_rf.fit(self.wave_feature_all,self.wave_label_all.ravel())
        print('Finish RF')
        return classifier_rf
        
    def gnb_train(self):
        print('Start GNB')
        classifier_gnb=GaussianNB()
        classifier_gnb.fit(self.wave_feature_all,self.wave_label_all.ravel())
        print('Finish GNB\n')
        return classifier_gnb
    
    def save_model(self,classifier):
        print('Start saving model')
        joblib.dump(classifier,self.args.output_path+'/'+self.args.output_name)
        print('Finish')
        
    def save_mfcc_csv(self):
        # wave_feature_all_csv = pd.DataFrame(self.wave_feature_all)
        # wave_label_all_csv = pd.DataFrame(self.wave_label_all)
        
        # wave_feature_all_csv.to_csv(self.args.csv_savePath+'/'+self.args.csv_featureName, header=None, index=None)
        # wave_label_all_csv.to_csv(self.args.csv_savePath+'/'+self.args.csv_labelName, header=None, index=None)
        print('save_mfcc_csv is abandon!!!')
    
    def audio_select(self, originData_path, dic_choose, dic_aban, name_check):
        # Store audio data and label
        audio = []
        audio_label = []
        for root, dirs, files in os.walk(originData_path):
            for i in range(len(files)):
                name, label = name_check.info_detector(files[i],"drone_No")
                if name_check.check_file_choose(name, dic_choose) \
                    and not(name_check.check_file_aban(name, dic_aban)):
                        
                    audio.append(wave.open(root+'/'+files[i]))
                    audio_label.append(label)
                    print("Choose %s"%files[i])
                else:
                    print("abandon %s"%files[i])
                    pass
        return audio, audio_label

    def train_eval_split(self, audio_data, audio_label):
        train_data = []
        # eval_data = []
        
        train_label = []
        # eval_label = []
        for data in audio_data:
            seg_point = []
            seg_point.append(int(len(data)*0.2))
            seg_point.append(int(len(data)*0.35))
            seg_point.append(int(len(data)*0.65))
            seg_point.append(int(len(data)*0.8))
            
            train_data.append(data[0:seg_point[0]])
            train_data.append(data[seg_point[1]:seg_point[2]])
            train_data.append(data[seg_point[3]:])
            
            # eval_data.append(data[seg_point[0]:seg_point[1]])
            # eval_data.append(data[seg_point[2]:seg_point[3]])
            
        for i in audio_label:
            train_label.append(i)
            train_label.append(i)
            train_label.append(i)
            # eval_label.append(i)
            # eval_label.append(i)
        
        return train_data, train_label

    def dic_quick_check(self, dic_choose, dic_aban, name_check):
        # Check the format of the dic_choose
        if not(name_check.check_dic(dic_choose)):
            print("The format of the dic_choose is wrong!")
            sys.exit()
        # Check the format of the dic_aban
        if not(name_check.check_dic(dic_aban)):
            print("The format of the dic_aban is wrong!")
            sys.exit()


class Trainer_csv(Trainer):
    
    def __init__(self, args):
        self.args = copy.deepcopy(args)
        
        # Carefully treat with this variable
        self.name_check = idd.FileNameProcessing(name_set_csv)
        # 
        self.wave_feature_all = None
        self.wave_label_all = None
        
        name_set_csv_list = ['prefix','date','num_filter','num_cep','winlen','winstep','multiset','suffix']
        dic_choose_csv = dict([(k,[]) for k in name_set_csv_list])
        dic_aban_csv = dict([(k,[]) for k in name_set_csv_list])
        # Do not change this. 
        dic_choose_csv['prefix'] = ['_label_']
        dic_choose_csv['date'] = self.args.dic_choose['date']
        
        self.dic_quick_check(dic_choose_csv, dic_aban_csv, self.name_check)
            
        # For adding drone set to the name
        self.name_multiset = dict([(k,False) for k in name_set_csv['multiset']])
        
        # Load valid audio file
        for root, dirs, files in os.walk(self.args.csv_savePath):
            for i in range(len(files)):
                
                name, _ = self.name_check.info_detector(files[i], "suffix")
                if self.name_check.check_file_choose(name, dic_choose_csv) \
                    and not(self.name_check.check_file_aban(name, dic_aban_csv))\
                    and self.fileName_check(files[i]):
                    
                    # Add drone set to the name
                    for i_drone_set in name_set_csv['multiset']:
                        if self.name_check.check_file_choose(name, {'multiset':[i_drone_set]}):
                            self.name_multiset[i_drone_set] = True
                    
                    wave_label = np.array(pd.read_csv(root+'/'+files[i]))
                    wave_feature = np.array(pd.read_csv(root+'/'+files[i].replace('label','MFCC',1)))
                    # Here should design a drone select function in the future
                    wave_label, wave_feature = self.drone_selection(wave_label, wave_feature)
                    
                    #**********************************************************#
                    #**********************************************************#
                    
                    try:
                        self.wave_feature_all = np.vstack((self.wave_feature_all, wave_feature))
                        self.wave_label_all = np.vstack((self.wave_label_all, wave_label))
                    except ValueError:
                        self.wave_feature_all = wave_feature
                        self.wave_label_all = wave_label
                        
                    print("Choose %s"%files[i])
                else:
                    print("abandon %s"%files[i])
                    pass
        print('wave_feature_all = ',np.shape(self.wave_feature_all))
        print('wave_label_all = ',np.shape(self.wave_label_all))
        
        # Change the output model name
        name_extra = ''
        for i_drone_set, i in self.name_multiset.items():
            if i:
                name_extra = name_extra + i_drone_set.replace('_','',1)

        self.args.output_name = self.args.output_name.replace('.m','') + name_extra + '.m'
        
    def drone_selection(self, wave_label, wave_feature):
        label_dic = dict(zip(name_set_drone['drone_No'], 
                             np.arange(len(name_set_drone['drone_No']))))
        list_delete = []
        count = 0
        for label in wave_label:
            for name in self.args.dic_aban['drone_No']:
                if label == label_dic[name]:
                    list_delete.append(count)
                    break
            count = count + 1
        wave_label = np.delete(wave_label, list_delete, axis = 0)
        wave_feature = np.delete(wave_feature, list_delete, axis = 0)
        return wave_label, wave_feature
    
    def fileName_check(self, name):
        name_list = ['num_filter','winlen','winstep']
        name_list_pattern = ['_.{1,4}nf_','_.{1,4}wl_','_.{1,4}ws_']
        # Check the if MFCC parameters are valid
        for i in range(len(name_list_pattern)):
            if not float(re.search('\d+\.?\d*',re.search(name_list_pattern[i], name).group()).group()) \
            == float(self.args.mfcc[name_list[i]]):
                print('This MFCC parameter is not right')
                print('fileName = ',float(re.search('\d+\.?\d*',re.search(name_list_pattern[i], name).group()).group()))
                print('train.py = ', float(self.args.mfcc[name_list[i]]))
                return False
        
        return True

class Trainer_pkl(Trainer_csv):
    def __init__(self, args):
        self.args = copy.deepcopy(args)
        idx = pd.IndexSlice

        if not self.fileName_check(self.args.pkl_fileName):
            sys.exit()
        
        mfcc_list = ['0_mfcc']
        if self.args.mfcc['mfcc_d1_switch']:
            mfcc_list.append('1_mfcc')
        if self.args.mfcc['mfcc_d2_switch']:
            mfcc_list.append('2_mfcc')
        
        self.label_dic = dict(zip(name_set_drone['drone_No'], 
                                  np.arange(len(name_set_drone['drone_No']))))
        self.wave_label_all = []
        # Load the dataset
        pkl_dataset = pd.read_pickle(self.args.pkl_savePath+'/'+self.args.pkl_fileName)
        
        self.wave_feature_all = pkl_dataset.loc[idx['train',self.args.dic_choose["date"],
                                                 :,self.args.dic_choose["drone_No"]],
                                                idx[mfcc_list,1:self.args.mfcc['num_cep']]]
        
        for multiIndex in self.wave_feature_all.index:
            self.wave_label_all.append(self.label_dic[multiIndex[-1]])
        
        self.wave_label_all = np.array(self.wave_label_all).reshape(-1,1)
        self.wave_feature_all = np.array(self.wave_feature_all)