# --------------------------------------------------------------- # Copyright (c) __________________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions.categorical import Categorical from .helper import ( make_repeat_frame, unflat_tensor, get_win_hypo ) from utils import PAD_IDX from ..constant import SCALE_DIMS, REG_DIMS from .coder import Coder class Decoder(Coder): def __init__(self, dec_conv, dec_proj_col, dec_proj_row, dec_tf_col, dec_tf_row, emb_dim1, cont_loss_fn, scale_loss_fn, max_series_blocks, max_cont_blocks, norm_mode, scale_floor, use_mhd, hypothese_bsz, hypothese_dim=1, device='cuda:0', debug=False, ): super().__init__() self.dec_tf_col = dec_tf_col self.dec_tf_row = dec_tf_row self.dec_conv = dec_conv self.dec_proj_col = dec_proj_col self.dec_proj_row = dec_proj_row self.dec_col_head = nn.Linear(emb_dim1, 2) self.dec_row_head = nn.Linear(emb_dim1, 2) #continuous embeddings self.cont_decoder = nn.ModuleDict() for name in ['scale', 'continuous']: self.cont_decoder[name] = nn.ModuleDict() for head_name, head_dim in SCALE_DIMS[norm_mode].items(): self.cont_decoder['scale'][head_name] = nn.ModuleList([nn.Linear(emb_dim1, head_dim) for _ in range(1)]) for head_name, head_dim in REG_DIMS.items(): self.cont_decoder['continuous'][head_name] = nn.Linear(emb_dim1, head_dim) self.emb_dim1 = emb_dim1 self.use_mhd = use_mhd self.hypothese_bsz = hypothese_bsz self.hypothese_dim = hypothese_dim self.max_cont_blocks = max_cont_blocks self.max_series_blocks = max_series_blocks self.scale_floor = scale_floor self.device = device self.debug = debug self.dtype_float = torch.float32 self.cont_loss_fn = F.smooth_l1_loss if cont_loss_fn == 'l1' else F.mse_loss self.scale_loss_fn = F.smooth_l1_loss if scale_loss_fn == 'l1' else F.mse_loss self.ce_loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX, reduction='none') self.apply(self._init_weights) def forward(self, y_hat, chart_type_dict, cond, labels, wta_idx, scale_embd, scale_mask, cont_embd, cont_mask): if len(y_hat.shape) == 4 and self.use_mhd and wta_idx is not None: if cond is not None: repeat_frame = make_repeat_frame(cond, hypo_bsz=self.hypothese_bsz) cond = cond.unsqueeze(self.hypothese_dim).repeat(repeat_frame) cond = torch.flatten(cond, start_dim=0, end_dim=1) y_hat = torch.flatten(y_hat, start_dim=0, end_dim=1) if cond is not None: y_hat = torch.cat([cond, y_hat], dim=1) dec_hidden_col, dec_hidden_row = self.decode_y_hat(y_hat) if self.use_mhd and wta_idx is not None: repeat_frame = make_repeat_frame(scale_embd, hypo_bsz=self.hypothese_bsz) scale_embd = scale_embd.unsqueeze(self.hypothese_dim).repeat(repeat_frame) scale_embd = torch.flatten(scale_embd, start_dim=0, end_dim=1) repeat_frame = make_repeat_frame(scale_mask, hypo_bsz=self.hypothese_bsz) scale_mask = scale_mask.unsqueeze(self.hypothese_dim).repeat(repeat_frame) scale_mask = torch.flatten(scale_mask, start_dim=0, end_dim=1) repeat_frame = make_repeat_frame(cont_embd, hypo_bsz=self.hypothese_bsz) cont_embd = cont_embd.unsqueeze(self.hypothese_dim).repeat(repeat_frame) cont_embd = torch.flatten(cont_embd, start_dim=0, end_dim=1) repeat_frame = make_repeat_frame(cont_mask, hypo_bsz=self.hypothese_bsz) cont_mask = cont_mask.unsqueeze(self.hypothese_dim).repeat(repeat_frame) cont_mask = torch.flatten(cont_mask, start_dim=0, end_dim=1) hidden_row = self.dec_tf_row( inputs_embeds=scale_embd, attention_mask=scale_mask, encoder_hidden_states=dec_hidden_row ).last_hidden_state row_len = cont_embd.size(1) hidden_col = [] for ridx in range(row_len): row_cont_embd = cont_embd[:, ridx,:] row_cont_mask = cont_mask[:, ridx,:] h = self.dec_tf_col( inputs_embeds=row_cont_embd, attention_mask=row_cont_mask, encoder_hidden_states=dec_hidden_col ).last_hidden_state hidden_col += [h] hidden_col = torch.stack(hidden_col, dim=1) col_logit, row_logit, tab_loss, tab_logs = self.decode_tab_shape( hidden_col, hidden_row, labels, wta_idx=wta_idx, chart_type_dict=chart_type_dict) loss = {**tab_loss} logs = {**tab_logs} #logs are split by chart data. This is used for metrics outputs = {} outputs['row'] = row_logit outputs['col'] = col_logit scale_pred, scale_loss, scale_logs = self.decode_scale( hidden_row, chart_type_dict, labels, wta_idx=wta_idx) outputs['scale'] = scale_pred loss['scale'] = scale_loss logs['scale'] = scale_logs cont_pred, cont_loss, cont_logs = self.decode_continuous( hidden_col, hidden_row, chart_type_dict, labels, wta_idx=wta_idx) outputs['continuous'] = cont_pred loss['cont'] = cont_loss logs['cont'] = cont_logs if len(y_hat.shape) == 4 and self.use_mhd and wta_idx is not None: keys = list(outputs.keys()) for k in keys: outputs[k] = unflat_tensor(outputs[k], self.hypothese_bsz) return outputs, loss, logs def decode_y_hat(self, y_hat): h = self.dec_conv(y_hat) dec_hidden_col = self.dec_proj_col(h) dec_hidden_row = self.dec_proj_row(h) return dec_hidden_col, dec_hidden_row def decode_tab_shape(self, hidden_col, hidden_row, labels=None, wta_idx=None, chart_type_dict=None): ''' Predict the number of rows and columns ''' ## Check if column is flattened. If not just average through it if len(hidden_col.shape) == 4: hidden_col = hidden_col.mean(1) col_logit = self.dec_col_head(hidden_col) row_logit = self.dec_row_head(hidden_row) loss_dict = {} logs = {'row': {}, 'col': {}} if labels is not None: if self.use_mhd and wta_idx is not None: #, "Must provide winner idx for this" unflat_col_logit = unflat_tensor(col_logit, self.hypothese_bsz) unflat_row_logit = unflat_tensor(row_logit, self.hypothese_bsz) col_logit = get_win_hypo(unflat_col_logit, wta_idx) row_logit = get_win_hypo(unflat_row_logit, wta_idx) labels = labels['chart_data'].get('labels') col_label = labels['col'].to(self.device) row_label = labels['row'].to(self.device) bsz = col_logit.size(0) pad_col_len = col_logit.size(1) - col_label.size(1) pad_row_len = row_logit.size(1) - row_label.size(1) assert pad_col_len >= 0 and pad_row_len >= 0, "column or row label exceed max size. Limit: C={} R={} Labels: C={} R={}".format( col_logit.size(1), row_logit.size(1), labels['col'].size(1), labels['row'].size(1) ) pad_col = torch.ones((bsz, pad_col_len), device=self.device, dtype=torch.long) * PAD_IDX pad_row = torch.ones((bsz, pad_row_len), device=self.device, dtype=torch.long) * PAD_IDX col_label = torch.cat([col_label, pad_col], dim=1) row_label = torch.cat([row_label, pad_row], dim=1) col_label_flat = torch.flatten(col_label, start_dim=0, end_dim=1) row_label_flat = torch.flatten(row_label, start_dim=0, end_dim=1) col_logit_flat = torch.flatten(col_logit, start_dim=0, end_dim=1) row_logit_flat = torch.flatten(row_logit, start_dim=0, end_dim=1) col_nll = self.ce_loss_fn(col_logit_flat, col_label_flat) row_nll = self.ce_loss_fn(row_logit_flat, row_label_flat) col_counts = torch.where(col_label == PAD_IDX, torch.tensor(0.0, device=self.device, dtype=self.dtype_float), torch.tensor(1.0, device=self.device, dtype=self.dtype_float)) row_counts = torch.where(row_label == PAD_IDX, torch.tensor(0.0, device=self.device, dtype=self.dtype_float), torch.tensor(1.0, device=self.device, dtype=self.dtype_float)) loss_dict['col'] = col_nll.sum() / col_counts.sum() loss_dict['row'] = row_nll.sum() / row_counts.sum() if chart_type_dict is not None: # Metrics remove the zeros col_nll_unflat = col_nll.view(bsz, -1).sum(-1) row_nll_unflat = row_nll.view(bsz, -1).sum(-1) col_count_unflat = col_counts.view(bsz, -1).sum(-1) row_count_unflat = row_counts.view(bsz, -1).sum(-1) col_nll_unflat = torch.nan_to_num(col_nll_unflat / col_count_unflat, nan=0.0) row_nll_unflat = row_nll_unflat / row_count_unflat for chart_type, indices in chart_type_dict.items(): logs['col'][chart_type] = col_nll_unflat[indices].detach().cpu() logs['row'][chart_type] = row_nll_unflat[indices].detach().cpu() return col_logit, row_logit, loss_dict, logs def decode_continuous(self, hidden_col, hidden_row, chart_type_dict, labels=None, wta_idx=None): ''' Predict continuous values for each row and column Prepend each column prediction with the particular row ''' is_unflat = len(hidden_col.shape) == 4 if self.use_mhd and wta_idx is not None: hidden_row = unflat_tensor(hidden_row, self.hypothese_bsz) hidden_col = unflat_tensor(hidden_col, self.hypothese_bsz) total_loss = torch.tensor(0.0, device=self.device, dtype=self.dtype_float) all_predictions = [[] for _ in range(hidden_col.size(0))] max_rows = hidden_row.size(1) if labels is not None: cd_label = labels['chart_data'] max_rows = max(emb.size(0) for emb in cd_label['continuous']['inputs_embeds']) #Ensure all heads are used for torchrun for head_name, head in self.cont_decoder['continuous'].items(): if head_name not in chart_type_dict: total_loss += (head(hidden_col) * 0).sum() logs = {} #Loops through each head and creates an embedding based on the number of continous variables and scale variables for head_name, ind in chart_type_dict.items(): logs[head_name] = [] for row_idx in range(max_rows): if is_unflat: hid_col = hidden_col[ind, row_idx, : ] else: hid_col = hidden_col[ind, :] if self.use_mhd and wta_idx is not None: hid_col = torch.flatten(hid_col, start_dim=0, end_dim=1) cont_logits = self.cont_decoder['continuous'][head_name](hid_col) cont_pred = torch.sigmoid(cont_logits) if labels is not None: if self.use_mhd and wta_idx is not None: # "Must provide winner idx for this" unflat_cont_logit = unflat_tensor(cont_pred, self.hypothese_bsz) cont_pred = self.mhd_layer.get_win_hypo(unflat_cont_logit, wta_idx[ind]) label = torch.stack([l[row_idx,:] for idx, l in enumerate(cd_label['continuous']['inputs_embeds']) if idx in ind], dim=0).to(self.device) mask = torch.stack([l[row_idx,:] for idx, l in enumerate(cd_label['continuous']['attention_mask']) if idx in ind], dim=0).to(self.device) #Pad labels to match max len pad_label_len = cont_pred.size(1) - label.size(1) assert pad_label_len >= 0, "Targets contain length={} larger than max={} | Shapes label={} pred={}".format( label.size(1), cont_pred.size(1), label.shape, cont_pred.shape ) pad_label = torch.zeros((label.size(0), pad_label_len, label.size(-1)), device=self.device, dtype=self.dtype_float) pad_mask = torch.zeros((label.size(0), pad_label_len), device=self.device, dtype=self.dtype_float) label = torch.cat([label, pad_label], dim=1) mask = torch.cat([mask, pad_mask], dim=1) loss = self.cont_loss_fn(cont_pred, label, reduction='none').mean(-1) rec_error = F.mse_loss(cont_pred.detach(), label, reduction='none').mean(-1).detach() #Reduce this way to ignore the zeros. Average across batch counts = mask.sum(-1) #Protect against divide by zero counts = torch.where(counts == 0, torch.tensor(0.0, device=self.device, dtype=self.dtype_float), 1 / counts) loss = (loss * mask * counts.unsqueeze(-1)).sum(-1) rec_error = (rec_error * mask * counts.unsqueeze(-1)).sum(-1) c = counts.nonzero() logs[head_name].append(rec_error[c]) total_loss += loss.mean(-1) for pidx, pred in zip(ind, cont_pred): all_predictions[pidx].append(pred) logs[head_name] = torch.cat(logs[head_name], dim=0).view(-1).detach().cpu() # Stack all predictions. Each chart type has different dimension for pidx in range(len(all_predictions)): p = all_predictions[pidx] assert len(p) > 0, "One of the continuous predictions not recorded." all_predictions[pidx] = torch.stack(p, dim=0) return all_predictions, total_loss, logs def decode_scale(self, hidden_row, chart_type_dict, labels=None, wta_idx=None, greedy=False): if labels is not None: cd_label = labels['chart_data'] if self.use_mhd and wta_idx is not None: hidden_row = unflat_tensor(hidden_row, self.hypothese_bsz) logs = {} total_loss = torch.tensor(0.0, device=self.device, dtype=self.dtype_float) bsz = hidden_row.size(0) all_predictions = [None for _ in range(bsz)] #Ensure all heads are used for torchrun for head_name, head in self.cont_decoder['scale'].items(): if head_name not in chart_type_dict: total_loss += (head[0](hidden_row)* 0).sum() for head_name, ind in chart_type_dict.items(): hid_row = hidden_row[ind, :, :] if self.use_mhd and wta_idx is not None: hid_row = torch.flatten(hid_row, start_dim=0, end_dim=1) scale_logit = self.cont_decoder['scale'][head_name][0](hid_row) scale_logit = self.dim_wise_relu(scale_logit, head_name=head_name, decoder_name='scale') if labels is not None: ## Compute loss # Collect labels and mask label = torch.stack([l for idx, l in enumerate(cd_label['scale']['inputs_embeds']) if idx in ind], dim=0).to(self.device) mask = torch.stack([l for idx, l in enumerate(cd_label['scale']['attention_mask']) if idx in ind], dim=0).to(self.device) pad_label_len = scale_logit.size(1) - label.size(1) assert pad_label_len >= 0, "Targets contain length={} larger than max={} | tensor shape = label: {} logits: {}".format( label.size(1), scale_logit.size(1), label.shape, scale_logit.shape ) pad_label = torch.zeros((label.size(0), pad_label_len, label.size(-1)), device=self.device, dtype=self.dtype_float) pad_mask = torch.zeros((label.size(0), pad_label_len), device=self.device, dtype=self.dtype_float) label = torch.cat([label, pad_label], dim=1) mask = torch.cat([mask, pad_mask], dim=1) counts = mask.sum(-1) #Prevents div by zero counts = torch.where(counts == 0, torch.tensor(0.0, device=self.device, dtype=self.dtype_float), 1/counts) if self.use_mhd and wta_idx is not None: #, "Must provide winner idx for this" unflat_scale_logit = unflat_tensor(scale_logit, self.hypothese_bsz) scale_logit = get_win_hypo(unflat_scale_logit, wta_idx[ind]) #Compute distance loss dist_loss = self.scale_loss_fn(scale_logit, label, reduction='none') dist_loss = torch.clip(dist_loss, min=-1e2, max=1e2) dist_loss = dist_loss.mean(-1) * mask loss = (dist_loss.sum(-1) * counts).mean(-1) total_loss += loss.mean() logs[head_name] = loss.detach().cpu() #Cannot stack because last dimension can vary by chart type for idx, pred in zip(ind, scale_logit): all_predictions[idx] = pred assert all(p is not None for p in all_predictions), "One of scale predictions not recorded." return all_predictions, total_loss, logs def dim_wise_relu(self, x, head_name=None, decoder_name=None): dims = x.size(-1) if head_name == 'boxplot' and decoder_name == 'continuous': incrent_dim = [i for i in range(dims) if i > 0] x[:,:,incrent_dim] = torch.relu(x[:,:,incrent_dim]) elif decoder_name == 'scale': assert len(x.shape) == 3, "unsupported shape" min_dims = [i for i in range(dims) if (i % 2) == 0] rng_dims = [i for i in range(dims) if (i % 2) == 1] x[:,:,min_dims] = torch.relu(x[:,:,min_dims] + self.scale_floor[0]) - self.scale_floor[0] x[:,:,rng_dims] = torch.relu(x[:,:,rng_dims] + self.scale_floor[1]) - self.scale_floor[1] else: raise NotImplementedError() return x