vitanet / code / vitanet.py
vitanet.py
Raw
import os
import click
import random
import datetime

import sys
sys.path.insert(0, '../common')
from tools import *
from dataio import data_pipe
from utils import inf_metrics 

import numpy as np
import tensorflow as tf

from os import listdir
from os.path import join, dirname
from glob import glob
from pathlib import Path

from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import tensorflow.keras.layers as L
from tensorflow.keras import Model
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K


@click.command()
@click.option("--cmd", type=str, default='train', required=False)
@click.option("--train-split-ratio", type=float, default=0.90, required=True)
@click.option("--model-path", type=str, required=True)
@click.option("--tfr-path", type=str, required=True)
@click.option('--cache', type=str, default='/data1/homes.local/superluca/cache/vitanet.dp', required=False, help='file_path|ram|disabled')
## PARM> train-epochs
@click.option("--batch-size", type=int, default=32, required=True)
@click.option("--train-epochs", type=int, default=300, required=True)
@click.option("--lr-drop-epoch", type=int, default=200, required=True)
@click.option("--isync-anneal-epoch", type=int, default=100, required=True)
# @click.option("--tar-sz", type=int, default=192, required=True)
# @click.option("--inp-sz", type=int, default=256, required=True)
# @click.option("--tar-sz", type=int, default=128, required=True)
# @click.option("--inp-sz", type=int, default=128, required=True)
@click.option("--tar-sz", type=int, default=140, required=True)
@click.option("--inp-sz", type=int, default=128, required=True)


