# --------------------------------------------------------------- # Copyright (c) __________________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- from ..constant import SCALE_DIMS, REG_DIMS, UNIQ_CHART_HEADS, CHART_TO_HEAD_IDX import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions.categorical import Categorical from .encoder import Encoder from .decoder import Decoder import numpy as np import math from .helper import ( unflat_tensor, get_chart_type_dict, get_chart_type_from_idx, make_repeat_frame ) class BaseDataModel(nn.Module): def __init__(self, enc_conv, enc_proj1, vq_layer1, dec_conv, dec_proj_col, dec_proj_row, dec_tf_col, dec_tf_row, enc_tf=None, use_mhd=False, mhd_layer=None, enc_proj2=None, enc_proj3=None, vq_layer2=None, emb_dim1=8, emb_len1=4, emb_len2=None, max_series_blocks=5, max_cont_blocks=32, conditional_encoder=False, conditional_decoder=True, scale_mode='log10', scale_eps=1.00001, scale_floor=-1.0, norm_mode='minmax', cont_loss_fn='l1', scale_loss_fn='l1', use_pos_embs=True, hypothese_bsz=2048, hypothese_count=2048, device='cpu', fp16=False, debug=False): super().__init__() if use_mhd: assert enc_proj2 is not None and vq_layer2 is not None and mhd_layer is not None self.debug = debug self.device = device self.use_fp16 = fp16 self.dtype_float = torch.float32 self.dtype_int = torch.int32 self.use_mhd = use_mhd self.norm_mode = norm_mode self.encoder = Encoder( enc_conv=enc_conv, enc_tf=enc_tf, out_dim=emb_dim1, max_series_blocks=max_series_blocks, max_cont_blocks=max_cont_blocks, use_pos_embs=use_pos_embs, norm_mode=norm_mode, chart_type_conditional=conditional_encoder, device=device, debug=debug ) self.enc_proj1 = enc_proj1 self.vq_layer1 = vq_layer1 if self.use_mhd: self.mhd_layer = mhd_layer self.vq_layer2 = vq_layer2 self.up_sample_d = nn.ConvTranspose1d(emb_len2, emb_len1, 1) self.enc_proj2 = enc_proj2 self.enc_proj3 = enc_proj3 self.hypothese_bsz = hypothese_bsz self.hypothese_count = hypothese_count self.decoder = Decoder( dec_conv=dec_conv, dec_proj_col=dec_proj_col, dec_proj_row=dec_proj_row, dec_tf_col=dec_tf_col, dec_tf_row=dec_tf_row, emb_dim1=emb_dim1, cont_loss_fn=cont_loss_fn, scale_loss_fn=scale_loss_fn, max_series_blocks=max_series_blocks, max_cont_blocks=max_cont_blocks, use_mhd=use_mhd, hypothese_bsz=hypothese_bsz, scale_floor=scale_floor, norm_mode=norm_mode, device=device, debug=debug ) #Chart type embeddings self.ct_emb = nn.Embedding(len(UNIQ_CHART_HEADS), emb_dim1) #2. Continuous data -> Scale/Data settings (self.data_embed) self.max_series_blocks = max_series_blocks self.max_cont_blocks = max_cont_blocks self.conditional_decoder = conditional_decoder #positional embeddings self.use_pos_embs = use_pos_embs self.pos_emb_scale = None self.scale_exponent = 10 if scale_mode == 'log10' else math.e self.scale_eps = scale_eps self.scale_floor = scale_floor self.emb_dim1 = emb_dim1 def get_chart_type_emb(self, ct_idx=None, chart_type=None): if ct_idx is None and chart_type is None: raise ValueError("Need to provide either indinces or string") if chart_type is not None: ct_idx = [CHART_TO_HEAD_IDX[ct] for ct in chart_type] ct_idx = torch.tensor(ct_idx, dtype=torch.long, device=self.device).view(-1,1) return self.ct_emb(ct_idx) def run_encoder(self, inputs): return self.encoder(inputs) def quantize(self, encoder_output): outputs = {} loss = {} y1 = self.enc_proj1(encoder_output) c_base, _, cb_loss1, _ = self.vq_layer1(y1) loss['cb1'] = cb_loss1 if self.use_mhd: y2 = self.enc_proj2(encoder_output) y2 = self.enc_proj3(y2) c_i, _, cb_loss2, _ = self.vq_layer2(y2) loss['cb2'] = cb_loss2 outputs['c_base'] = c_base outputs['secondary_latents'] = c_i outputs['latent_tgt'] = y1 outputs['hypothese_count'] = self.hypothese_count else: outputs['y_hat'] = c_base return outputs, loss def get_topk_batch(self, topk_idx, x): x_batch = [] for idx, _x in zip(topk_idx, x): x_batch.append(_x[idx]) return torch.stack(x_batch, dim=0) def run_mhd_layer(self, c_base, hypothese_count, latent_tgt=None, secondary_latents=None, split='train', keep_len_limit=128): #Reshape for mhd layer if secondary_latents is None: inp_latents = c_base else: inp_latents = secondary_latents if latent_tgt is not None: bsz = self.hypothese_bsz else: bsz = hypothese_count num_batches = int(np.ceil(hypothese_count / bsz)) keep_length = int(np.ceil(bsz / num_batches)) repeat_frame = make_repeat_frame(c_base, hypo_bsz=bsz) expanded_c = c_base.unsqueeze(self.mhd_layer.hypothese_dim).repeat(repeat_frame) batch_hypos = [] diff_vec_container = [] for _ in range(num_batches): diff_vector = self.mhd_layer(inp_latents, hypothese_count=bsz) #Flatten along batch if self.up_sample_d is not None: diff_vector = torch.flatten(diff_vector, start_dim=0, end_dim=1) diff_vector = self.up_sample_d(diff_vector) diff_vector = unflat_tensor(diff_vector, bsz) y_hat = expanded_c + diff_vector if latent_tgt is not None: dist_loss, topk_idx = self.mhd_layer.get_dist_loss(y_hat, latent_tgt, topk=keep_length) ### Only keep the top x based on meta loss keep_idx = topk_idx[:,:keep_length] topk_hypos = self.get_topk_batch(keep_idx, y_hat) batch_hypos.append(topk_hypos) diff_vector = self.get_topk_batch(keep_idx, diff_vector) diff_vec_container.append(diff_vector) wta_loss = None mhd_dict = {} mhd_dict['bsz'] = bsz mhd_dict['keep_length'] = keep_length if latent_tgt is not None: batch_hypos = torch.cat(batch_hypos, dim=1)[:,:bsz,:] diff_vector = torch.cat(diff_vec_container, dim=1)[:,:bsz,:] # Recombine into y_hat (final list) y_hat = batch_hypos #Get the winner out of the batched winners dist_loss, topk_idx = self.mhd_layer.get_dist_loss(y_hat, latent_tgt, topk=keep_length) wta_loss, wta_idx = self.mhd_layer.get_wta_loss(dist_loss, topk_idx) mhd_dict['yw_hat'] = self.mhd_layer.get_win_hypo(y_hat, wta_idx) #mhd_dict['yk_hat'] = self.mhd_layer.get_topk_hypo(y_hat, topk_idx) mhd_dict['wta_idx'] = wta_idx mhd_dict['topk_idx'] = topk_idx wta_loss = wta_loss.mean() mhd_dict['c_base'] = expanded_c mhd_dict['diff_vector'] = diff_vector mhd_dict['y_hat'] = y_hat return mhd_dict, wta_loss def run_decoder(self, y_hat, chart_type_dict, cond=None, labels=None, wta_idx=None): # "decoder_inputs_embeds" are the right shifted inputs. # See dataset.continuous.py (lines 599) scale_embd, scale_mask, s_loss = self.encoder.preencode_scale( inputs_embeds=labels['chart_data']['scale']['decoder_inputs_embeds'], attention_mask=labels['chart_data']['scale']['decoder_attention_mask'], chart_type_dict=chart_type_dict, pad_len=self.max_series_blocks, pad_dim=1) cont_embd, cont_mask, c_loss = self.encoder.preencode_continuous( inputs_embeds=labels['chart_data']['continuous']['decoder_inputs_embeds'], attention_mask=labels['chart_data']['continuous']['decoder_attention_mask'], chart_type_dict=chart_type_dict, pad_len=self.max_cont_blocks, pad_dim=2, flatten_series=False) loss = s_loss + c_loss decoder_output, dec_loss, logs = self.decoder( y_hat, chart_type_dict, cond, labels, wta_idx, scale_embd, scale_mask, cont_embd, cont_mask ) dec_loss['na2'] = loss return decoder_output, dec_loss, logs def reconstruct_tab_shape(self, row_logits, col_logits, temp=1.0, eval=False): row_probs = F.softmax(row_logits / temp, dim=-1) col_probs = F.softmax(col_logits / temp, dim=-1) row_sample = Categorical(probs=row_probs).sample() col_sample = Categorical(probs=col_probs).sample() bsz = row_sample.size(0) #Force first series to always be active for stability row_sample[:, 0] = 1 col_sample[:, 0] = 1 #Count from left until 0 is reached. row_count = torch.zeros(bsz, device=self.device, dtype=torch.long) row_stop = torch.ones(bsz, device=self.device, dtype=torch.long) col_count = torch.zeros(bsz, device=self.device, dtype=torch.long) col_stop = torch.ones(bsz, device=self.device, dtype=torch.long) search_len = max(row_sample.size(1), col_sample.size(1)) for idx in range(search_len): if idx < row_sample.size(1): row = row_sample[:, idx] row_stop *= row row_count += (row_stop * row) if idx < col_sample.size(1): col = col_sample[:, idx] col_stop *= col col_count += (col_stop * col) if (row_stop.sum(-1) + col_stop.sum(-1)) == 0: break row_count = torch.where(row_count == 0, torch.ones_like(row_count), row_count) col_count = torch.where(col_count == 0, torch.ones_like(col_count), col_count) samples = {'row': row_sample, 'col': col_sample} tab_shape = {'row': row_count, 'col': col_count} return tab_shape, samples def eval_loop(self, inputs, temp=1.0, split='eval', return_labels=True): assert temp > 0.0 and temp <= 1.0, "invalid temp given" samples, loss_dict, metric_logs = self.train_loop(inputs) return samples, loss_dict, metric_logs def batch_outputs(self, outputs, to_np=True): #Batches for easy index and converts to numpy #batched_outputs = [{} for _ in range(bsz)] #[{}] * bsz batched_outputs = [] for k, v in outputs.items(): for bidx, vi in enumerate(v): if isinstance(vi, torch.Tensor): vi = vi.detach().cpu().tolist() if to_np: vi = vi.numpy() if len(batched_outputs) <= bidx: batched_outputs.append({}) batched_outputs[bidx][k] = vi return batched_outputs def sample_codebook(self, inputs): encoder_output, _ = self.run_encoder(inputs) y1 = self.enc_proj1(encoder_output) cb_ind1, _ = self.vq_layer1.get_code_indices(y1) cb_ind1 = cb_ind1.view(y1.size(0), -1) if not self.use_mhd: return (cb_ind1, ) else: y2 = self.enc_proj2(encoder_output) y2 = self.enc_proj3(y2) cb_ind2, _ = self.vq_layer2.get_code_indices(y2) cb_ind2 = cb_ind2.view(y2.size(0), -1) return (cb_ind1, cb_ind2, ) def reconstruct_from_indices(self, ct_idx, cb_ind1, cb_ind2=None, hypo_count=4, hypo_bsz=4, temp=1.0, greedy=True): # Obtain embeddings from vq codebooks c_base = self.vq_layer1(cb_ind1, is_indices=True) c_i = None if self.use_mhd: c_i = self.vq_layer2(cb_ind2, is_indices=True) decoder_input, _ = self.run_mhd_layer( c_base=c_base, secondary_latents=c_i, hypothese_count=hypo_count, ) y_hat = decoder_input['y_hat'] y_hat = y_hat[:,0] else: y_hat = c_base # Obtain chart type embedding cond = None if self.conditional_decoder and ct_idx is not None: cond = self.get_chart_type_emb(ct_idx=ct_idx) if len(y_hat.shape) == 4 and self.use_mhd and cond is not None and hypo_bsz > 1: repeat_frame = make_repeat_frame(cond, hypo_bsz=hypo_bsz) cond = cond.unsqueeze(self.mhd_layer.hypothese_dim).repeat(repeat_frame) cond = torch.flatten(cond, start_dim=0, end_dim=1) #Prepend chart type embedding if cond is not None: y_hat = torch.cat([cond, y_hat], dim=1) # For decoder heads chart_type_dict = get_chart_type_from_idx(ct_idx) #Split y_hat into two paths: Bottom and top dec_hidden_col, dec_hidden_row = self.decoder.decode_y_hat(y_hat) ################################ scale_preds, hidden_row = self.generate_scale(dec_hidden_row, chart_type_dict) ################################ cont_preds, hidden_col = self.generate_continuous(dec_hidden_col, chart_type_dict) col_logit, row_logit, _, _ = self.decoder.decode_tab_shape(hidden_col, hidden_row) x_hat = self.reconstruct_x(row_logit, col_logit, scale_preds, cont_preds, eval=True) #Repeat chart type for each hypothesis x_hat['ct_idx'] = ct_idx x_hat['chart_type_dict'] = chart_type_dict return x_hat def train_loop(self, inputs): encoder_output, loss = self.run_encoder(inputs) decoder_input, cb_loss = self.quantize(encoder_output) all_loss = cb_loss all_loss['na1'] = loss wta_idx = None if self.use_mhd: decoder_input, wta_loss = self.run_mhd_layer(**decoder_input) wta_idx = decoder_input.get('wta_idx') all_loss['wta'] = wta_loss ct_cond = None chart_type = inputs['chart_data']['chart_type'] if self.conditional_decoder: ct_cond = self.get_chart_type_emb(chart_type=chart_type) y_hat = decoder_input['y_hat'] if wta_idx is not None: y_hat = self.mhd_layer.get_win_hypo(y_hat, wta_idx) chart_type_dict = get_chart_type_dict(chart_type) decoder_output, dec_loss, logs = self.run_decoder( y_hat=y_hat, chart_type_dict=chart_type_dict, cond=ct_cond, labels=inputs, wta_idx=None ) #Prepare loss and logs for outputs all_loss = {**all_loss, **dec_loss} row_logit = decoder_output['row'] col_logit = decoder_output['col'] scale_preds = decoder_output['scale'] cont_preds = decoder_output['continuous'] #Convert to input shape x_hat = self.reconstruct_x(row_logit, col_logit, scale_preds, cont_preds, eval=False) x_hat['chart_type'] = inputs['chart_data']['chart_type'] x_hat['ct_idx'] = [CHART_TO_HEAD_IDX[ct] for ct in inputs['chart_data']['chart_type']] x_hat['chart_type_dict'] = chart_type_dict return x_hat, all_loss, logs def reconstruct_x(self, row_logit, col_logit, scale_preds, cont_preds, eval=False): tab_shape, tab_samples = self.reconstruct_tab_shape(row_logit, col_logit, eval=eval) #Convert predictions into values (unscaled) #scale_values = self.unscale_scales(scale_preds, tab_shape) #cont_values = self.unscale_continuous(cont_preds, scale_values, tab_shape) scale_embeds, scale_mask = self.embed_scale(scale_preds, tab_shape, tab_samples) cont_embeds, cont_mask = self.embed_continuous(cont_preds, tab_shape, tab_samples) samples = {} samples['shape'] = {} samples['shape']['counts'] = tab_shape samples['shape']['embeds'] = tab_samples samples['scale'] = {} #samples['scale']['output'] = scale_values samples['scale']['inputs_embeds'] = scale_embeds samples['scale']['attention_mask'] = scale_mask samples['continuous'] = {} #samples['continuous']['output'] = cont_values samples['continuous']['inputs_embeds'] = cont_embeds samples['continuous']['attention_mask'] = cont_mask return samples def forward(self, inputs, is_train=True, temp=1.0, **kwargs): if is_train: return self.train_loop(inputs) else: return self.eval_loop(inputs, temp) def generate_scale(self, dec_hidden_row, chart_type_dict): bsz = dec_hidden_row.size(0) scale_embd = None scale_preds = [[] for _ in range(bsz)] ########################################### # Create starting vector for head_name, ind in chart_type_dict.items(): head_dim = SCALE_DIMS[self.norm_mode][head_name] zero_vec = torch.zeros([1, head_dim], device=self.device, dtype=self.dtype_float) for bidx in ind: scale_preds[bidx].append(zero_vec.clone().detach()) for sidx in range(self.max_series_blocks): #Expand scale embd by one new_embd = torch.zeros([bsz, 1, self.emb_dim1], device=self.device, dtype=self.dtype_float) scale_embd = torch.cat([scale_embd, new_embd], dim=1) if scale_embd is not None else new_embd #Encode scale_pred and assign new values to embedding for head_name, ind in chart_type_dict.items(): scale_x = torch.stack([v[-1] for bidx, v in enumerate(scale_preds) if bidx in ind], dim=0) #[:, -1:, :] scale_enc = self.encoder.cont_encoder['scale'][head_name](scale_x) scale_embd[ind,-1:,:] = scale_enc #Add positional embeddings pos_emb = self.encoder.pos_emb_scale[:,sidx:sidx + 1,:] if self.encoder.pos_emb_scale is not None else 0.0 scale_embd[:,-1:,:] += pos_emb # Generate rows starting with zeros hidden_state = self.decoder.dec_tf_row( inputs_embeds=scale_embd, encoder_hidden_states=dec_hidden_row ).last_hidden_state scale_pred, _, _ = self.decoder.decode_scale(hidden_state, chart_type_dict) #Collect only the last prediction for each instance for bidx, p in enumerate(scale_pred): scale_preds[bidx].append(p[-1:, :]) #Stack each column for bidx in range(bsz): scale_preds[bidx] = torch.cat(scale_preds[bidx], dim=0) return scale_pred, hidden_state def generate_continuous(self, dec_hidden_col, chart_type_dict): bsz = dec_hidden_col.size(0) hidden_col = [] cont_embd = torch.zeros([bsz, self.max_series_blocks, self.max_cont_blocks, self.emb_dim1], device=self.device, dtype=self.dtype_float) cont_preds = [[] for _ in range(bsz)] for row_idx in range(self.max_series_blocks): cont_col_preds = [[] for _ in range(bsz)] ########################################### # Create starting vector for head_name, ind in chart_type_dict.items(): head_dim = REG_DIMS[head_name] zero_vec = torch.zeros([1, head_dim], device=self.device, dtype=self.dtype_float) for bidx in ind: cont_col_preds[bidx].append(zero_vec.clone().detach()) for bidx in range(self.max_cont_blocks): ########################################### # Encode vector for head_name, ind in chart_type_dict.items(): cont_x = torch.stack([l[-1] for idx, l in enumerate(cont_col_preds) if idx in ind], dim=0).to(self.device) emb = self.encoder.cont_encoder['continuous'][head_name](cont_x) cont_embd[ind,row_idx:row_idx+1,bidx:bidx+1, :] = emb.unsqueeze(1) #Add positional embeddings to cont_embd pos_emb = self.encoder.pos_emb_cont[:,row_idx:row_idx+1,bidx:bidx+1,:] if self.encoder.pos_emb_cont is not None else 0.0 cont_embd[:,row_idx:row_idx+1, bidx:bidx+1, :] += pos_emb #Create hidden state using h = self.decoder.dec_tf_col( inputs_embeds=cont_embd[:,row_idx,:bidx+1, :], #Take until last encoder_hidden_states=dec_hidden_col ).last_hidden_state for head_name, ind in chart_type_dict.items(): #Obtain predictions using only the last hidden state cont_logits = self.decoder.cont_decoder['continuous'][head_name](h[ind, -1:, :]) cont_pred = torch.sigmoid(cont_logits) for idx, pred in zip(ind, cont_pred): cont_col_preds[idx].append(pred) #Append only the final hidden state hidden_col += [h] #Stack predictions for idx, preds in enumerate(cont_col_preds): p = torch.stack(preds, dim=1) cont_preds[idx].append(p) for bidx in range(bsz): cont_preds[bidx] = torch.cat(cont_preds[bidx], dim=0) hidden_col = torch.stack(hidden_col, dim=1) return cont_preds, hidden_col