import math import torch import numpy as np from utils import CHART_TYPE_MAP, PAD_IDX from .base import ( BaseDataset, get_text_window, stack_dict, stack_tensor_dict, shift_right_vec, shift_tokens_right ) ''' Offset_cont mode predict == > [x0,x1] [x1-x0,x2-x1], [x2-x1,y2-y1] ... scale == > [mean(d1), std(d1), mean(d1), std(x1)] #Multi choice learners ''' class PmcDataDataset(BaseDataset): def __init__(self, data, tokenizer=None, chart_tasks=None, scale_mode='log10', scale_eps=1.00001, scale_floor=-1.0, window_size=16, widen_rate=0.0, max_widen_len=1, min_sent_len=10, max_source_len=1024, max_target_len=256, max_source_len2=8, max_target_len2=8, pad_to_max_len=True, ignore_pad_token_for_loss=True, norm_mode='minmax', discrete_input=None, tasks=['categorical','series_name','axis','caption','data'], sep_token='', **kwargs): super().__init__(tokenizer, window_size, widen_rate, max_widen_len, min_sent_len, max_source_len, max_target_len, pad_to_max_len, ignore_pad_token_for_loss) # Filter the pmc dataset for certain tasks or charts self.chart_tasks = ['task6'] if chart_tasks is None else chart_tasks self.active_charts = [] self.tokenizer = tokenizer self.max_source_len2 = max_source_len if max_source_len2 is None else max_source_len2 self.max_target_len2 = max_target_len if max_target_len2 is None else max_target_len2 self.discrete_input = discrete_input self.sep_token = sep_token self.tasks = tasks self._data = data self.data = self.filter_data(data) assert scale_mode in ['log10', 'log'], "scale mode not implemented" self.scale_mode = math.log10 if scale_mode == 'log10' else math.log self.scale_eps = scale_eps self.scale_floor = scale_floor self.box_plot_keys = ['min', 'first_quartile', 'median', 'third_quartile', 'max'] self.node_map = {'pad': PAD_IDX, 'point': 1, 'eos': 0} self.del_list = [] assert norm_mode in ['minmax','offset'] self.norm_mode = norm_mode def filter_data(self, data): task_req = [] if isinstance(self.chart_tasks, list): assert 'task6' in self.chart_tasks task_req = self.chart_tasks data = [d for d in data if all(t in d and d[t] != None for t in task_req)] if len(self.active_charts): data = [d for d in data if d['task1']['output']['chart_type'].lower().strip() in self.active_charts] return data def __len__(self): return len(self.data) def get_data_with_idx(self, index): if isinstance(index, list): return [self.data[i % len(self.data)] for i in index] elif isinstance(index, int): return self.data[index % len(self.data)] else: raise ValueError("Invalid index given: {}".format(index)) def __getitem__(self, index): while True: d = self.data[index % len(self.data)] #Get context fig_id = d['fig_id'] context_start = d['fig_index'][fig_id][0] all_text = d['all_text'] captions = d['captions'] #Remove caption from text all_text = self.preprocess_all_text(all_text) caption_label, all_text, context_start = self.get_caption_label( captions, all_text, context_start) context = get_text_window( all_text, context_start, tgt_token=self.tgt_token, window_size=self.window_size, widen_rate=self.widen_rate, max_widen_len=self.max_widen_len, min_sent_len=self.min_sent_len) ###################### # Prepare inputs ###################### input_text = '' if 'captions' in self.discrete_input: input_text += caption_label if 'context' in self.discrete_input: input_text += context outputs = {} context_inputs = None if self.tokenizer is not None: context_inputs, _ = self._tokenize(self.tokenizer, input_text, max_source_len=self.max_source_len, max_target_len=self.max_target_len, is_target=False) outputs['context'] = context_inputs outputs['chart_type'] = self.get_chart_type(d) outputs['chart_data'] = self.preprocess_data_series(d) if self.enforce_data_to_chart_type(outputs): break else: index += 1 return index, outputs def collate_fn(self, list_batch): list_idx = [b[0] for b in list_batch] list_batch = [b[1] for b in list_batch] collated_batch = stack_dict(list_batch) if self.tokenizer is not None: collated_batch['context'] = self.batch_tokens(collated_batch['context']) collated_batch['chart_data'] = self.collate_data(collated_batch['chart_data']) return list_idx, collated_batch def batch_tokens(self, list_batch): dict_of_tokens = stack_dict(list_batch) for key in list(dict_of_tokens.keys()): dict_of_tokens[key] = torch.cat(dict_of_tokens[key], dim=0) return dict_of_tokens def get_chart_type(self, d): '''Returns an index''' ct = d['task1']['output']['chart_type'].lower().strip() return CHART_TYPE_MAP.index(ct) def preprocess_data_series(self, d): output = {} output['unnorm_series'] = [] output['norm_series'] = [] output['unnorm_scale'] = {} output['chart_type'] = {} output['data_type'] = {} chart_type = d['task1']['output']['chart_type'].lower().strip() output['chart_type'] = chart_type #For normalising the data for s in ['y', 'x']: output['unnorm_scale'][s] = {} if self.norm_mode == 'offset': output['unnorm_scale'][s]['all'] = [] #Used to compute mean and scales elif self.norm_mode == 'minmax': output['unnorm_scale'][s]['min'] = [float('inf')] * len(d['task6']['output']['data series']) output['unnorm_scale'][s]['max'] = [-float('inf')] * len(d['task6']['output']['data series']) else: raise NotImplementedError() #First loop to collect data from json. data is organised for scaling later first_v = {} #This goes through the series for ds_idx, series_data in enumerate(d['task6']['output']['data series']): unnorm_series = {} unnorm_series['name'] = None unnorm_series['data'] = [] if 'unnamed' not in series_data['name']: unnorm_series['name'] = series_data['name'] prev_v = {} #This goes through each data point for pt_idx, data in enumerate(series_data['data']): series_keys = list(data.keys()) #ignore series that do not have 'x', 'y' or 'y2' if ('x' not in series_keys) or \ ('y' not in series_keys and \ 'y2' not in series_keys): continue #Replace y2 with y where possible y2_replacement_flag = False if 'y2' in series_keys and 'y' not in series_keys and 'x' in series_keys: y2_replacement_flag = True pt_store = {} #This goes through each property of each data point for _, (k, v) in enumerate(data.items()): #Replace y2 with y where possible if 'y2-' in k: continue if y2_replacement_flag and k == 'y2': k = 'y' data_type = type(v) output['data_type'][k] = data_type #Remove spaces if data_type == str: pt_store[k] = v.replace('\n',' ') #Store the data AND update max, min for each series elif isinstance(v, (float, int)): # Assign keys to each grouping if k in ['y', 'y2', 'first_quartile', 'min', 'max', 'third_quartile', 'median']: scale_key = 'y' elif 'x' == k: scale_key = 'x' else: raise ValueError("Unsupported axis key: {}".format(k)) #Use offset mode (mean,std) or not min,range if self.norm_mode == 'minmax': if v < -self.scale_eps[0]: v = 0 pt_store[k] = v output['unnorm_scale'][scale_key]['min'][ds_idx] = min(output['unnorm_scale'][scale_key]['min'][ds_idx], v) output['unnorm_scale'][scale_key]['max'][ds_idx] = max(output['unnorm_scale'][scale_key]['max'][ds_idx], v) elif self.norm_mode == 'offset': pt_store[k] = v output['unnorm_scale'][scale_key]['all'].append(v) else: raise NotImplementedError(self.norm_mode) if len(pt_store) > 1: unnorm_series['data'].append(pt_store) keys = list(pt_store.keys()) assert True if 'y' in keys else 'y2' not in keys, "{}, {}, {}".format(series_keys, keys, y2_replacement_flag) output['unnorm_series'].append(unnorm_series) ########################################## #Second loop to calculate statistics on collected data and then normalise data. e.g. min max or mean std output['norm_scale'] = {} #Compute the mean and std if self.norm_mode == 'offset': for s in list(output['unnorm_scale'].keys()): all_data = np.array(output['unnorm_scale'][s]['all']) if len(all_data): d_mean = np.mean(all_data) d_std = np.std(all_data) #Clip means lower than eps d_mean = max(d_mean, -self.scale_eps[0] + 0.0001) scale_mean = self.scale_mode(d_mean + self.scale_eps[0]) scale_std = self.scale_mode(d_std + self.scale_eps[1]) output['norm_scale'][s] = { 'mean': d_mean, 'std': d_std, 'scale_mean': scale_mean, 'scale_std': scale_std, } elif self.norm_mode == 'minmax': for s in ['y', 'x']: series_count = len(output['unnorm_scale'][s]['min']) min_container, max_container = [],[] for ds_idx in range(series_count): if output['unnorm_scale'][s]['min'][ds_idx] < float('inf') and \ output['unnorm_scale'][s]['max'][ds_idx] > -float('inf'): # and \ #output['unnorm_scale'][s]['min'][ds_idx] != output['unnorm_scale'][s]['max'][ds_idx]: if s not in output['norm_scale']: output['norm_scale'][s] = {} for n in ['min','max','range']: output['norm_scale'][s][n] = [None] * series_count output['norm_scale'][s]['min'][ds_idx] = self.scale_mode(output['unnorm_scale'][s]['min'][ds_idx] + self.scale_eps[0]) output['norm_scale'][s]['max'][ds_idx] = self.scale_mode(output['unnorm_scale'][s]['max'][ds_idx] + self.scale_eps[1]) min_container.append(output['unnorm_scale'][s]['min'][ds_idx]) max_container.append(output['unnorm_scale'][s]['max'][ds_idx]) #Calculate the scale before the floor to prevent out of domain errors scale_range = output['unnorm_scale'][s]['max'][ds_idx] - output['unnorm_scale'][s]['min'][ds_idx] output['norm_scale'][s]['range'][ds_idx] = self.scale_mode(scale_range + self.scale_eps[1]) output['norm_scale'][s]['min'][ds_idx] = max(output['norm_scale'][s]['min'][ds_idx], self.scale_floor[0]) output['norm_scale'][s]['range'][ds_idx] = max(output['norm_scale'][s]['range'][ds_idx], self.scale_floor[1]) #Check if any is none, then just average the rest for ds_idx in range(series_count): if s in output['norm_scale']: if output['norm_scale'][s]['min'][ds_idx] is None: avg_min = np.array(min_container).mean() avg_max = np.array(max_container).mean() avg_range = avg_max - avg_min #output['norm_scale'][s]['max'][ds_idx] = self.scale_mode(avg_max + self.scale_eps[0]) output['norm_scale'][s]['min'][ds_idx] = max(self.scale_mode(avg_min + self.scale_eps[0]), self.scale_floor[0]) output['norm_scale'][s]['range'][ds_idx] = max(self.scale_mode(avg_range + self.scale_eps[1]), self.scale_floor[1]) for ds_idx, series in enumerate(output['unnorm_series']): norm_series = {} norm_series['name'] = series['name'] norm_series['data'] = [] for pt_idx, x_data in enumerate(series['data']): norm_x = {} for _, (k, v) in enumerate(x_data.items()): if isinstance(v, (float, int)): if k in ['y', 'y2', 'first_quartile', 'min', 'max', 'third_quartile', 'median']: scale_key = 'y' elif 'x' == k: scale_key = 'x' else: raise NotImplementedError(k) if self.norm_mode == 'minmax': min_val = output['unnorm_scale'][scale_key]['min'][ds_idx] minmax = output['unnorm_scale'][scale_key]['max'][ds_idx] - output['unnorm_scale'][scale_key]['min'][ds_idx] v = (v - min_val) / minmax if minmax > 0 else 0 elif self.norm_mode == 'offset': # if ds_idx == 0 and pt_idx == 0 and k in ['min', 'x', 'y', 'y2']: # v = 0.0 # else: mean = output['norm_scale'][scale_key]['mean'] std = output['norm_scale'][scale_key]['std'] v = (v - mean) / std if std > 0 else 0.0 norm_x[k] = v if len(norm_x): norm_series['data'].append(norm_x) if len(norm_series['data']): output['norm_series'].append(norm_series) return output def enforce_data_to_chart_type(self, outputs): flag = False if len(outputs['chart_data']['norm_series']) == 0 or len(outputs['chart_data']['norm_series'][0]['data']) == 0: flag = False elif len(outputs['chart_data']['norm_scale']) == 0: flag = False elif outputs['chart_data']['chart_type'] == 'vertical box': flag = True elif outputs['chart_data']['chart_type'] in ['line', 'scatter'] and len(outputs['chart_data']['norm_scale']) == 2: flag = all(v==float for k, v in outputs['chart_data']['data_type'].items()) elif outputs['chart_data']['chart_type'] in ['vertical bar', 'horizontal bar'] and len(outputs['chart_data']['norm_scale']) == 1: flag = True if outputs['chart_data']['data_type']['x'] == str else False else: flag = False #raise ValueError("chart type not recognized: {}".format(outputs['data']['chart_type'])) return flag def collate_captions(self, list_batch): collated_batch = stack_dict(list_batch) for key in list(collated_batch.keys()): if isinstance(collated_batch[key], torch.Tensor): collated_batch[key] = torch.stack(collated_batch[key], dim=0) return collated_batch def tokenize_text_batch(self, padded_batch_txt_tgt, depth=3): all_txt_inputs, all_txt_labels = {},{} for padded_series_txt_tgt in padded_batch_txt_tgt: length = len(padded_series_txt_tgt) collector = [] for txt_raw in padded_series_txt_tgt: if depth == 2: collector += [txt_raw] if len(collector) < length: continue else: txt_raw = collector txt_inputs, txt_labels = self._tokenize( self.tokenizer, txt_raw, max_source_len=self.max_source_len2, max_target_len=self.max_target_len2, is_target=True, shift_right=True) for k, v in txt_inputs.items(): if k not in all_txt_inputs: all_txt_inputs[k] = [] all_txt_inputs[k] += [v] for k, v in txt_labels.items(): if k not in all_txt_labels: all_txt_labels[k] = [] all_txt_labels[k] += [v] #all_txt_inputs = stack_tensor_dict(all_txt_inputs) #all_txt_labels = stack_tensor_dict(all_txt_labels) ################################################ #Pad and batch #def pad_n_batch() label = all_txt_labels['input_ids'] col_counts = [l.size(0) for l in label] max_token_len = label[0].size(1) padded_labels = [] for lbl in label: pad_len = max(col_counts) - lbl.size(0) if pad_len > 0: pad = torch.ones((pad_len, max_token_len), dtype=torch.int32) * self.tokenizer.pad_token_id lbl = torch.cat([lbl, pad], dim=0) padded_labels.append(lbl) label = torch.stack(padded_labels, dim=0) label[label == self.tokenizer.pad_token_id] = PAD_IDX all_txt_labels['input_ids'] = label return all_txt_inputs, all_txt_labels def collate_data(self, batch_data): batch_chart_type = [] batch_node_type = [] batch_node_mask = [] batch_reg_targets = [] batch_reg_mask = [] batch_text_targets = [] batch_text_mask = [] max_node_len = 0 max_series_len = 0 batch_name_targets = [] batch_name_mask = [] batch_scale_tgt = [] max_series_len = max(len(d['norm_series']) for d in batch_data) for data in batch_data: series_node_type = [] series_node_mask = [] series_reg_tgt = [] series_reg_mask = [] series_txt_tgt = [] series_txt_mask = [] series_name_tgt = [] series_name_mask = [] series_scale_tgt = [] chart_type = data['chart_type'] batch_chart_type.append(chart_type) for s_idx, series in enumerate(data['norm_series']): node_type, node_mask = [], [] reg_targets, reg_mask = [], [] txt_tgt, txt_mask = [], [] #TARGET 1 Scales scale_flag = True scale_tgt = [] for s in ['x', 'y']: if s in data['norm_scale']: if self.norm_mode == 'minmax': scale_tensor = torch.tensor([data['norm_scale'][s]['min'][s_idx], data['norm_scale'][s]['range'][s_idx]], dtype=torch.float32) elif self.norm_mode == 'offset': scale_tensor = torch.tensor([data['norm_scale'][s]['scale_mean'], data['norm_scale'][s]['scale_std']], dtype=torch.float32) scale_tgt += scale_tensor if None in scale_tgt: scale_flag = False break #Check scales are correct. If not remove the series if not scale_flag: continue if len(scale_tgt): scale_tgt = torch.stack(scale_tgt, dim=-1).view(1, -1) series_scale_tgt += scale_tgt #TARGET 2: Series name = 'asdf' if series['name'] is not None: series_name_tgt.append(series['name']) series_name_mask.append(1) else: series_name_tgt.append('') series_name_mask.append(0) #TARGET 3: Sequence of points ( node type) prev_pt = None for pidx, point in enumerate(series['data']): #TARGET 4: Regression data if data['chart_type'] == 'vertical box': # Original: [min, first-quartile, median, third-quartile, max] # Prediction: [min_val, first_to_min, median_to_min, third_to_first, max_to_third] # Prediction head must contain relu final layer min_val = point['min'] first_to_min = point['first_quartile'] - point['min'] median_to_min = point['median'] - point['min'] third_to_first = point['third_quartile'] - point['first_quartile'] max_to_third = point['max'] - point['third_quartile'] reg_tgt = [min_val, first_to_min, median_to_min, third_to_first, max_to_third] elif data['chart_type'] in ['vertical bar', 'horizontal bar']: if pidx == 0 or self.norm_mode == 'minmax': reg_tgt = [point[k] for k in ['y']] else: reg_tgt = [point[k] - prev_pt[k] for k in ['y']] elif data['chart_type'] in ['line', 'scatter']: if pidx == 0 or self.norm_mode == 'minmax': reg_tgt = [point[k] for k in ['x', 'y']] else: reg_tgt = [point[k] - prev_pt[k] for k in ['x', 'y']] else: raise NotImplementedError("Invalid chart given") #save for offsetting prev_pt = point node_type.append(self.node_map['point']) node_mask.append(1) #TARGET 5: TEXT data if data['chart_type'] not in ['line', 'scatter'] and \ point.get('x') is not None and \ len(point.get('x')) and \ isinstance(point.get('x'), str): txt_tgt.append(point['x']) txt_mask.append(1) else: txt_tgt.append('') txt_mask.append(0) reg_targets.append(reg_tgt) reg_mask.append(1) node_type.append(self.node_map['eos']) node_mask.append(1) reg_len = len(reg_targets[-1]) reg_targets.append([0.] * reg_len) reg_mask.append(0) series_node_type += [node_type] series_node_mask += [node_mask] series_reg_tgt += [reg_targets] series_reg_mask += [reg_mask] series_txt_tgt += [txt_tgt] series_txt_mask += [txt_mask] #Ensure all are the same length cur_node_len = len(node_type) for idx, l in enumerate([node_type, node_mask, reg_targets, reg_mask]): assert len(l) == cur_node_len, "l={} idx={} cur_node_len={}".format(l, idx, cur_node_len) if cur_node_len > max_node_len: max_node_len = cur_node_len batch_node_type += [series_node_type] batch_node_mask += [series_node_mask] batch_reg_targets += [series_reg_tgt] batch_reg_mask += [series_reg_mask] batch_text_targets += [series_txt_tgt] batch_text_mask += [series_txt_mask] batch_name_targets += [series_name_tgt] batch_name_mask += [series_name_mask] #Stacks by series if self.norm_mode == 'minmax': batch_scale_tgt += [torch.stack(series_scale_tgt, dim=0)] elif self.norm_mode == 'offset': batch_scale_tgt += [series_scale_tgt[0]] padded_batch_node_type = [] padded_batch_node_mask = [] padded_batch_reg_tgt = [] padded_batch_reg_mask = [] padded_batch_txt_tgt = [] padded_batch_scale_tgt = [] padded_batch_scale_mask = [] #Padding to ensure whole batch is same length for series_node_type, series_node_mask, series_reg_tgt, series_reg_mask, series_txt_tgt, series_txt_mask, series_scale_tgt in zip( batch_node_type, batch_node_mask, batch_reg_targets, batch_reg_mask , batch_text_targets, batch_text_mask, batch_scale_tgt ): padded_series_node_type = [] padded_series_node_mask = [] padded_series_reg_tgt = [] padded_series_reg_mask = [] padded_series_txt_tgt = [] padded_series_txt_mask = [] #max_token_len = 0 ### Pad by column for idx, (node_type, node_mask, reg_tgt, reg_mask, txt_tgt, txt_mask) in enumerate(zip( series_node_type, series_node_mask, series_reg_tgt, \ series_reg_mask, series_txt_tgt, series_txt_mask )): pad_node_len = max_node_len - len(node_type) assert sum(reg_mask) > 0, "must be more than zero" if pad_node_len > 0: node_type += [self.node_map['pad']] * pad_node_len node_mask += [0] * pad_node_len mask_pad = [0] * pad_node_len reg_mask += mask_pad txt_mask += mask_pad reg_len = len(reg_tgt[-1]) reg_pad = [[0.] * reg_len for _ in range(pad_node_len)] reg_tgt += reg_pad padded_series_node_type += [node_type] padded_series_node_mask += [node_mask] padded_series_reg_tgt += [reg_tgt] padded_series_reg_mask += [reg_mask] if idx == 0: padded_series_txt_tgt += [txt_tgt] padded_series_txt_mask += [txt_mask] #categorical text data is always the same, except the series name # txt_tgt = ['text1', 'text2', '', ''] # txt_tgt_ids = [[2,3,4,0], [1,2,0,0], [1,0,0,0], [1,0,0,0]] #max_token_len = max(max_token_len, len(txt_tgt_ids[0])) pad_series_len = max_series_len - len(padded_series_node_type) pad_node_len = max(len(p) for p in padded_series_node_type) reg_len = len(padded_series_reg_tgt[0][0]) padded_series_reg_mask = torch.tensor(padded_series_reg_mask, dtype=torch.int32) padded_series_node_mask = torch.tensor(padded_series_node_mask, dtype=torch.int32) assert padded_series_reg_mask.sum() > 0, "mask is zero" for _ in range(pad_series_len): padded_series_node_type += [[self.node_map['eos']] + [self.node_map['pad']] * (pad_node_len-1)] reg_pad = [[0.] * reg_len for _ in range(pad_node_len)] padded_series_reg_tgt += [reg_pad] if pad_series_len > 0 and pad_node_len > 0: pad_mask = torch.zeros((pad_series_len, pad_node_len), dtype=torch.long) padded_series_reg_mask = torch.cat([padded_series_reg_mask, pad_mask], dim=0) padded_series_node_mask = torch.cat([padded_series_node_mask, pad_mask], dim=0) ##################################### # SCALE PADDING #Pad the scales and make a mask series_scale_mask = None if self.norm_mode == 'minmax': series_len, scale_dim = series_scale_tgt.shape series_scale_mask = torch.ones((series_len), dtype=torch.int32) pad_len = max_series_len - series_len if pad_len > 0: pad_scale_tgt = torch.zeros((pad_len, scale_dim), dtype=torch.int32) series_scale_tgt = torch.cat([series_scale_tgt, pad_scale_tgt], dim=0) pad_scale_mask = torch.zeros((pad_len), dtype=torch.int32) series_scale_mask = torch.cat([series_scale_mask, pad_scale_mask], dim=0) padded_series_reg_tgt = torch.tensor(padded_series_reg_tgt, dtype=torch.float32) padded_batch_node_type += [padded_series_node_type] padded_batch_node_mask += [padded_series_node_mask] padded_batch_reg_tgt += [padded_series_reg_tgt] padded_batch_reg_mask += [padded_series_reg_mask] padded_batch_txt_tgt += [padded_series_txt_tgt] padded_batch_scale_tgt += [series_scale_tgt] padded_batch_scale_mask += [series_scale_mask] #Shift right all sequence to sequence problems (text, name) if self.tokenizer is not None: txt_inputs, txt_labels = self.tokenize_text_batch(padded_batch_txt_tgt) name_inputs, name_labels = self.tokenize_text_batch(batch_name_targets, depth=2) #Stack non sequence problems inputs = {} inputs['chart_type'] = batch_chart_type #### Continuous data (Can be different dimension depending on chart type. cannot stack these.) inputs['scale'] = {} inputs['scale']['inputs_embeds'] = padded_batch_scale_tgt #batch_scale_tgt #torch.stack(batch_scale_tgt, dim=0) inputs['scale']['attention_mask'] = padded_batch_scale_mask #batch_scale_tgt #torch.stack(batch_scale_tgt, dim=0) if self.norm_mode == 'minmax': inputs['scale']['decoder_inputs_embeds'] = [shift_right_vec(inp, start_values=0.0) for inp in inputs['scale']['inputs_embeds']] inputs['scale']['decoder_attention_mask'] = [shift_right_vec(inp, start_values=0.0) for inp in inputs['scale']['attention_mask']] inputs['continuous'] = {} inputs['continuous']['inputs_embeds'] = padded_batch_reg_tgt # torch.stack(padded_batch_reg_tgt, dim=0) attn_mask = torch.stack(padded_batch_reg_mask, dim=0) assert attn_mask.sum() > 0, "must be atleast reg :{}".format(attn_mask) inputs['continuous']['attention_mask'] = attn_mask inputs['continuous']['decoder_inputs_embeds'] = [shift_right_vec(inp, start_values=0.0) for inp in inputs['continuous']['inputs_embeds']] inputs['continuous']['decoder_attention_mask'] = [shift_right_vec(inp, start_values=0.0) for inp in inputs['continuous']['attention_mask']] #### Discrete data if self.tokenizer is not None: inputs['categorical'] = txt_inputs inputs['categorical']['label'] = txt_labels inputs['categorical']['raw'] = padded_batch_txt_tgt inputs['series_name'] = name_inputs inputs['series_name']['label'] = name_labels inputs['series_name']['raw'] = batch_name_targets inputs['node'] = {} inputs['node']['input_ids'] = torch.tensor(padded_batch_node_type, dtype=torch.int32) inputs['node']['attention_mask'] = torch.stack(padded_batch_node_mask, dim=0) #torch.tensor(padded_batch_node_mask, dtype=torch.long) #inputs['node']['decoder_input_ids'] = torch.stack([shift_tokens_right(inp, self.node_map['pad'], dim=-1) for inp in inputs['node']['input_ids']], dim=0) inputs['labels'] = {} inputs['labels']['col'] = inputs['node']['input_ids'][:,0,:] inputs['labels']['row'] = inputs['node']['input_ids'][:,:,0] return inputs