def VitaNet(cmd: str,  model_path: str = 'model', tfr_path: str = 'tfrs', cache = None,
                train_split_ratio: float = 0.80,
                batch_size: int = 32,
                train_epochs: int = 300,
                lr_drop_epoch: int = 200,
                isync_anneal_epoch: int = 100,
                fps: int = 20,
                tfr_win_sz: int = 400,
                tar_sz: int = 320,
                inp_sz : int = 256,
                stretch_sz: int = 80,
                # jitter_sz: int = 40
                # jitter_sz: int = 80
                jitter_sz: int = 180
            ):

    def get_fnames(split_sessions='train_val_dirs'):
        if split_sessions == 'flat':
            sessions = [str(kk) for kk in glob(tfr_path+'/*')] 
            train_test_split = int(len(sessions) * train_split_ratio)
            train_sessions = sessions[:train_test_split]
            val_sessions = sessions[train_test_split:]
            print('sessions>', 'train>', train_sessions, '\nval>', val_sessions)
            train_fnames = [kk for ss in train_sessions for kk in glob(ss+'/*.tfr')]
            val_fnames = [kk for ss in val_sessions for kk in glob(ss+'/*.tfr')]
        elif split_sessions == 'flat':
            fnames = [str(kk) for kk in Path(tfr_path).rglob('*.tfr')] 
            train_test_split = int(len(fnames) * train_split_ratio)
            train_fnames = fnames[:train_test_split]
            val_fnames = fnames[train_test_split:]
        elif split_sessions == 'train_val_dirs':
            train_fnames = [str(kk) for kk in glob(tfr_path+'/train/*/*.tfr')] 
            val_fnames = [str(kk) for kk in glob(tfr_path+'/val/*/*.tfr')] 
        random.shuffle(train_fnames)
        layout_fn = join(dirname(train_fnames[0]), 'layout.pkl')
        print('train>', train_fnames, '\n', 'val>', val_fnames)
        return train_fnames, val_fnames, layout_fn

    def normalize(xx,axis):
        # ## N(0,1) norm
        # xx -= tf.reduce_mean(xx,axis=axis,keepdims=True)
        # xx /= (tf.math.reduce_std(xx,axis=axis,keepdims=True) + 1e-2)
        ## max/min norm to [-1,+1] range
        dr = tf.reduce_max(xx,axis=axis,keepdims=True) - tf.reduce_min(xx,axis=axis,keepdims=True)
        zero = (tf.reduce_max(xx,axis=axis,keepdims=True) + tf.reduce_min(xx,axis=axis,keepdims=True))/2
        xx = 2 * (xx - zero) / (dr + 1e-4) ## about~ [-1,+1] range
        # tf.print('normalize>', zero, dr, tf.reduce_max(xx,axis=axis,keepdims=True), tf.reduce_min(xx,axis=axis,keepdims=True))
        return xx

    ## get the right data-keys 
    def get_data(rec):
        # {'phase': {'shape': (400, 12, 64, 2), 'dtype': 'float32'},
        # 'phase_raw': {'shape': (400, 12, 64), 'dtype': 'float32'},
        # 'mag_raw': {'shape': (400, 12, 64), 'dtype': 'float32'},
        # 'ppg': {'shape': (400, 2, 2), 'dtype': 'float32'},
        # 'ppg_raw': {'shape': (400, 2), 'dtype': 'float32'},
        # 'vs_present': {'shape': (64,), 'dtype': 'float32'},
        # 'rec_id': {'dtype': 'str'}}
        # inp = rec['phase_raw']          ## raw-phase ==> WORKS BUT NOT BETTER THAN H|F>filt
        # inp = tf.stack( [rec['phase_raw'],rec['mag_raw']], 3) ## raw-phase + raw-mag ==> similar to above
        inp = rec['phase'][...,0:2] ## R,H phases - 64 bins
        tar = rec['ppg'][:,0,0:2]   ## R,H phases - of ir-ppg-channel
        ## select best bin with vs-present
        vspr = rec['vs_present_rad'][:,0]   ## BR - 64 bins
        vspp = rec['vs_present_ppg'][0][1]  ## Red channel, HR
        id_ = rec['rec_id']
        valid_1 = tf.cond(tf.math.reduce_max(vspr)>1.0, lambda: True, lambda: False)
        valid_2 = tf.cond(vspp>1.0, lambda: True, lambda: False)
        valid = tf.math.logical_and(valid_1, valid_2)
        # inp = inp[:,:,tf.math.argmax(vspr)]
        # inp = inp[:,:,tf.math.argmax(vspr),1:2] ## argmax-model of H/channel
        inp = inp[...,1] ##  ## fullcube-model of H/channel -- (400, 12, 64, 1)
        return ( inp, tar, valid, vspr, id_ )

    ## augmentation
    def augment_dataset(ds, randomize=True):
        # PARM> to augment or not to augment? that is the problem (says Yann:)
        # return ds
        def augment(inp, tar, vspr, id_):
            if randomize:
                ## inp=(400, 12, 64, 1)
                ## permute antenna order
                inp = tf.transpose(inp, perm=[1, 0, 2]) ## (12, 400, 64, 1)
                inp = tf.random.shuffle(inp)            ## shuffled along dimension 0 ==> a=12
                inp = tf.transpose(inp, perm=[1, 0, 2]) ## (400, 12, 64, 1)
                rand = tf.random.uniform([3])
                # rand = tf.constant([0.,0.5,0.])
                ## stretch
                stretch = tfr_win_sz + (2*rand[0]-1)*stretch_sz
                stretch = tf.cast(stretch, tf.int32)
                inp = tf.image.resize(inp, [stretch,12])
                #
                tar = tf.expand_dims(tar,2)
                tar = tf.image.resize(tar, [stretch,2])
                tar = tf.squeeze(tar)
                ## jitter
                jitter = (2*rand[1]-1)*jitter_sz
                jitter = tf.cast(jitter, tf.int32)
                inp = tf.cond(jitter>=0, lambda: inp[jitter:],  lambda: inp[:jitter])
                tar = tf.cond(jitter>=0, lambda: tar[jitter:],  lambda: tar[:jitter])
                ## 180deg input phase rotation
                inp = tf.cond(rand[2]<0.5, lambda: inp,  lambda: -inp)
                ## set new-size
                new_size = stretch-tf.math.abs(jitter)
                # tf.print('augment>', stretch, jitter, new_size)
            else:
                new_size = tfr_win_sz
            ## center-crop target to larger size than inp to allow for iSyncLoss
            ## this has to respect the min_tar_size formula from stetch/jitter!
            tar = tar[ (new_size-tar_sz)//2 : (new_size+tar_sz)//2 ]
            ## center-crop to 256 for input target
            inp = inp[ (new_size-inp_sz)//2 : (new_size+inp_sz)//2 ]
            ## FCK> manually set shape ~ because TF is dumb and gives None which breakes downstream shape inference ...
            tar = tf.reshape(tar, [tar_sz,2])
            # inp = tf.reshape(inp, [inp_sz,12,2])
            inp = tf.reshape(inp, [inp_sz,12,64]) ## fullcube-model
            # inp = tf.reshape(inp, [inp_sz,12,1])
            ## post-cache select H/channel
            # inp = inp[...,1:2]
            ## dynrange normalization to [-1,+1]
            inp = normalize(inp, [0,1]) ## R|H dependent norm ==> (1,2) dr stats
                                        ## antennas of low-dr have lower weight ==> BIT BETTER
            # inp = normalize(inp, 0) ## antenna dependent norm // R|H dependent norm ==> (12,2) dr stats
            #                         ## all antennas count the same even low-dr ones ==> BIT WORSE
            # tar = normalize(tar, 0)  ## R|H dependent norm ==> (1,2) dr stats
            # inp = inp / 0.015666926 ## global fixed norm on data-set stats
            # tar = tar / 1556.5098
            return inp, tar, vspr, id_
        ds = ds.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
        return ds


    def iSyncLoss(tar, pre, return_crops=False):
        def f_derivative(x): 
            d = x[:,:,1:] - x[:,:,:-1]
            return d
        def s_derivative(x):
            d1 = x[:,:,1:] - x[:,:,:-1]
            d2 = d1[:,:,1:] - d1[:,:,:-1]
            return d2
        def gauss(x, istd=2.0):
            return tf.math.exp( - tf.math.square(x*istd) )
        def peak_mask(xx):
            pp = tf.constant([[0,0],[0,0],[2,2]]) # padding to recover same input shape
            yy = tf.signal.frame(xx, frame_length=5, frame_step=1, axis=2) # slidewin of size 5
            ## peak/max(min) if - center elm (idx=3) is larger(smaller) than all other elm of win-5
            pmax = tf.gather(yy,[2],axis=3) > tf.reduce_max(tf.gather(yy,[0,1,3,4],axis=3),axis=3,keepdims=True)
            pmax = tf.pad(tf.squeeze(pmax),pp)
            ##
            pmin = tf.gather(yy,[2],axis=3) < tf.reduce_min(tf.gather(yy,[0,1,3,4],axis=3),axis=3,keepdims=True)
            pmin = tf.pad(tf.squeeze(pmin),pp)
            pmask = tf.cast( tf.math.logical_or(pmax, pmin), tf.float32)
            pmask += 1e-2 ## small loss for non-peaks (1% of peak loss)
            return tf.expand_dims(pmax,-1), tf.expand_dims(pmin,-1), tf.expand_dims(pmask,-1)
        ##
        pre_sz = pre.shape[1] # crop and align tar to the prediction output size
        tar = tar[...,1:2]    # H phase ~ select
        ## slide the target
        tar_slide = tf.signal.frame(tar, frame_length=pre_sz, frame_step=1, axis=1) ## use with l1/2 loss
        # tar_slide = tf.signal.frame(tar, frame_length=pre_sz+1, frame_step=1, axis=1) ## use with peak-weighted
        tar_slide = tf.concat( [tar_slide,-tar_slide] , 1) ## Sheen180 ~ correct for 180deg phase uncertanty
        ## dr stats ==> slide-window dependent!
        tar_slide = normalize(tar_slide, [2]) ## [32, 130, 128, 1] => stats: [32, 130, 1, 1]
        ## Errors
        err = tf.math.abs(tar_slide - tf.expand_dims(pre,1) ) ## L1-loss
        # err = tf.math.square(tar_slide - tf.expand_dims(pre,1) ) ## L2-loss
        err_d1 = tf.math.abs(f_derivative(tar_slide) - f_derivative(tf.expand_dims(pre,1)))
        err_d2 = tf.math.abs(s_derivative(tar_slide) - s_derivative(tf.expand_dims(pre,1)))
        ## Peaks
        # pmax, pmin, pmask = peak_mask(tar_slide)
        # tf.print(pmax.shape, pmin.shape, 'peaks>', tf.where(pmask != 0).shape[0])
        # tf.print(pmax.shape, pmin.shape, 'peaks>', tf.where(pmax).shape[0], tf.where(pmin).shape[0])
        # err *= pmask
        #
        # pmax, pmin, pmask = peak_mask(f_derivative(tar_slide))
        # err_d1 *= pmask
        #
        # pmax, pmin, pmask = peak_mask(s_derivative(tar_slide))
        # err_d2 *= pmask
        ## Error Slide ~ cumsum
        err_slide = tf.reduce_sum(err, [2,3])
        err_slide += tf.reduce_sum(err_d1, [2,3])
        err_slide += tf.reduce_sum(err_d2, [2,3])
        ##DEBUG> isync-shapes> [32, 192, 1] [32, 128, 1] [32, 130, 128, 1] [32, 130, 128, 1] [32, 130]
        print('isync-shapes>', tar.shape.as_list(), pre.shape.as_list(), tar_slide.shape.as_list(),
                               err.shape.as_list(), err_slide.shape.as_list())
        # return tf.reduce_mean(pre)
        ##
        ## best#1 loss
        # sync_independent_loss = tf.reduce_min(err_slide, -1)
        ## PARM> best#k losses ~ with annealing
        ##TODO>>> add some kk noise by random select multiple alignments even after annealed!
        # kk = tf.math.maximum(1, 10-epoch_var//10) ## 100 epochs to kk=1
        # kk = tf.math.maximum(1, 10-epoch_var//2) ## 20 epochs to kk=1
        kk = tf.math.maximum(1, 10-epoch_var//(isync_anneal_epoch//10))
        loss_sort = tf.sort(err_slide, -1)[:,:kk]
        sync_independent_loss = tf.reduce_mean(loss_sort, -1)
        ##
        if return_crops:
            idx = tf.math.argmin(err_slide, -1)
            ## DEBUG>
            # idx_sort = tf.argsort(err_slide, -1)
            # idx = idx_sort[:,2] ## 2nd best crop!
            # print(f'sync> idx {idx[0]} idx_sort {idx_sort[0]} loss_sort {loss_sort[0]} err_slide {err_slide[0]}')
            # tar_sync = tf.stack([ tar[kk,idx[kk]:idx[kk]+pre_sz] for kk in range(tar.shape[0]) ])
            tar_sync = tf.stack([ tar_slide[kk,idx[kk]] for kk in range(tar.shape[0]) ])
            return sync_independent_loss, tar_sync, idx//2 - err_slide.shape[1]//4, idx #(idx % pre_sz)- pre_sz//2 ## with Sheen180
            # return sync_independent_loss, tar_sync, idx - err_slide.shape[1]//2 #(idx % pre_sz)- pre_sz//2
        else:
            return sync_independent_loss


    def dynRangeLoss(tar, pre):
        dr = lambda xx: tf.reduce_max(xx,axis=[1]) - tf.reduce_min(xx,axis=[1])
        # tf.print('dr>', tar.shape, pre.shape, dr(tar[...,1:2]), dr(tar[...,1:2]).shape, pre)
        return tf.reduce_mean( tf.math.abs( dr(tar[...,1:2]) - pre ) ) ## L1-loss


    def get_model_base(fx=2): ## encoder->decoder model
        input_shape = (inp_sz, 12, 1)
        # input_shape = (inp_sz, 12, 64)
        # input_shape = (inp_sz, 12, 2)
        # output_shape = (out_sz, 2)

        model_input = L.Input(input_shape)
        xx = model_input

        ## Convo Enc->Dec
        ## enc(t)
        for kk in range(4):
            rx = L.AveragePooling2D(pool_size=(3,1), strides=(2,1), padding='same')(xx) ## residual
            xx = L.Convolution2D(filters=32*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx)
            xx = L.Dropout(0.2)(xx)
            if kk>0: xx = xx + rx ## residual
            xx = L.BatchNormalization()(xx)
        ## enc(t,a)
        # for kk in range(3): ## inp=256
        for kk in range(2): ## inp=128
            rx = L.AveragePooling2D(pool_size=(3,3), strides=(2,2), padding='same')(xx) ## residual
            xx = L.Convolution2D(filters=64*fx, kernel_size=(3,3), strides=(2,2), padding='same', activation='gelu')(xx)
            xx = L.Dropout(0.2)(xx)
            if kk>0: xx = xx + rx ## residual
            xx = L.BatchNormalization()(xx)
        ## Pool
        ## Embedding=128dim -- from signal of tfr_win_sz dim -- compression of ~3.6x
        # xx = L.AveragePooling2D(pool_size=(2, 2), strides=(1, 1), padding='valid')(xx)
        # xx = L.Convolution2D(filters=256, kernel_size=(2,2), strides=(1,1), padding='valid', activation='gelu')(xx) ## inp=256
        xx = L.Convolution2D(filters=256, kernel_size=(2,3), strides=(1,1), padding='valid', activation='gelu')(xx)  ## inp=128
        xx = L.BatchNormalization()(xx)
        print('embedding-shape>', xx.shape[1:], 'size>', np.prod(np.array(xx.shape[1:].as_list())) )

        # ## DynRange model
        # dr = L.Dense(1, activation='linear')(xx)
        # model= Model([model_input], dr)
        # print(model.summary())
        # return model

        ## FC-mixer layer
        # xx = L.Dense(128, activation='tanh')(xx)
        # xx = L.BatchNormalization()(xx)
        ## dec(t,a) --> but don't upsample a(ntennas)
        # for kk in range(8): ## inp=256 out=256
        for kk in range(7): ## inp=256 out=128
        # for kk in range(6): ## inp=128 out=64
            rx = L.UpSampling2D(size=(2,1))(xx) ## up-residual
            xx = L.Convolution2DTranspose(filters=64*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx)
            xx = L.Dropout(0.2)(xx)
            if kk>0: xx = xx + rx ## residual
            xx = L.BatchNormalization()(xx)
        # ## dec(t)
        # for kk in range(4):
        #     xx = L.Convolution2DTranspose(filters=32, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx)
        #     xx = L.BatchNormalization()(xx)
        ## up-sample yields larger tensor because input
        ## not needed if pow2 input dim[0]
        # xx = L.CenterCrop(output_shape[0],1)(xx)
        ## FC-out & shape
        # xx = L.Dense(2, activation='linear')(xx)
        xx = L.Dense(128, activation='gelu')(xx)
        xx = L.Dropout(0.2)(xx)
        xx = L.BatchNormalization()(xx)
        xx = L.Dense(1, activation='linear')(xx)
        # xx = L.Reshape(output_shape)(xx)
        xx = xx[:,:,0,:]
        ## Output normalization
        # xx = L.Normalization()(xx) ## weird doenst really normalize well or at all ...
        # xx = normalize(xx,[1])

        model= Model([model_input], xx)
        print(model.summary())
        return model


    def encoder_decoder(xx, fx):
        ## enc(t)
        for kk in range(4):
            rx = L.AveragePooling2D(pool_size=(3,1), strides=(2,1), padding='same')(xx) ## residual
            xx = L.Convolution2D(filters=32*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx)
            xx = L.Dropout(0.2)(xx)
            if kk>0: xx = xx + rx ## residual
            xx = L.BatchNormalization()(xx)
        ## enc(t,a)
        for kk in range(2): ## inp=128
            rx = L.AveragePooling2D(pool_size=(3,1), strides=(2,1), padding='same')(xx) ## residual
            xx = L.Convolution2D(filters=64*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx)
            xx = L.Dropout(0.2)(xx)
            if kk>0: xx = xx + rx ## residual
            xx = L.BatchNormalization()(xx)
        ## Pool
        # xx = L.Convolution2D(filters=256, kernel_size=(2,attention_heads), strides=(1,1), padding='valid', activation='gelu')(xx)  ## inp=128
        xx = L.GlobalAveragePooling2D(keepdims=True)(xx)
        xx = L.Dense(256, activation='gelu')(xx)
        xx = L.BatchNormalization()(xx)
        print('embedding-shape>', xx.shape[1:], 'size>', np.prod(np.array(xx.shape[1:].as_list())) )
        for kk in range(7): ## inp=128 out=128
            rx = L.UpSampling2D(size=(2,1))(xx) ## up-residual
            xx = L.Convolution2DTranspose(filters=64*fx, kernel_size=(3,1), strides=(2,1), padding='same', activation='gelu')(xx)
            xx = L.Dropout(0.2)(xx)
            if kk>0: xx = xx + rx ## residual
            xx = L.BatchNormalization()(xx)
        # xx = L.Dense(128, activation='gelu')(xx)
        # xx = L.Dropout(0.2)(xx)
        # xx = L.BatchNormalization()(xx)
        xx = L.Dense(1, activation='linear')(xx)
        xx = xx[:,:,0,:]
        return xx


    # def get_model_flat_attention(fx=2, attention_heads=8, scale=1.0): ## attentional model on flat A*B => 12*64=768
    def get_model_flat_attention(fx=2, attention_heads=48, scale=1.0): ## attentional model on flat A*B => 12*64=768
        input_shape = (inp_sz, 12, 64)
        model_input = L.Input(input_shape)
        ## flat 64*12 => multi-head attention
        xx = L.Reshape([inp_sz, 12*64, 1])(model_input)  ## (None, 128, 768, 1)
        ## attention-encoder
        for kk in range(3):
            xx = L.Convolution2D(filters=32, kernel_size=(16,1), strides=(2,1), padding='valid', activation='gelu')(xx)
            xx = L.BatchNormalization()(xx)
        xx = tf.reduce_mean(xx, 1)
        ## attention-projector
        xx = L.Dense(attention_heads, activation='linear')(xx)
        multi_head_attention_weights = L.Softmax(axis=1, name='attention_weights')(xx) * scale
        xx = tf.matmul( L.Reshape([inp_sz, 12*64])(model_input) , multi_head_attention_weights )
        xx = L.Reshape([inp_sz, attention_heads, 1])(xx)
        ## encoder-decoder trunk
        xx = encoder_decoder(xx, fx)
        ## out
        model= Model([model_input], xx)
        print(model.summary())
        return model


    def get_model_bin_attention_flat_ant(fx=2, attention_heads=8, scale=1.0): ## attentional model bins B=64
        input_shape = (inp_sz, 12, 64)
        model_input = L.Input(input_shape)
        ## bin attention on 64 => multi-head attention
        xx = L.Permute([1,3,2])(model_input)  ## (None, 128, 64, 12)
        ## attention-encoder
        for kk in range(3):
            xx = L.Convolution2D(filters=32, kernel_size=(16,1), strides=(2,1), padding='valid', activation='gelu')(xx)
            xx = L.BatchNormalization()(xx)
        xx = tf.reduce_mean(xx, 1)
        ## attention-projector
        xx = L.Dense(attention_heads, activation='linear')(xx)
        multi_head_attention_weights = L.Softmax(axis=1, name='attention_weights')(xx) * scale
        xx = tf.matmul( L.Reshape([inp_sz*12,64])(model_input), multi_head_attention_weights )
        xx = L.Reshape([inp_sz, 12, attention_heads])(xx)
        xx = L.Permute([1,3,2])(xx)  ## (None, 128, 8, 12)
        ## (None, 128, 8, 12) => (None, 128, 8*12, 1) :: treat attn-bins and antennas the same way
        xx = L.Reshape([inp_sz, attention_heads*12, 1])(xx) ## doesnt explint intra-antenna structure - such as AOA/BeamForm
        ## encoder-decoder trunk
        xx = encoder_decoder(xx, fx)
        ## out
        model = Model([model_input], [xx])
        print(model.summary())
        return model


    def get_model_bin_attention(fx=2, attention_heads=8, scale=1.0): ## attentional model bins B=64 / input-reduce-antennas
        input_shape = (inp_sz, 12, 64)
        model_input = L.Input(input_shape)
        ## bin attention on 64 => multi-head attention
        xx = L.Permute([1,3,2])(model_input)  ## (None, 128, 64, 12)
        ##
        # reduced_input = L.Dense(1, activation='linear')(xx)
        reduced_input = tf.reduce_mean(xx, 3)
        reduced_input = L.Reshape([inp_sz, 64, 1])(reduced_input)
        xx = reduced_input
        ## attention-encoder
        for kk in range(3):
            xx = L.Convolution2D(filters=32, kernel_size=(16,1), strides=(2,1), padding='valid', activation='gelu')(xx)
            xx = L.BatchNormalization()(xx)
        xx = tf.reduce_mean(xx, 1)
        ## attention-projector
        xx = L.Dense(attention_heads, activation='linear')(xx)
        multi_head_attention_weights = L.Softmax(axis=1, name='attention_weights')(xx) * scale
        xx = tf.matmul( L.Reshape([inp_sz,64])(reduced_input), multi_head_attention_weights )
        xx = L.Reshape([inp_sz, attention_heads, 1])(xx)
        # xx = L.Permute([1,3,2])(xx)  ## (None, 128, 8, 1)
        # ## (None, 128, 8, 1) => (None, 128, 8*12, 1) :: treat attn-bins and antennas the same way
        # xx = L.Reshape([inp_sz, attention_heads, 1])(xx) ## doesnt explint intra-antenna structure - such as AOA/BeamForm
        ## encoder-decoder trunk
        xx = encoder_decoder(xx, fx)
        ## out
        model = Model([model_input], [xx])
        print(model.summary())
        return model


    def get_model():
        # return get_model_flat_attention()
        return get_model_bin_attention_flat_ant()
        # return get_model_bin_attention()


    def train():
        def scheduler(epoch, lr):
            K.set_value(epoch_var, epoch)
            if epoch < 5: return 1e-4
            elif epoch < lr_drop_epoch: return 1e-3
            else: return 2*1e-4

        model = get_model()
        optimizer = tf.keras.optimizers.RMSprop(learning_rate=1e-4)
        # optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
        # model.compile(optimizer=optimizer, loss='mse', metrics=['mse', 'mae'])
        print(model)
        model.compile(optimizer=optimizer, loss=iSyncLoss)
        # model.compile(optimizer=optimizer, loss=dynRangeLoss)

        checkpoint_cb = ModelCheckpoint(filepath=model_path,
                save_weights_only=False,
                save_freq='epoch',
                # monitor='val_mae',
                save_best_only=True,
                verbose=1)

        train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, batch_size, cache=cache)

        tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir=join(model_path,'tb'), histogram_freq=1)

        scheduler_cb = tf.keras.callbacks.LearningRateScheduler(scheduler)

        model.fit(train_generator,
                  validation_data=val_generator,
                  epochs=train_epochs,
                  callbacks=[tensorboard_cb, scheduler_cb])

        model.save(model_path)


    def inference():

        metrics_obj = inf_metrics.Metrics()
        def inference_ds(ds, fname, max_samples=None):
            def add_figure(kk, ii, curr_metrics):
                add_flag, high_error = False, False
                if kk%4 and ii==0: ## add first sample of batch every 4 batches
                    add_flag=True
                #elif abs(pre_hr-ppg_hr)>20: ## or if large error
                #  high_error = True
                if not add_flag: return
                ##

                sign = -1 if idx[ii] < 13 else 1

                plt.figure()
                plt.subplot(3,2,1)
                plt.plot(sign * tar_sync[ii, :], label='Ground-Truth (PPG/HR)')
                plt.plot(sign * pre[ii, :], label='Prediction (VitaNet/HR)')
                sd_ppg = sum(curr_metrics['sd_ppg'], [])
                sd_pre = sum(curr_metrics['sd_vita'], [])
                tar_y = [sign * tar_sync[ii, sd_idx] for sd_idx in sd_ppg]
                pre_y = [sign * pre[ii, sd_idx] for sd_idx in sd_pre]
                plt.plot(sd_ppg, tar_y, 'r*')
                plt.plot(sd_pre, pre_y, 'g*')
                plt.title(f'isync: {offset_index[ii]} : {idx[ii]}')
                plt.legend(fontsize=7)
                plt.subplot(3,2,3)
                plt.plot(sign * (tar_sync[ii, 1:] - tar_sync[ii, :-1]), label='Ground-Truth (PPG/HR)')
                plt.plot(sign * (pre[ii, 1:] - pre[ii, :-1]), label='Prediction (VitaNet/HR)')
                plt.ylabel('First Derivative')
                plt.legend(fontsize=7)
                plt.subplot(3,2,5)
                ppg_diff = tar_sync[ii, 1:] - tar_sync[ii, :-1]
                pre_diff = pre[ii, 1:] - pre[ii, :-1]
                plt.plot(sign * (ppg_diff[1:] - ppg_diff[:-1]), label='Ground-Truth (PPG/HR)')
                plt.plot(sign * (pre_diff[1:] - pre_diff[:-1]), label='Prediction (VitaNet/HR)')
                plt.ylabel('Second Derivative')
                print(curr_metrics)
                plt.xlabel(f'VN-HR-Err: {curr_metrics["hr_vita"]}, Rad-HR-Err: {curr_metrics["hr_rad_multi"]}, ' \
                           f'S-D: {curr_metrics["sys_dia_diff_error"]}')
                #obj = plt.xlabel(f'HRs: P: {ppg_hr}, VN: {pre_hr}, Rad: {rad_hr} {"HIGH-ERROR!" if high_error else ""}')
                #if high_error: plt.setp(obj, color='r')
                plt.legend(fontsize=7)

                ## Radar plots
                max_bin = np.argmax(vspr[ii,:])
                rad_br_sig = np.mean(inputs[ii, :, :, max_bin], axis=1)
                rad_hr_sig = np.mean(inputs[ii, :, :, max_bin], axis=1)
                plt.subplot(3,2,2)
                plt.plot(rad_br_sig, label='Radar BR')
                plt.gca().yaxis.tick_right()
                plt.legend(fontsize=7)
                plt.subplot(3,2,4)
                plt.plot(atw[0], label='Attention')
                plt.gca().yaxis.tick_right()
                plt.legend(fontsize=7)
                plt.subplot(3,2,6)
                plt.plot(np.abs(np.fft.fft(rad_hr_sig))[:30], label='Rad HR FFT')
                plt.gca().yaxis.tick_right()
                plt.legend(fontsize=7)
                ##
                pdf.savefig()
                plt.close()
            ##

            chake_dir(f'{model_path}/inference')
            with PdfPages(f'{model_path}/inference/{fname}.pdf') as pdf:
                # for kk, element in enumerate(ds.as_numpy_iterator()):
                for kk, element in enumerate(ds):
                    if max_samples and kk==max_samples: break
                    # print(f'inf> {fname} :: {kk}', end='\r')
                    inputs, tar, vspr, id_ = element
                    pre, atw = model_with_attn.predict(inputs)
                    pre = tf.convert_to_tensor(pre)
                    loss, tar_sync, offset_index, idx = iSyncLoss(tar, pre, return_crops=True)

                    print('tar_sync>', tar_sync.shape, np.mean(tar_sync, 1)[0:4], np.std(tar_sync, 1)[0:4],
                                        np.amin(tar_sync, 1)[0:4], np.amax(tar_sync, 1)[0:4])
                    print('pre>', pre.shape, np.mean(pre, 1)[0:4], np.std(pre, 1)[0:4],
                                    np.amin(pre, 1)[0:4], np.amax(pre, 1)[0:4])

                    # Calculate metrics
                    for ii in range(pre.shape[0]):
                        # Flip if necessary
                        if idx[ii] < 13: 
                            curr_metrics = metrics_obj.evaluate(-tar_sync[ii, :], -pre[ii, :], inputs[ii, :, :, :], 
                                                                vspr[ii, :], atw[ii, :])
                        else:
                            curr_metrics = metrics_obj.evaluate(tar_sync[ii, :], pre[ii, :], inputs[ii, :, :, :], 
                                                                vspr[ii, :], atw[ii, :])
                        #add_figure(kk, ii, curr_metrics)

                ##
                fig = plt.figure()
                ax = fig.add_subplot(1,1,1)
                top = 0.95
                for key in metrics_obj.errs:
                    mean_err = round(np.mean(metrics_obj.errs[key]), 2)
                    print(key, mean_err, len(metrics_obj.errs[key]), metrics_obj.errs[key][-5:])
                    plt.text(0.1, top, f'{key}: {mean_err}', horizontalalignment='left', 
                             verticalalignment='center', transform=ax.transAxes, fontsize=20)
                    top -= 0.1
                plt.axis('off')
                pdf.savefig()
                plt.close()

        ##
        model = load_model(model_path, custom_objects={"iSyncLoss": iSyncLoss})
        model_with_attn=tf.keras.models.Model(inputs=model.input,outputs=[model.output, model.get_layer('attention_weights').output])
        # print(model.summary())
        train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, retain_all=True, 
                                                   randomize_train=False, cache=cache)
        #inference_ds(train_generator, 'train', max_samples=200)
        inference_ds(val_generator, 'val', max_samples=200)


    def debug():
        def plots(sample):
            # print('plot-range>', tf.reduce_max(sample[0], [0,1]), tf.reduce_min(sample[0], [0,1]), sample.shape )
            rad_br_sig = sample[0, :, 0, 0]
            rad_hr_sig = sample[0, :, 0, 1]
            plt.subplot(3,1,1)
            plt.plot(rad_br_sig, label='Radar BR')
            # plt.gca().yaxis.tick_right()
            # plt.xticks([])
            plt.legend(fontsize=7)
            plt.subplot(3,1,2)
            plt.plot(rad_hr_sig, label='Radar HR')
            # plt.gca().yaxis.tick_right()
            # plt.xticks([])
            plt.legend(fontsize=7)
            plt.subplot(3,1,3)
            plt.plot(np.abs(np.fft.fft(rad_hr_sig))[:50], label='Rad HR FFT')
            # plt.gca().yaxis.tick_right()
            pdf.savefig()
            plt.close()

        train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, 
                                                    randomize_train=False, batch_size=1, cache=cache)
                                                    # randomize_train=True, batch_size=1, cache=cache)

        with PdfPages(f'debug.pdf') as pdf:
            for kk,sample in enumerate(train_generator):
                print(f'chksum {kk} {np.mean(sample[0])}')
                # if kk%100 == 0:
                # plots(sample[0])
                # kk +=1
                # if kk==10: return
                # print(f'rad shape: >> {sample[0].shape}')
                # print(f'ppg shape: >> {sample[1].shape}')
            print('EPOCH')
            for kk,sample in enumerate(train_generator):
                print(f'chksum {kk} {np.mean(sample[0])}')


        print('train-samples>', kk)
        kk=0
        for sample in val_generator: kk +=1
        print('val-samples>', kk)

    def stats():
        train_generator, val_generator = data_pipe(get_fnames, get_data, augment_dataset, cache=cache)
        def cal_stats(gen):
            ## only with small dataset -- should accumulate with suff-stats
            xx, yy = [], []
            for train_sample in gen:
                xx.append( train_sample[0].numpy() )
                yy.append( train_sample[1].numpy() )
            xx, yy = np.concatenate(xx,0), np.concatenate(yy,0)
            print(f'rad shape: >> {xx.shape}', np.mean(xx), np.std(xx))
            print(f'ppg shape: >> {yy.shape}', np.mean(yy), np.std(yy))
        cal_stats(train_generator)
        cal_stats(val_generator)

    ## main()
    ## set shapes by augment() and isync()
    ## center-crop to get always same shape ~ =320 for target
    max_tar_sz = tfr_win_sz-stretch_sz-jitter_sz
    isync_sz = tar_sz-inp_sz
    print(f'sizes> tfr {tfr_win_sz} max_tar_sz {max_tar_sz} tar_sz {tar_sz} inp_sz {inp_sz} isync_sz {isync_sz}')
    print(f'augment> stretch +-{stretch_sz/fps}s jitter +-{jitter_sz/fps}s')
    print(f'iSyncLoss> time +-{(tar_sz-inp_sz)/2/fps}s')
    if tar_sz>max_tar_sz:
        print('error> target size smaller than min_tar_size from augment(), cannot proceed ...')
        return
    ##
    ## reproducibility
    tf.random.set_seed(42)
    random.seed(42)
    np.random.seed(42)
    epoch_var = tf.Variable(0, trainable=False)

    print('cmd>', cmd)
    if cmd == 'train': train()
    elif cmd == 'inference': inference()
    elif cmd == 'debug': debug()
    elif cmd == 'stats': stats()


if __name__ == '__main__':
    VitaNet()