import os import sys import pickle import hashlib import numpy as np import tensorflow as tf g_layout=None def chake_dir(directory): if not os.path.exists(directory): os.makedirs(directory) return directory # consitency check // same layout across database def layout_consitency_check(layout): global g_layout if g_layout and g_layout != layout: print('ERROR: inconsistent layout\n', layout, '\n', g_layout) sys.exit(-1) g_layout = layout def layout_save(opath='tfrs'): global g_layout print('LAYOUT:', g_layout) pickle.dump(g_layout, open(opath+'/layout.pkl', "wb")) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def _float_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) def _int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def tfr_serialize(**kwargs): dd={} ; layout={} for k,v in kwargs.items(): # print("%s = %s" % (k, v), type(v)) if isinstance(v, int): dd[k] = _int64_feature(v) layout[k] = {'dtype': 'int64'} # print(' rec>', k, 'layout:', layout[k]) elif isinstance(v, float): dd[k] = _float_feature(v) layout[k] = {'dtype': 'float32'} # print(' rec>', k, 'layout:', layout[k]) elif isinstance(v, bytes): dd[k] = _bytes_feature(v) layout[k] = {'dtype': 'bytes'} # print(' rec>', k, 'layout:', layout[k]) elif isinstance(v, str): dd[k] = _bytes_feature(v.encode()) layout[k] = {'dtype': 'str'} # print(' rec>', k, 'layout:', layout[k]) elif isinstance(v, np.ndarray): dd[k] = _bytes_feature( v.tobytes() ) layout[k] = {'shape': v.shape, 'dtype': str(v.dtype)} print(' rec>', k, 'layout:', layout[k]) else: print('ERROR: unknown type', type(v), k, v) sys.exit(-1) print(' keys>', list(dd.keys())) return dd, layout def make_tfr_parser(layout_fn='tfrs/layout.pkl'): type_map = {'int8':tf.int8, 'uint8':tf.uint8, 'int32':tf.int32, 'int64':tf.int64, 'int16':tf.int16, 'int':tf.int32, 'str':tf.string, 'float32':tf.float32,'float64':tf.float64} layout = pickle.load(open(layout_fn, "rb")) keys_to_features = {} for key, val in layout.items(): if 'shape' not in val: if val['dtype'] == 'int64': keys_to_features[key] = tf.io.FixedLenFeature([], tf.int64) elif val['dtype'] == 'float32': keys_to_features[key] = tf.io.FixedLenFeature([], tf.float32) elif val['dtype'] == 'str': keys_to_features[key] = tf.io.FixedLenFeature([], tf.string) else: raise ValueError(f'cannot parse layout: {val}') else: keys_to_features[key] = tf.io.FixedLenFeature([], tf.string) #print('layout>', layout) #print('keys_to_features>', keys_to_features) def tfr_parser(record): odic = tf.io.parse_single_example(record, keys_to_features) for key, val in layout.items(): # if key == 'times': continue ## giant HACK HACK for hnm ~ will regret it will break things in the future will be pain and awfulness ... if val['dtype'] in type_map: if 'shape' in val: odic[key] = tf.io.decode_raw(odic[key], type_map[val['dtype']]) odic[key] = tf.reshape(odic[key], val['shape']) #print('rec>', key, val, type_map[val['dtype']], val['shape']) else: #print('rec>', key, val, type_map[val['dtype']]) pass else: print('WARNING> type not found in layout', key, val) return odic return tfr_parser def get_next_batch(layout_fn, tfr_list, batch_size=1): tfr_parser = make_tfr_parser(layout_fn) dataset = tf.data.TFRecordDataset(tfr_list) dataset = dataset.map(tfr_parser, num_threads=2, output_buffer_size=10*batch_size) dataset = dataset.shuffle(buffer_size=100) dataset = dataset.batch(batch_size) # dataset = dataset.repeat(1) iterator = dataset.make_one_shot_iterator() return iterator.get_next() ## new fully contained i/o classes class TfrReader: def __init__(self, layout_fn, tfr_list, batch_size=1, repeat=1, shuffle=False, resample=False, class_num=2): print('TfrReader> started') tfr_parser = make_tfr_parser(layout_fn) dataset = tf.data.TFRecordDataset(tfr_list) dataset = dataset.map(tfr_parser) #, num_threads=4, output_buffer_size=100*batch_size) if shuffle: dataset = dataset.shuffle(buffer_size=100*batch_size) if resample: dataset = dataset.apply(tf.data.rejection_resample( class_func=lambda x : get_lab_id_tensor(x['class_id']), target_dist=[1.0/class_num for cc in range(class_num)], initial_dist=None)).map(lambda a,b: b) dataset = dataset.batch(batch_size) # if repeat 0: self.dataset = dataset.repeat(repeat) def items(self): return self.dataset # iterator = dataset.make_initializable_iterator() # def __del__(self): # print('TfrReader> stopped') # def read_batch(self): # return self.iterator.get_next() class TfrWriter: def __init__(self, opath='tfrs_fg', data_set='fg', tfr_size=64): print('TfrWriter> started', opath) self.opath=opath self.data_set=data_set self.tfr_size=tfr_size self.writer=None self.rec_nn=0 self.tfr_nn=0 chake_dir(self.opath) def __del__(self): print('TfrWriter> stopped') if self.writer: self.writer.close() layout_save(self.opath) def write_record(self, rec, layout): if self.rec_nn == self.tfr_size: print('TfrWriter> close') self.writer.close() self.rec_nn = 0 self.writer = None if self.writer == None: fo = '{}/{}-{:03d}.tfr'.format(self.opath,self.data_set,self.tfr_nn) self.writer = tf.io.TFRecordWriter(fo) print('TfrWriter> open', fo) self.tfr_nn += 1 example = tf.train.Example(features=tf.train.Features(feature=rec)) self.writer.write(example.SerializeToString()) layout_consitency_check(layout) self.rec_nn += 1 ## main dataset creator function! ## cache: None|'disabled':disabled, 'cache_file_path': disk cache, ## 'RAM': cache in memory (all ds needs to fit in RAM or it will crash during the 1st epoch) def data_pipe(get_fnames_func, get_data_func, augment_func, batch_size=32, nshuffle=128, retain_all=False, randomize_train=True, cache=None): ## def load_dataset(fnames, layout_fn, shuffle=False, rand_augment=True, cache=cache): ds = tf.data.TFRecordDataset(fnames, buffer_size=1<<24, ## 16mb num_parallel_reads=tf.data.AUTOTUNE) ds = ds.map( make_tfr_parser(layout_fn), num_parallel_calls=tf.data.AUTOTUNE ) ## parse into rec dict ds = ds.map( get_data_func, num_parallel_calls=tf.data.AUTOTUNE ) ## get (inp,tar,valid,vspr,id_) ds = ds.filter( lambda inp,tar,valid,vspr,id_:valid) ## remove invalid samples ds = ds.map( lambda inp,tar,valid,vspr,id_: (inp, tar, vspr, id_)) ## drop valid flag (inp,tar,vspr,id_) if cache and cache != 'disabled': if cache == 'ram': ds = ds.cache() else: ## cache needs to be data-set specific (fnames = train|va`l) hexhash = hex(int(hashlib.sha1(str(fnames).encode("utf-8")).hexdigest(), 16) % (10 ** 8)) cache_path = f'{cache}.{hexhash}' print('cache>', cache_path, 'found' if os.path.exists(cache_path+'.index') else 'building!') ds = ds.cache(cache_path) if shuffle: ds = ds.shuffle(buffer_size=nshuffle) ds = augment_func(ds, rand_augment) if not retain_all: ds = ds.map( lambda inp,tar,vspr,id_: (inp, tar)) ## only retain (inp, target) ds = ds.batch(batch_size) ds = ds.prefetch(tf.data.AUTOTUNE) print('ds>', ds) return ds ## train_fnames, val_fnames, layout_fn = get_fnames_func() train_ds = load_dataset(train_fnames, layout_fn, shuffle=randomize_train, rand_augment=randomize_train, cache=cache) val_ds = load_dataset(val_fnames, layout_fn, shuffle=False, rand_augment=False, cache=cache) return train_ds, val_ds if __name__ == '__main__': import argparse import matplotlib.pyplot as plt from os.path import join, dirname, exists, basename from prepro import PrePro parser = argparse.ArgumentParser() parser.add_argument('-tf','--tfr', default=None, help='tf-record') args = parser.parse_known_args()[0] prpr = PrePro() def plot(xx): def reduce(xx): xx = np.abs(xx) xx = np.mean(xx,3) xx = xx[:,:,:,0]+xx[:,:,:,1] xx = np.mean(xx,2) print('plot>', xx.shape) return xx[:,0:1] f, (ax1, ax2) = plt.subplots(2, 1) #, sharey=True) if xx.shape[-1]==63: ## raw-prepro xx,yy = prpr(sess, xx) ax1.plot(reduce(xx)) ax2.plot(reduce(yy)) else: ax1.plot(reduce(xx)) # ax2.imshow(xx.transpose(1,0)) plt.show() return rr=TfrReader(dirname(args.tfr)+'/layout.pkl', args.tfr) with tf.Session() as sess: tf.global_variables_initializer().run() tf.local_variables_initializer().run() rr.init(sess) sample_nn=0 while True: try: batch = sess.run([rr.read_batch()])[0] nn=batch['sel'].shape[0] #print('batch>', batch.keys(), nn) for kk in range(0,nn): for ky,vl in batch.items(): if ky == 'data': #print(ky, vl.shape) plot(vl[0]) else: #print(ky, vl) pass except tf.errors.OutOfRangeError: # print 'END OF TFR' break