vitanet / common / dataio.py
dataio.py
Raw
import os
import sys
import pickle
import hashlib
import numpy as np
import tensorflow as tf

g_layout=None

def chake_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)
    return directory

# consitency check // same layout across database
def layout_consitency_check(layout):
    global g_layout
    if g_layout and g_layout != layout:
        print('ERROR: inconsistent layout\n', layout, '\n', g_layout)
        sys.exit(-1)
    g_layout = layout

def layout_save(opath='tfrs'):
    global g_layout
    print('LAYOUT:', g_layout)
    pickle.dump(g_layout, open(opath+'/layout.pkl', "wb"))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
      """Returns an int64_list from a bool / enum / int / uint."""
      return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def tfr_serialize(**kwargs):
    dd={} ; layout={}
    for k,v in kwargs.items():
        # print("%s = %s" % (k, v), type(v))
        if isinstance(v, int):
            dd[k] = _int64_feature(v)
            layout[k] = {'dtype': 'int64'}
            # print('    rec>', k, 'layout:', layout[k])
        elif isinstance(v, float):
            dd[k] = _float_feature(v)
            layout[k] = {'dtype': 'float32'}
            # print('    rec>', k, 'layout:', layout[k])
        elif isinstance(v, bytes):
            dd[k] = _bytes_feature(v)
            layout[k] = {'dtype': 'bytes'}
            # print('    rec>', k, 'layout:', layout[k])
        elif isinstance(v, str):
            dd[k] = _bytes_feature(v.encode())
            layout[k] = {'dtype': 'str'}
            # print('    rec>', k, 'layout:', layout[k])
        elif isinstance(v, np.ndarray):
            dd[k] = _bytes_feature( v.tobytes() )
            layout[k] = {'shape': v.shape, 'dtype': str(v.dtype)}
            print('    rec>', k, 'layout:', layout[k])
        else:
            print('ERROR: unknown type', type(v), k, v)
            sys.exit(-1)
    print('    keys>', list(dd.keys()))
    return dd, layout

def make_tfr_parser(layout_fn='tfrs/layout.pkl'):
    type_map = {'int8':tf.int8, 'uint8':tf.uint8, 'int32':tf.int32, 'int64':tf.int64, 'int16':tf.int16,
                'int':tf.int32, 'str':tf.string, 'float32':tf.float32,'float64':tf.float64}
    layout = pickle.load(open(layout_fn, "rb"))
    keys_to_features = {}
    for key, val in layout.items():
        if 'shape' not in val:
            if val['dtype'] == 'int64':
                keys_to_features[key] = tf.io.FixedLenFeature([], tf.int64)
            elif val['dtype'] == 'float32':
                keys_to_features[key] = tf.io.FixedLenFeature([], tf.float32)
            elif val['dtype'] == 'str':
                keys_to_features[key] = tf.io.FixedLenFeature([], tf.string)
            else:
                raise ValueError(f'cannot parse layout: {val}')
        else:
            keys_to_features[key] = tf.io.FixedLenFeature([], tf.string)
    #print('layout>', layout)
    #print('keys_to_features>', keys_to_features)
    def tfr_parser(record):
        odic = tf.io.parse_single_example(record, keys_to_features)
        for key, val in layout.items():
            # if key == 'times': continue ## giant HACK HACK for hnm ~ will regret it will break things in the future will be pain and awfulness ...
            if val['dtype'] in type_map:
                if 'shape' in val:
                    odic[key] = tf.io.decode_raw(odic[key], type_map[val['dtype']])
                    odic[key] = tf.reshape(odic[key], val['shape'])
                    #print('rec>', key, val, type_map[val['dtype']], val['shape'])
                else:
                    #print('rec>', key, val, type_map[val['dtype']])
                    pass
            else:
                print('WARNING> type not found in layout', key, val)
        return odic
    return tfr_parser

def get_next_batch(layout_fn, tfr_list, batch_size=1):
    tfr_parser = make_tfr_parser(layout_fn)
    dataset = tf.data.TFRecordDataset(tfr_list)
    dataset = dataset.map(tfr_parser, num_threads=2, output_buffer_size=10*batch_size)
    dataset = dataset.shuffle(buffer_size=100)
    dataset = dataset.batch(batch_size)
    # dataset = dataset.repeat(1)
    iterator = dataset.make_one_shot_iterator()
    return iterator.get_next()

## new fully contained i/o classes
class TfrReader:
    def __init__(self, layout_fn, tfr_list, batch_size=1, repeat=1, shuffle=False, resample=False, class_num=2):
        print('TfrReader> started')
        tfr_parser = make_tfr_parser(layout_fn)
        dataset = tf.data.TFRecordDataset(tfr_list)
        dataset = dataset.map(tfr_parser) #, num_threads=4, output_buffer_size=100*batch_size)
        if shuffle:
            dataset = dataset.shuffle(buffer_size=100*batch_size)
        if resample:
            dataset = dataset.apply(tf.data.rejection_resample(
                        class_func=lambda x : get_lab_id_tensor(x['class_id']),
                        target_dist=[1.0/class_num for cc in range(class_num)],
                        initial_dist=None)).map(lambda a,b: b)
        dataset = dataset.batch(batch_size)
        # if repeat 0:
        self.dataset = dataset.repeat(repeat)

    def items(self):
        return self.dataset
    #     iterator = dataset.make_initializable_iterator()

    # def __del__(self):
    #     print('TfrReader> stopped')

    # def read_batch(self):
    #     return self.iterator.get_next()


