# --------------------------------------------------------------- # Copyright (c) __________________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import torch from .scale import ScaleModel class ContinuousModel(ScaleModel): def __init__(self, **kwargs): super().__init__(**kwargs) def embed_continuous(self, cont_logits, tab_shape, tab_samples): ''' Converts ouput of decoders into input embeds that can be used by GANs ''' row_sample = tab_samples['row'] col_sample = tab_samples['col'] row_counts = tab_shape['row'] col_counts = tab_shape['col'] max_rows = min(max(row_counts), max([l.size(0) for l in cont_logits])) max_cols = min(max(col_counts), max([l.size(1) for l in cont_logits])) cont_embds = [] cont_masks = [] for bidx, logits in enumerate(cont_logits): row_preds = row_sample[bidx].unsqueeze(-1).unsqueeze(-1) col_preds = col_sample[bidx].unsqueeze(0).unsqueeze(-1) logits = logits[:max_rows,:max_cols, :] embed = logits * row_preds[:max_rows, :, :] embed = embed * col_preds[:, :max_cols, :] rows = min(row_counts[bidx], max_rows) cols = col_counts[bidx] mask = [[1] * cols + [0] * (max_cols - cols)] * rows + [[0] * max_cols] * (max_rows - rows) mask = torch.tensor(mask, device=self.device) cont_embds.append(embed) cont_masks.append(mask) return cont_embds, cont_masks def unscale_continuous(self, cont_preds, scale_preds, tab_shape=None): # (bsz) if tab_shape is not None: row_counts = tab_shape['row'] col_counts = tab_shape['col'] cont_values = [] for bidx, (all_cont_pred) in enumerate(cont_preds): #Can only generate as far as we have predictions available. if tab_shape is not None: rows = min(row_counts[bidx], all_cont_pred.size(0)) cols = min(col_counts[bidx], all_cont_pred.size(1)) else: rows = all_cont_pred.size(0) cols = all_cont_pred.size(1) #cut continuous shape to match tab shape cont_pred = all_cont_pred[:rows, :cols, :] #rows, (min, range) for each dimension scale_pred = scale_preds[bidx] if len(scale_pred.shape) == 1: scale_pred = scale_pred.view(1, -1) cont_dims = cont_pred.size(-1) #Reverse offseting. Increment from first point if cont_dims * 2 == scale_pred.size(-1): for dim in range(cont_dims): mindim = dim * 2 rngdim = mindim + 1 min_scale = scale_pred[:rows, mindim] rng_scale = scale_pred[:rows, rngdim] #scale_min_rng = (bsz, 2) min_scale = min_scale[:,None] rng_scale = rng_scale[:,None] #Perform scaling cont_pred[:,:,dim] = cont_pred[:,:,dim] * rng_scale + min_scale else: for cidx in range(1, cols): cont_pred[:, cidx, :] += cont_pred[:, cidx - 1, :] min_scale = scale_pred[:rows, 0] rng_scale = scale_pred[:rows, 1] #scale_min_rng = (bsz, 2) min_scale = min_scale.view(-1,1,1) rng_scale = rng_scale.view(-1,1,1) cont_pred = cont_pred * rng_scale + min_scale cont_values.append(cont_pred) return cont_values