import os import click import random import datetime import sys sys.path.insert(0, '../common') from tools import * from dataio import data_pipe from utils import inf_metrics import numpy as np import tensorflow as tf from os import listdir from os.path import join, dirname from glob import glob from pathlib import Path from matplotlib import pyplot as plt from matplotlib.backends.backend_pdf import PdfPages import tensorflow.keras.layers as L from tensorflow.keras import Model from tensorflow.keras.models import load_model from tensorflow.keras.callbacks import ModelCheckpoint from tensorflow.keras import backend as K @click.command() @click.option("--cmd", type=str, default='train', required=False) @click.option("--train-split-ratio", type=float, default=0.90, required=True) @click.option("--model-path", type=str, required=True) @click.option("--tfr-path", type=str, required=True) @click.option('--cache', type=str, default='/data1/homes.local/superluca/cache/vitanet.dp', required=False, help='file_path|ram|disabled') ## PARM> train-epochs @click.option("--batch-size", type=int, default=32, required=True) @click.option("--train-epochs", type=int, default=300, required=True) @click.option("--lr-drop-epoch", type=int, default=200, required=True) @click.option("--isync-anneal-epoch", type=int, default=100, required=True) # @click.option("--tar-sz", type=int, default=192, required=True) # @click.option("--inp-sz", type=int, default=256, required=True) # @click.option("--tar-sz", type=int, default=128, required=True) # @click.option("--inp-sz", type=int, default=128, required=True) @click.option("--tar-sz", type=int, default=140, required=True) @click.option("--inp-sz", type=int, default=128, required=True) def VitaNet(cmd: str, model_path: str = 'model', tfr_path: str = 'tfrs', cache = None, train_split_ratio: float = 0.80, batch_size: int = 32, train_epochs: int = 300, lr_drop_epoch: int = 200, isync_anneal_epoch: int = 100, fps: int = 20, tfr_win_sz: int = 400, tar_sz: int = 320, inp_sz : int = 256, stretch_sz: int = 80, # jitter_sz: int = 40 # jitter_sz: int = 80 jitter_sz: int = 180 ): def get_fnames(split_sessions='train_val_dirs'): if split_sessions == 'flat': sessions = [str(kk) for kk in glob(tfr_path+'/*')] train_test_split = int(len(sessions) * train_split_ratio) train_sessions = sessions[:train_test_split] val_sessions = sessions[train_test_split:] print('sessions>', 'train>', train_sessions, '\nval>', val_sessions) train_fnames = [kk for ss in train_sessions for kk in glob(ss+'/*.tfr')] val_fnames = [kk for ss in val_sessions for kk in glob(ss+'/*.tfr')] elif split_sessions == 'flat': fnames = [str(kk) for kk in Path(tfr_path).rglob('*.tfr')] train_test_split = int(len(fnames) * train_split_ratio) train_fnames = fnames[:train_test_split] val_fnames = fnames[train_test_split:] elif split_sessions == 'train_val_dirs': train_fnames = [str(kk) for kk in glob(tfr_path+'/train/*/*.tfr')] val_fnames = [str(kk) for kk in glob(tfr_path+'/val/*/*.tfr')] random.shuffle(train_fnames) layout_fn = join(dirname(train_fnames[0]), 'layout.pkl') print('train>', train_fnames, '\n', 'val>', val_fnames) return train_fnames, val_fnames, layout_fn def normalize(xx,axis): # ## N(0,1) norm # xx -= tf.reduce_mean(xx,axis=axis,keepdims=True) # xx /= (tf.math.reduce_std(xx,axis=axis,keepdims=True) + 1e-2) ## max/min norm to [-1,+1] range dr = tf.reduce_max(xx,axis=axis,keepdims=True) - tf.reduce_min(xx,axis=axis,keepdims=True) zero = (tf.reduce_max(xx,axis=axis,keepdims=True) + tf.reduce_min(xx,axis=axis,keepdims=True))/2 xx = 2 * (xx - zero) / (dr + 1e-4) ## about~ [-1,+1] range # tf.print('normalize>', zero, dr, tf.reduce_max(xx,axis=axis,keepdims=True), tf.reduce_min(xx,axis=axis,keepdims=True)) return xx ## get the right data-keys def get_data(rec): # {'phase': {'shape': (400, 12, 64, 2), 'dtype': 'float32'}, # 'phase_raw': {'shape': (400, 12, 64), 'dtype': 'float32'}, # 'mag_raw': {'shape': (400, 12, 64), 'dtype': 'float32'}, # 'ppg': {'shape': (400, 2, 2), 'dtype': 'float32'}, # 'ppg_raw': {'shape': (400, 2), 'dtype': 'float32'}, # 'vs_present': {'shape': (64,), 'dtype': 'float32'}, # 'rec_id': {'dtype': 'str'}} # inp = rec['phase_raw'] ## raw-phase ==> WORKS BUT NOT BETTER THAN H|F>filt # inp = tf.stack( [rec['phase_raw'],rec['mag_raw']], 3) ## raw-phase + raw-mag ==> similar to above inp = rec['phase'][...,0:2] ## R,H phases - 64 bins tar = rec['ppg'][:,0,0:2] ## R,H phases - of ir-ppg-channel ## select best bin with vs-present vspr = rec['vs_present_rad'][:,0] ## BR - 64 bins vspp = rec['vs_present_ppg'][0][1] ## Red channel, HR id_ = rec['rec_id'] valid_1 = tf.cond(tf.math.reduce_max(vspr)>1.0, lambda: True, lambda: False) valid_2 = tf.cond(vspp>1.0, lambda: True, lambda: False) valid = tf.math.logical_and(valid_1, valid_2) # inp = inp[:,:,tf.math.argmax(vspr)] # inp = inp[:,:,tf.math.argmax(vspr),1:2] ## argmax-model of H/channel inp = inp[...,1] ## ## fullcube-model of H/channel -- (400, 12, 64, 1) return ( inp, tar, valid, vspr, id_ ) ## augmentation def augment_dataset(ds, randomize=True): # PARM> to augment or not to augment? that is the problem (says Yann:) # return ds def augment(inp, tar, vspr, id_): if randomize: ## inp=(400, 12, 64, 1) ## permute antenna order inp = tf.transpose(inp, perm=[1, 0, 2]) ## (12, 400, 64, 1) inp = tf.random.shuffle(inp) ## shuffled along dimension 0 ==> a=12 inp = tf.transpose(inp, perm=[1, 0, 2]) ## (400, 12, 64, 1) rand = tf.random.uniform([3]) # rand = tf.constant([0.,0.5,0.]) ## stretch stretch = tfr_win_sz + (2*rand[0]-1)*stretch_sz stretch = tf.cast(stretch, tf.int32) inp = tf.image.resize(inp, [stretch,12]) # tar = tf.expand_dims(tar,2) tar = tf.image.resize(tar, [stretch,2]) tar = tf.squeeze(tar) ## jitter jitter = (2*rand[1]-1)*jitter_sz jitter = tf.cast(jitter, tf.int32) inp = tf.cond(jitter>=0, lambda: inp[jitter:], lambda: inp[:jitter]) tar = tf.cond(jitter>=0, lambda: tar[jitter:], lambda: tar[:jitter]) ## 180deg input phase rotation inp = tf.cond(rand[2]<0.5, lambda: inp, lambda: -inp) ## set new-size new_size = stretch-tf.math.abs(jitter) # tf.print('augment>', stretch, jitter, new_size) else: new_size = tfr_win_sz ## center-crop target to larger size than inp to allow for iSyncLoss ## this has to respect the min_tar_size formula from stetch/jitter! tar = tar[ (new_size-tar_sz)//2 : (new_size+tar_sz)//2 ] ## center-crop to 256 for input target inp = inp[ (new_size-inp_sz)//2 : (new_size+inp_sz)//2 ] ## FCK> manually set shape ~ because TF is dumb and gives None which breakes downstream shape inference ... tar = tf.reshape(tar, [tar_sz,2]) # inp = tf.reshape(inp, [inp_sz,12,2]) inp = tf.reshape(inp, [inp_sz,12,64]) ## fullcube-model # inp = tf.reshape(inp, [inp_sz,12,1]) ## post-cache select H/channel # inp = inp[...,1:2] ## dynrange normalization to [-1,+1] inp = normalize(inp, [0,1]) ## R|H dependent norm ==> (1,2) dr stats ## antennas of low-dr have lower weight ==> BIT BETTER # inp = normalize(inp, 0) ## antenna dependent norm // R|H dependent norm ==> (12,2) dr stats # ## all antennas count the same even low-dr ones ==> BIT WORSE # tar = normalize(tar, 0) ## R|H dependent norm ==> (1,2) dr stats # inp = inp / 0.015666926 ## global fixed norm on data-set stats # tar = tar / 1556.5098 return inp, tar, vspr, id_ ds = ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE) return ds def iSyncLoss(tar, pre, return_crops=False): def f_derivative(x): d = x[:,:,1:] - x[:,:,:-1] return d def s_derivative(x): d1 = x[:,:,1:] - x[:,:,:-1] d2 = d1[:,:,1:] - d1[:,:,:-1] return d2 def gauss(x, istd=2.0): return tf.math.exp( - tf.math.square(x*istd) ) def peak_mask(xx): pp = tf.constant([[0,0],[0,0],[2,2]]) # padding to recover same input shape yy = tf.signal.frame(xx, frame_length=5, frame_step=1, axis=2) # slidewin of size 5 ## peak/max(min) if - center elm (idx=3) is larger(smaller) than all other elm of win-5 pmax = tf.gather(yy,[2],axis=3) > tf.reduce_max(tf.gather(yy,[0,1,3,4],axis=3),axis=3,keepdims=True) pmax = tf.pad(tf.squeeze(pmax),pp) ## pmin = tf.gather(yy,[2],axis=3) < tf.reduce_min(tf.gather(yy,[0,1,3,4],axis=3),axis=3,keepdims=True) pmin = tf.pad(tf.squeeze(pmin),pp) pmask = tf.cast( tf.math.logical_or(pmax, pmin), tf.float32) pmask += 1e-2 ## small loss for non-peaks (1% of peak loss) return tf.expand_dims(pmax,-1), tf.expand_dims(pmin,-1), tf.expand_dims(pmask,-1) ## pre_sz = pre.shape[1] # crop and align tar to the prediction output size tar = tar[...,1:2] # H phase ~ select ## slide the target tar_slide = tf.signal.frame(tar, frame_length=pre_sz, frame_step=1, axis=1) ## use with l1/2 loss # tar_slide = tf.signal.frame(tar, frame_length=pre_sz+1, frame_step=1, axis=1) ## use with peak-weighted tar_slide = tf.concat( [tar_slide,-tar_slide] , 1) ## Sheen180 ~ correct for 180deg phase uncertanty ## dr stats ==> slide-window dependent! tar_slide = normalize(tar_slide, [2]) ## [32, 130, 128, 1] => stats: [32, 130, 1, 1] ## Errors err = tf.math.abs(tar_slide - tf.expand_dims(pre,1) ) ## L1-loss # err = tf.math.square(tar_slide - tf.expand_dims(pre,1) ) ## L2-loss err_d1 = tf.math.abs(f_derivative(tar_slide) - f_derivative(tf.expand_dims(pre,1))) err_d2 = tf.math.abs(s_derivative(tar_slide) - s_derivative(tf.expand_dims(pre,1))) ## Peaks # pmax, pmin, pmask = peak_mask(tar_slide) # tf.print(pmax.shape, pmin.shape, 'peaks>', tf.where(pmask != 0).shape[0]) # tf.print(pmax.shape, pmin.shape, 'peaks>', tf.where(pmax).shape[0], tf.where(pmin).shape[0]) # err *= pmask # # pmax, pmin, pmask = peak_mask(f_derivative(tar_slide)) # err_d1 *= pmask # # pmax, pmin, pmask = peak_mask(s_derivative(tar_slide)) # err_d2 *= pmask ## Error Slide ~ cumsum err_slide = tf.reduce_sum(err, [2,3]) err_slide += tf.reduce_sum(err_d1, [2,3]) err_slide += tf.reduce_sum(err_d2, [2,3]) ##DEBUG> isync-shapes> [32, 192, 1] [32, 128, 1] [32, 130, 128, 1] [32, 130, 128, 1] [32, 130] print('isync-shapes>', tar.shape.as_list(), pre.shape.as_list(), tar_slide.shape.as_list(), err.shape.as_list(), err_slide.shape.as_list()) # return tf.reduce_mean(pre) ## ## best#1 loss # sync_independent_loss = tf.reduce_min(err_slide, -1) ## PARM> best#k losses ~ with annealing ##TODO>>> add some kk noise by random select multiple alignments even after annealed! # kk = tf.math.maximum(1, 10-epoch_var//10) ## 100 epochs to kk=1 # kk = tf.math.maximum(1, 10-epoch_var//2) ## 20 epochs to kk=1 kk = tf.math.maximum(1, 10-epoch_var//(isync_anneal_epoch//10)) loss_sort = tf.sort(err_slide, -1)[:,:kk] sync_independent_loss = tf.reduce_mean(loss_sort, -1) ## if return_crops: idx = tf.math.argmin(err_slide, -1) ## DEBUG> # idx_sort = tf.argsort(err_slide, -1) # idx = idx_sort[:,2] ## 2nd best crop! # print(f'sync> idx {idx[0]} idx_sort {idx_sort[0]} loss_sort {loss_sort[0]} err_slide {err_slide[0]}') # tar_sync = tf.stack([ tar[kk,idx[kk]:idx[kk]+pre_sz] for kk in range(tar.shape[0]) ]) tar_sync = tf.stack([ tar_slide[kk,idx[kk]] for kk in range(tar.shape[0]) ]) return sync_independent_loss, tar_sync, idx//2 - err_slide.shape[1]//4, idx #(idx % pre_sz)- pre_sz//2 ## with Sheen180 # return sync_independent_loss, tar_sync, idx - err_slide.shape[1]//2 #(idx % pre_sz)- pre_sz//2 else: return sync_independent_loss def dynRangeLoss(tar, pre): dr = lambda xx: tf.reduce_max(xx,axis=[1]) - tf.reduce_min(xx,axis=[1]) # tf.print('dr>', tar.shape, pre.shape, dr(tar[...,1:2]), dr(tar[...,1:2]).shape, pre) return tf.reduce_mean( tf.math.abs( dr(tar[...,1:2]) - pre ) ) ## L1-loss def get_model_base(fx=2): ## encoder->decoder model input_shape = (inp_sz, 12, 1) # input_shape = (inp_sz, 12, 64) # input_shape = (inp_sz, 12, 2) # output_shape = (out_sz, 2) model_input = L.Input(input_shape) xx = model_input ## Convo Enc->Dec ## enc(t) for kk in range(4): rx = L.AveragePooling2D(pool_size=(3,1), strides=(2,1), padding='same')(xx) ## residual xx = L.Convolution2D(filters=32*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx) xx = L.Dropout(0.2)(xx) if kk>0: xx = xx + rx ## residual xx = L.BatchNormalization()(xx) ## enc(t,a) # for kk in range(3): ## inp=256 for kk in range(2): ## inp=128 rx = L.AveragePooling2D(pool_size=(3,3), strides=(2,2), padding='same')(xx) ## residual xx = L.Convolution2D(filters=64*fx, kernel_size=(3,3), strides=(2,2), padding='same', activation='gelu')(xx) xx = L.Dropout(0.2)(xx) if kk>0: xx = xx + rx ## residual xx = L.BatchNormalization()(xx) ## Pool ## Embedding=128dim -- from signal of tfr_win_sz dim -- compression of ~3.6x # xx = L.AveragePooling2D(pool_size=(2, 2), strides=(1, 1), padding='valid')(xx) # xx = L.Convolution2D(filters=256, kernel_size=(2,2), strides=(1,1), padding='valid', activation='gelu')(xx) ## inp=256 xx = L.Convolution2D(filters=256, kernel_size=(2,3), strides=(1,1), padding='valid', activation='gelu')(xx) ## inp=128 xx = L.BatchNormalization()(xx) print('embedding-shape>', xx.shape[1:], 'size>', np.prod(np.array(xx.shape[1:].as_list())) ) # ## DynRange model # dr = L.Dense(1, activation='linear')(xx) # model= Model([model_input], dr) # print(model.summary()) # return model ## FC-mixer layer # xx = L.Dense(128, activation='tanh')(xx) # xx = L.BatchNormalization()(xx) ## dec(t,a) --> but don't upsample a(ntennas) # for kk in range(8): ## inp=256 out=256 for kk in range(7): ## inp=256 out=128 # for kk in range(6): ## inp=128 out=64 rx = L.UpSampling2D(size=(2,1))(xx) ## up-residual xx = L.Convolution2DTranspose(filters=64*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx) xx = L.Dropout(0.2)(xx) if kk>0: xx = xx + rx ## residual xx = L.BatchNormalization()(xx) # ## dec(t) # for kk in range(4): # xx = L.Convolution2DTranspose(filters=32, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx) # xx = L.BatchNormalization()(xx) ## up-sample yields larger tensor because input ## not needed if pow2 input dim[0] # xx = L.CenterCrop(output_shape[0],1)(xx) ## FC-out & shape # xx = L.Dense(2, activation='linear')(xx) xx = L.Dense(128, activation='gelu')(xx) xx = L.Dropout(0.2)(xx) xx = L.BatchNormalization()(xx) xx = L.Dense(1, activation='linear')(xx) # xx = L.Reshape(output_shape)(xx) xx = xx[:,:,0,:] ## Output normalization # xx = L.Normalization()(xx) ## weird doenst really normalize well or at all ... # xx = normalize(xx,[1]) model= Model([model_input], xx) print(model.summary()) return model def encoder_decoder(xx, fx): ## enc(t) for kk in range(4): rx = L.AveragePooling2D(pool_size=(3,1), strides=(2,1), padding='same')(xx) ## residual xx = L.Convolution2D(filters=32*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx) xx = L.Dropout(0.2)(xx) if kk>0: xx = xx + rx ## residual xx = L.BatchNormalization()(xx) ## enc(t,a) for kk in range(2): ## inp=128 rx = L.AveragePooling2D(pool_size=(3,1), strides=(2,1), padding='same')(xx) ## residual xx = L.Convolution2D(filters=64*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx) xx = L.Dropout(0.2)(xx) if kk>0: xx = xx + rx ## residual xx = L.BatchNormalization()(xx) ## Pool # xx = L.Convolution2D(filters=256, kernel_size=(2,attention_heads), strides=(1,1), padding='valid', activation='gelu')(xx) ## inp=128 xx = L.GlobalAveragePooling2D(keepdims=True)(xx) xx = L.Dense(256, activation='gelu')(xx) xx = L.BatchNormalization()(xx) print('embedding-shape>', xx.shape[1:], 'size>', np.prod(np.array(xx.shape[1:].as_list())) ) for kk in range(7): ## inp=128 out=128 rx = L.UpSampling2D(size=(2,1))(xx) ## up-residual xx = L.Convolution2DTranspose(filters=64*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx) xx = L.Dropout(0.2)(xx) if kk>0: xx = xx + rx ## residual xx = L.BatchNormalization()(xx) # xx = L.Dense(128, activation='gelu')(xx) # xx = L.Dropout(0.2)(xx) # xx = L.BatchNormalization()(xx) xx = L.Dense(1, activation='linear')(xx) xx = xx[:,:,0,:] return xx # def get_model_flat_attention(fx=2, attention_heads=8, scale=1.0): ## attentional model on flat A*B => 12*64=768 def get_model_flat_attention(fx=2, attention_heads=48, scale=1.0): ## attentional model on flat A*B => 12*64=768 input_shape = (inp_sz, 12, 64) model_input = L.Input(input_shape) ## flat 64*12 => multi-head attention xx = L.Reshape([inp_sz, 12*64, 1])(model_input) ## (None, 128, 768, 1) ## attention-encoder for kk in range(3): xx = L.Convolution2D(filters=32, kernel_size=(16,1), strides=(2,1), padding='valid', activation='gelu')(xx) xx = L.BatchNormalization()(xx) xx = tf.reduce_mean(xx, 1) ## attention-projector xx = L.Dense(attention_heads, activation='linear')(xx) multi_head_attention_weights = L.Softmax(axis=1, name='attention_weights')(xx) * scale xx = tf.matmul( L.Reshape([inp_sz, 12*64])(model_input) , multi_head_attention_weights ) xx = L.Reshape([inp_sz, attention_heads, 1])(xx) ## encoder-decoder trunk xx = encoder_decoder(xx, fx) ## out model= Model([model_input], xx) print(model.summary()) return model def get_model_bin_attention_flat_ant(fx=2, attention_heads=8, scale=1.0): ## attentional model bins B=64 input_shape = (inp_sz, 12, 64) model_input = L.Input(input_shape) ## bin attention on 64 => multi-head attention xx = L.Permute([1,3,2])(model_input) ## (None, 128, 64, 12) ## attention-encoder for kk in range(3): xx = L.Convolution2D(filters=32, kernel_size=(16,1), strides=(2,1), padding='valid', activation='gelu')(xx) xx = L.BatchNormalization()(xx) xx = tf.reduce_mean(xx, 1) ## attention-projector xx = L.Dense(attention_heads, activation='linear')(xx) multi_head_attention_weights = L.Softmax(axis=1, name='attention_weights')(xx) * scale xx = tf.matmul( L.Reshape([inp_sz*12,64])(model_input), multi_head_attention_weights ) xx = L.Reshape([inp_sz, 12, attention_heads])(xx) xx = L.Permute([1,3,2])(xx) ## (None, 128, 8, 12) ## (None, 128, 8, 12) => (None, 128, 8*12, 1) :: treat attn-bins and antennas the same way xx = L.Reshape([inp_sz, attention_heads*12, 1])(xx) ## doesnt explint intra-antenna structure - such as AOA/BeamForm ## encoder-decoder trunk xx = encoder_decoder(xx, fx) ## out model = Model([model_input], [xx]) print(model.summary()) return model def get_model_bin_attention(fx=2, attention_heads=8, scale=1.0): ## attentional model bins B=64 / input-reduce-antennas input_shape = (inp_sz, 12, 64) model_input = L.Input(input_shape) ## bin attention on 64 => multi-head attention xx = L.Permute([1,3,2])(model_input) ## (None, 128, 64, 12) ## # reduced_input = L.Dense(1, activation='linear')(xx) reduced_input = tf.reduce_mean(xx, 3) reduced_input = L.Reshape([inp_sz, 64, 1])(reduced_input) xx = reduced_input ## attention-encoder for kk in range(3): xx = L.Convolution2D(filters=32, kernel_size=(16,1), strides=(2,1), padding='valid', activation='gelu')(xx) xx = L.BatchNormalization()(xx) xx = tf.reduce_mean(xx, 1) ## attention-projector xx = L.Dense(attention_heads, activation='linear')(xx) multi_head_attention_weights = L.Softmax(axis=1, name='attention_weights')(xx) * scale xx = tf.matmul( L.Reshape([inp_sz,64])(reduced_input), multi_head_attention_weights ) xx = L.Reshape([inp_sz, attention_heads, 1])(xx) # xx = L.Permute([1,3,2])(xx) ## (None, 128, 8, 1) # ## (None, 128, 8, 1) => (None, 128, 8*12, 1) :: treat attn-bins and antennas the same way # xx = L.Reshape([inp_sz, attention_heads, 1])(xx) ## doesnt explint intra-antenna structure - such as AOA/BeamForm ## encoder-decoder trunk xx = encoder_decoder(xx, fx) ## out model = Model([model_input], [xx]) print(model.summary()) return model def get_model(): # return get_model_flat_attention() return get_model_bin_attention_flat_ant() # return get_model_bin_attention() def train(): def scheduler(epoch, lr): K.set_value(epoch_var, epoch) if epoch < 5: return 1e-4 elif epoch < lr_drop_epoch: return 1e-3 else: return 2*1e-4 model = get_model() optimizer = tf.keras.optimizers.RMSprop(learning_rate=1e-4) # optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4) # model.compile(optimizer=optimizer, loss='mse', metrics=['mse', 'mae']) print(model) model.compile(optimizer=optimizer, loss=iSyncLoss) # model.compile(optimizer=optimizer, loss=dynRangeLoss) checkpoint_cb = ModelCheckpoint(filepath=model_path, save_weights_only=False, save_freq='epoch', # monitor='val_mae', save_best_only=True, verbose=1) train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, batch_size, cache=cache) tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=join(model_path,'tb'), histogram_freq=1) scheduler_cb = tf.keras.callbacks.LearningRateScheduler(scheduler) model.fit(train_generator, validation_data=val_generator, epochs=train_epochs, callbacks=[tensorboard_cb, scheduler_cb]) model.save(model_path) def inference(): metrics_obj = inf_metrics.Metrics() def inference_ds(ds, fname, max_samples=None): def add_figure(kk, ii, curr_metrics): add_flag, high_error = False, False if kk%4 and ii==0: ## add first sample of batch every 4 batches add_flag=True #elif abs(pre_hr-ppg_hr)>20: ## or if large error # high_error = True if not add_flag: return ## sign = -1 if idx[ii] < 13 else 1 plt.figure() plt.subplot(3,2,1) plt.plot(sign * tar_sync[ii, :], label='Ground-Truth (PPG/HR)') plt.plot(sign * pre[ii, :], label='Prediction (VitaNet/HR)') sd_ppg = sum(curr_metrics['sd_ppg'], []) sd_pre = sum(curr_metrics['sd_vita'], []) tar_y = [sign * tar_sync[ii, sd_idx] for sd_idx in sd_ppg] pre_y = [sign * pre[ii, sd_idx] for sd_idx in sd_pre] plt.plot(sd_ppg, tar_y, 'r*') plt.plot(sd_pre, pre_y, 'g*') plt.title(f'isync: {offset_index[ii]} : {idx[ii]}') plt.legend(fontsize=7) plt.subplot(3,2,3) plt.plot(sign * (tar_sync[ii, 1:] - tar_sync[ii, :-1]), label='Ground-Truth (PPG/HR)') plt.plot(sign * (pre[ii, 1:] - pre[ii, :-1]), label='Prediction (VitaNet/HR)') plt.ylabel('First Derivative') plt.legend(fontsize=7) plt.subplot(3,2,5) ppg_diff = tar_sync[ii, 1:] - tar_sync[ii, :-1] pre_diff = pre[ii, 1:] - pre[ii, :-1] plt.plot(sign * (ppg_diff[1:] - ppg_diff[:-1]), label='Ground-Truth (PPG/HR)') plt.plot(sign * (pre_diff[1:] - pre_diff[:-1]), label='Prediction (VitaNet/HR)') plt.ylabel('Second Derivative') print(curr_metrics) plt.xlabel(f'VN-HR-Err: {curr_metrics["hr_vita"]}, Rad-HR-Err: {curr_metrics["hr_rad_multi"]}, ' \ f'S-D: {curr_metrics["sys_dia_diff_error"]}') #obj = plt.xlabel(f'HRs: P: {ppg_hr}, VN: {pre_hr}, Rad: {rad_hr} {"HIGH-ERROR!" if high_error else ""}') #if high_error: plt.setp(obj, color='r') plt.legend(fontsize=7) ## Radar plots max_bin = np.argmax(vspr[ii,:]) rad_br_sig = np.mean(inputs[ii, :, :, max_bin], axis=1) rad_hr_sig = np.mean(inputs[ii, :, :, max_bin], axis=1) plt.subplot(3,2,2) plt.plot(rad_br_sig, label='Radar BR') plt.gca().yaxis.tick_right() plt.legend(fontsize=7) plt.subplot(3,2,4) plt.plot(atw[0], label='Attention') plt.gca().yaxis.tick_right() plt.legend(fontsize=7) plt.subplot(3,2,6) plt.plot(np.abs(np.fft.fft(rad_hr_sig))[:30], label='Rad HR FFT') plt.gca().yaxis.tick_right() plt.legend(fontsize=7) ## pdf.savefig() plt.close() ## chake_dir(f'{model_path}/inference') with PdfPages(f'{model_path}/inference/{fname}.pdf') as pdf: # for kk, element in enumerate(ds.as_numpy_iterator()): for kk, element in enumerate(ds): if max_samples and kk==max_samples: break # print(f'inf> {fname} :: {kk}', end='\r') inputs, tar, vspr, id_ = element pre, atw = model_with_attn.predict(inputs) pre = tf.convert_to_tensor(pre) loss, tar_sync, offset_index, idx = iSyncLoss(tar, pre, return_crops=True) print('tar_sync>', tar_sync.shape, np.mean(tar_sync, 1)[0:4], np.std(tar_sync, 1)[0:4], np.amin(tar_sync, 1)[0:4], np.amax(tar_sync, 1)[0:4]) print('pre>', pre.shape, np.mean(pre, 1)[0:4], np.std(pre, 1)[0:4], np.amin(pre, 1)[0:4], np.amax(pre, 1)[0:4]) # Calculate metrics for ii in range(pre.shape[0]): # Flip if necessary if idx[ii] < 13: curr_metrics = metrics_obj.evaluate(-tar_sync[ii, :], -pre[ii, :], inputs[ii, :, :, :], vspr[ii, :], atw[ii, :]) else: curr_metrics = metrics_obj.evaluate(tar_sync[ii, :], pre[ii, :], inputs[ii, :, :, :], vspr[ii, :], atw[ii, :]) #add_figure(kk, ii, curr_metrics) ## fig = plt.figure() ax = fig.add_subplot(1,1,1) top = 0.95 for key in metrics_obj.errs: mean_err = round(np.mean(metrics_obj.errs[key]), 2) print(key, mean_err, len(metrics_obj.errs[key]), metrics_obj.errs[key][-5:]) plt.text(0.1, top, f'{key}: {mean_err}', horizontalalignment='left', verticalalignment='center', transform=ax.transAxes, fontsize=20) top -= 0.1 plt.axis('off') pdf.savefig() plt.close() ## model = load_model(model_path, custom_objects={"iSyncLoss": iSyncLoss}) model_with_attn=tf.keras.models.Model(inputs=model.input,outputs=[model.output, model.get_layer('attention_weights').output]) # print(model.summary()) train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, retain_all=True, randomize_train=False, cache=cache) #inference_ds(train_generator, 'train', max_samples=200) inference_ds(val_generator, 'val', max_samples=200) def debug(): def plots(sample): # print('plot-range>', tf.reduce_max(sample[0], [0,1]), tf.reduce_min(sample[0], [0,1]), sample.shape ) rad_br_sig = sample[0, :, 0, 0] rad_hr_sig = sample[0, :, 0, 1] plt.subplot(3,1,1) plt.plot(rad_br_sig, label='Radar BR') # plt.gca().yaxis.tick_right() # plt.xticks([]) plt.legend(fontsize=7) plt.subplot(3,1,2) plt.plot(rad_hr_sig, label='Radar HR') # plt.gca().yaxis.tick_right() # plt.xticks([]) plt.legend(fontsize=7) plt.subplot(3,1,3) plt.plot(np.abs(np.fft.fft(rad_hr_sig))[:50], label='Rad HR FFT') # plt.gca().yaxis.tick_right() pdf.savefig() plt.close() train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, randomize_train=False, batch_size=1, cache=cache) # randomize_train=True, batch_size=1, cache=cache) with PdfPages(f'debug.pdf') as pdf: for kk,sample in enumerate(train_generator): print(f'chksum {kk} {np.mean(sample[0])}') # if kk%100 == 0: # plots(sample[0]) # kk +=1 # if kk==10: return # print(f'rad shape: >> {sample[0].shape}') # print(f'ppg shape: >> {sample[1].shape}') print('EPOCH') for kk,sample in enumerate(train_generator): print(f'chksum {kk} {np.mean(sample[0])}') print('train-samples>', kk) kk=0 for sample in val_generator: kk +=1 print('val-samples>', kk) def stats(): train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, cache=cache) def cal_stats(gen): ## only with small dataset -- should accumulate with suff-stats xx, yy = [], [] for train_sample in gen: xx.append( train_sample[0].numpy() ) yy.append( train_sample[1].numpy() ) xx, yy = np.concatenate(xx,0), np.concatenate(yy,0) print(f'rad shape: >> {xx.shape}', np.mean(xx), np.std(xx)) print(f'ppg shape: >> {yy.shape}', np.mean(yy), np.std(yy)) cal_stats(train_generator) cal_stats(val_generator) ## main() ## set shapes by augment() and isync() ## center-crop to get always same shape ~ =320 for target max_tar_sz = tfr_win_sz-stretch_sz-jitter_sz isync_sz = tar_sz-inp_sz print(f'sizes> tfr {tfr_win_sz} max_tar_sz {max_tar_sz} tar_sz {tar_sz} inp_sz {inp_sz} isync_sz {isync_sz}') print(f'augment> stretch +-{stretch_sz/fps}s jitter +-{jitter_sz/fps}s') print(f'iSyncLoss> time +-{(tar_sz-inp_sz)/2/fps}s') if tar_sz>max_tar_sz: print('error> target size smaller than min_tar_size from augment(), cannot proceed ...') return ## ## reproducibility tf.random.set_seed(42) random.seed(42) np.random.seed(42) epoch_var = tf.Variable(0, trainable=False) print('cmd>', cmd) if cmd == 'train': train() elif cmd == 'inference': inference() elif cmd == 'debug': debug() elif cmd == 'stats': stats() if __name__ == '__main__': VitaNet()