class TfrWriter:
    def __init__(self, opath='tfrs_fg', data_set='fg', tfr_size=64):
        print('TfrWriter> started', opath)
        self.opath=opath
        self.data_set=data_set
        self.tfr_size=tfr_size
        self.writer=None
        self.rec_nn=0
        self.tfr_nn=0
        chake_dir(self.opath)

    def __del__(self):
        print('TfrWriter> stopped')
        if self.writer:
            self.writer.close()
        layout_save(self.opath)

    def write_record(self, rec, layout):
        if self.rec_nn == self.tfr_size:
            print('TfrWriter> close')
            self.writer.close()
            self.rec_nn = 0
            self.writer = None
        if self.writer == None:
            fo = '{}/{}-{:03d}.tfr'.format(self.opath,self.data_set,self.tfr_nn)
            self.writer = tf.io.TFRecordWriter(fo)
            print('TfrWriter> open', fo)
            self.tfr_nn += 1
        example = tf.train.Example(features=tf.train.Features(feature=rec))
        self.writer.write(example.SerializeToString())
        layout_consitency_check(layout)
        self.rec_nn += 1

## main dataset creator function!
## cache: None|'disabled':disabled, 'cache_file_path': disk cache, 
##       'RAM': cache in memory (all ds needs to fit in RAM or it will crash during the 1st epoch)
def data_pipe(get_fnames_func, get_data_func, augment_func, batch_size=32, nshuffle=128, retain_all=False, randomize_train=True, cache=None):
    ##
    def load_dataset(fnames, layout_fn, shuffle=False, rand_augment=True, cache=cache):
        ds = tf.data.TFRecordDataset(fnames, buffer_size=1<<24, ## 16mb
                                            num_parallel_reads=tf.data.AUTOTUNE)
        ds = ds.map( make_tfr_parser(layout_fn), num_parallel_calls=tf.data.AUTOTUNE ) ## parse into rec dict
        ds = ds.map( get_data_func, num_parallel_calls=tf.data.AUTOTUNE ) ## get (inp,tar,valid,vspr,id_)
        ds = ds.filter( lambda inp,tar,valid,vspr,id_:valid)     ## remove invalid samples
        ds = ds.map( lambda inp,tar,valid,vspr,id_: (inp, tar, vspr, id_)) ## drop valid flag (inp,tar,vspr,id_)
            
        if cache and cache != 'disabled': 
            if cache == 'ram': 
                ds = ds.cache()
            else: ## cache needs to be data-set specific (fnames = train|va`l)
                hexhash = hex(int(hashlib.sha1(str(fnames).encode("utf-8")).hexdigest(), 16) % (10 ** 8))
                cache_path = f'{cache}.{hexhash}'
                print('cache>', cache_path, 'found' if os.path.exists(cache_path+'.index') else 'building!')
                ds = ds.cache(cache_path)
        if shuffle: ds = ds.shuffle(buffer_size=nshuffle)
        ds = augment_func(ds, rand_augment)
        if not retain_all:
            ds = ds.map( lambda inp,tar,vspr,id_: (inp, tar)) ## only retain (inp, target)
        ds = ds.batch(batch_size)
        ds = ds.prefetch(tf.data.AUTOTUNE)

        print('ds>', ds)
        return ds
    ##
    train_fnames, val_fnames, layout_fn = get_fnames_func()
    train_ds = load_dataset(train_fnames, layout_fn, shuffle=randomize_train, rand_augment=randomize_train, cache=cache)
    val_ds = load_dataset(val_fnames, layout_fn, shuffle=False, rand_augment=False, cache=cache)
    return train_ds, val_ds


if __name__ == '__main__':
    import argparse
    import matplotlib.pyplot as plt
    from os.path import join, dirname, exists, basename
    from prepro import PrePro

    parser = argparse.ArgumentParser()
    parser.add_argument('-tf','--tfr', default=None, help='tf-record')
    args = parser.parse_known_args()[0]

    prpr = PrePro()

    def plot(xx):
        def reduce(xx):
            xx = np.abs(xx)
            xx = np.mean(xx,3)
            xx = xx[:,:,:,0]+xx[:,:,:,1]
            xx = np.mean(xx,2)
            print('plot>', xx.shape)
            return xx[:,0:1]

        f, (ax1, ax2) = plt.subplots(2, 1) #, sharey=True)

        if xx.shape[-1]==63: ## raw-prepro
            xx,yy = prpr(sess, xx)
            ax1.plot(reduce(xx))
            ax2.plot(reduce(yy))
        else:
            ax1.plot(reduce(xx))
        # ax2.imshow(xx.transpose(1,0))
        plt.show()
        return

    rr=TfrReader(dirname(args.tfr)+'/layout.pkl', args.tfr)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        rr.init(sess)

        sample_nn=0
        while True:
            try:
                batch = sess.run([rr.read_batch()])[0]
                nn=batch['sel'].shape[0]
                #print('batch>', batch.keys(), nn)
                for kk in range(0,nn):
                    for ky,vl in batch.items():
                        if ky == 'data':
                            #print(ky, vl.shape)
                            plot(vl[0])
                        else:
                            #print(ky, vl)
                            pass

            except tf.errors.OutOfRangeError:
                # print 'END OF TFR'
                break