vitanet / code / utils / inf_metrics.py
inf_metrics.py
Raw
import numpy as np
from matplotlib import pyplot as plt


def get_peaks(sig):

    window = 5
    start = window - window // 2 - 1

    # Create sub windows of desired width
    sub_windows = (
        start +
        np.expand_dims(np.arange(window), 0) +
        np.expand_dims(np.arange(len(sig) - window - window // 2), 0).T
    )

    sig = sig[sub_windows]
    max_peaks = np.where(sig[:, start] >= np.max(sig, axis=1))[0] + (start * 2)
    min_peaks = np.where(sig[:, start] <= np.min(sig, axis=1))[0] + (start * 2)

    return min_peaks, max_peaks


def get_features(sig):
    
    def get_triplets(min_s, max_s):
        triplets = []
        if min_s[0] < max_s[0]:
            for idx in range(0, min(len(min_s) - 1, len(max_s))):
                if not (min_s[idx] < max_s[idx] < min_s[idx+1]): continue
                triplets.append([min_s[idx], max_s[idx], min_s[idx+1]])
        else:
            for idx in range(1, min(len(min_s) - 1, len(max_s))):
                if not (min_s[idx-1] < max_s[idx] < min_s[idx]): continue
                triplets.append([min_s[idx-1], max_s[idx], min_s[idx]])

        return triplets

    min_s, max_s = get_peaks(sig)
    #print(f'before thresh: min_s: {min_s}, max_s: {max_s}')
    min_s = [m for m in min_s if sig[m] < 0]
    max_s = [m for m in max_s if sig[m] > 0]
    #print(f'after thresh: min_s: {min_s}, max_s: {max_s}')
    triplets = get_triplets(min_s, max_s)
    d1 = [0] + np.diff(sig)
    d2 = [0] + np.diff(d1)
    min_d2, max_d2 = get_peaks(d2)
    sd = []

    for trp in triplets:
        search = d2[trp[1] + 1 : trp[2] + 2]
        max_search = [md2 for md2 in max_d2 if trp[1] + 1 < md2 < trp[2] + 1]
        if len(max_search) == 2:
            sd.append([trp[1], max_search[0]+3])

    return sd, triplets


class Metrics():
    def __init__(self):
        self.errs = {'hr_rad_error': [], 'hr_rad_multi_error': [],
                     'hr_rad_atw_error': [], 'hr_vita_error': [], 
                     'sys_dia_diff_error': [], 'sys_dia_diff_mean': [], 
                     'l1': [], 'l1_1': [], 'l1_2': []}

    def evaluate(self, ppg, pre, rad, vspr, atw):
    
        metrics = {}
        for key in self.errs: 
            metrics[key] = 0

        ppg, pre, rad = ppg[:, 0].numpy(), pre[:, 0].numpy(), rad[:, :, :].numpy()
        max_rad_bin = np.argmax(vspr)

        ### L1 errors
        ppg_d1, pre_d1 = np.diff(ppg), np.diff(pre)
        ppg_d2, pre_d2 = np.diff(ppg_d1), np.diff(pre_d1)
        l1 = np.mean(abs(pre-ppg) / (max(ppg) - min(ppg)))
        l1_1 = np.mean(abs(pre_d1-ppg_d1) / (max(ppg_d1) - min(ppg_d1)))
        l1_2 = np.mean(abs(pre_d2-ppg_d2) / (max(ppg_d2) - min(ppg_d2)))
        self.errs['l1'].append(round(l1, 2))
        self.errs['l1_1'].append(round(l1_1, 2))
        self.errs['l1_2'].append(round(l1_2, 2))

        ### HR Metrics
        ppg_extended = np.concatenate([ppg, np.zeros(1200 - len(ppg))])
        pre_extended = np.concatenate([pre, np.zeros(1200 - len(ppg))])
        ppg_hr = np.argmax(np.abs(np.fft.fft(ppg_extended))[45:110]) + 45
        pre_hr = np.argmax(np.abs(np.fft.fft(pre_extended))[45:110]) + 45
        metrics['hr_vita'] = abs(ppg_hr - pre_hr)
        self.errs['hr_vita_error'].append(abs(ppg_hr - pre_hr))

        # Single bin radar error
        rad_hr = []
        for rad_channel in range(rad.shape[1]):
            curr_rad = rad[:, rad_channel, max_rad_bin]
            rad_extended = np.concatenate([curr_rad, np.zeros(1200 - len(curr_rad))])
            hr = np.argmax(np.abs(np.fft.fft(rad_extended))[45:110]) + 45
            rad_hr.append(hr)
        rad_hr = int(round(np.mean(rad_hr)))  
        metrics['hr_rad'] = abs(ppg_hr - rad_hr)
        self.errs['hr_rad_error'].append(abs(ppg_hr - rad_hr))

        # Multi bin radar error
        rad_hr = []
        max_bins = [idx for idx,v in enumerate(vspr) if v>1]
        for rad_channel in range(rad.shape[1]):
            for bin in max_bins:
                curr_rad = rad[:, rad_channel, bin]
                rad_extended = np.concatenate([curr_rad, np.zeros(1200 - len(curr_rad))])
                hr = np.argmax(np.abs(np.fft.fft(rad_extended))[45:110]) + 45
                rad_hr.append(hr)
        rad_hr = int(round(np.mean(rad_hr)))  
        metrics['hr_rad_multi'] = abs(ppg_hr - rad_hr)
        self.errs['hr_rad_multi_error'].append(abs(ppg_hr - rad_hr))

        # Attention window radar error
        rad_hr = []
        atw = np.sum(atw, axis=1)
        bins = [bin for bin in np.argsort(atw)[::-1][:8] if atw[bin] > 0.5]
        if len(bins) == 0: bins = [np.argmax(atw)]

        for rad_channel in range(rad.shape[1]):
            for bin in range(64):
                curr_rad = rad[:, rad_channel, bin]
                rad_extended = np.concatenate([curr_rad, np.zeros(1200 - len(curr_rad))])
                hr = np.argmax(np.abs(np.fft.fft(rad_extended))[45:110]) + 45
                rad_hr.append(hr)
        rad_hr = int(round(np.mean(rad_hr)))  
        metrics['hr_rad_atw'] = abs(ppg_hr - rad_hr)
        self.errs['hr_rad_atw_error'].append(abs(ppg_hr - rad_hr))

        ## Diastolic/Systolic
        ppg_sd, ppg_trp = get_features(ppg)
        pre_sd, pre_trp = get_features(pre)

        """        
        print(f'ppg_trp: {ppg_trp}')
        print(f'pre_trp: {pre_trp}')

        plt.figure()
        plt.subplot(2,1,1)
        plt.plot(ppg)
        d1 = [0] + np.diff(ppg)
        d2 = [0] + np.diff(d1)
        plt.plot(d1)
        plt.plot(d2)
        ppg_sd_s = [p[0] for p in ppg_sd]
        ppg_sd_d = [p[1] for p in ppg_sd]
        plt.plot(ppg_sd_s, ppg[ppg_sd_s], 'r*') 
        plt.plot(ppg_sd_d, ppg[ppg_sd_d], 'g*') 
        for trp in ppg_trp:
            plt.axvline(trp[1], c='b')
            plt.axvline(trp[2], c='k')

        plt.subplot(2,1,2)
        plt.plot(pre)
        d1 = [0] + np.diff(pre)
        d2 = [0] + np.diff(d1)
        plt.plot(d1)
        plt.plot(d2)
        pre_sd_s = [p[0] for p in pre_sd]
        pre_sd_d = [p[1] for p in pre_sd]
        plt.plot(pre_sd_s, pre[pre_sd_s], 'r*') 
        plt.plot(pre_sd_d, pre[pre_sd_d], 'g*') 
        for trp in pre_trp:
            plt.axvline(trp[1], c='b')
            plt.axvline(trp[2], c='k')
 
        plt.savefig('sd.png')
        plt.close()
        """
        print(ppg_sd, pre_sd)
        
        if len(ppg_sd) > 0 and len(pre_sd) > 0:
            ppg_sd_diff = [sd[1] - sd[0] for sd in ppg_sd]
            pre_sd_diff = [sd[1] - sd[0] for sd in pre_sd]
            print(ppg_sd_diff, pre_sd_diff)
            if min(ppg_sd_diff) <= np.mean(pre_sd_diff) <= max(ppg_sd_diff):
                error = 0
            else:
                error = min(abs(np.mean(pre_sd_diff) - min(ppg_sd_diff)), abs(np.mean(pre_sd_diff) - max(ppg_sd_diff)))

            print(f'error: {error}')
            self.errs['sys_dia_diff_error'].append(error)
            self.errs['sys_dia_diff_mean'].append(np.mean(ppg_sd_diff))
            #input()

        #input()
        metrics['sd_ppg'] = ppg_sd
        metrics['sd_vita'] = pre_sd

        return metrics