honeyplotnet / models / continuous / helper.py
helper.py
Raw
# ---------------------------------------------------------------
# Copyright (c) __________________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------


import torch

from ..constant import CHART_TO_HEAD_MAP, UNIQ_CHART_HEADS

def pad_vector(x, pad_len, pad_dim, device, dtype):
  x_pad_len = pad_len - x.size(pad_dim)
  if x_pad_len > 0:
    x_shape = list(x.shape)
    x_shape[pad_dim] = x_pad_len
    padding = torch.zeros(x_shape, device=device, dtype=dtype)
    x = torch.cat([x, padding], dim=pad_dim)
  return x

def unflat_tensor(x, hypo_count, hypo_dim=1):
  x_shape = list(x.shape)
  x_shape.insert(hypo_dim, hypo_count)
  x_shape[0] = -1

  return x.reshape(x_shape)

def get_topk_batch(topk_idx, x):
  x_batch = []
  for idx, _x  in zip(topk_idx, x):
    x_batch.append(_x[idx])
  return torch.stack(x_batch, dim=0)

def get_chart_type_dict(chart_types, hypo_bsz=1):
  chart_type_dict = {}
  
  for idx, ct in enumerate(chart_types): 
    hidx = idx * hypo_bsz

    head_map = CHART_TO_HEAD_MAP[ct]

    if head_map not in chart_type_dict:
      chart_type_dict[head_map] = []
    chart_type_dict[head_map] += [i for i in range(hidx, hidx + hypo_bsz)]
    
  return chart_type_dict

def get_chart_type_from_idx(indices):
  chart_type_dict = {}
  for i, idx in enumerate(indices):
    head_map = UNIQ_CHART_HEADS[idx]
    if head_map not in chart_type_dict:
      chart_type_dict[head_map] = []
    chart_type_dict[head_map] += [i]
  return chart_type_dict

def make_repeat_frame(x, hypo_bsz, only_frame=True, hypo_dim=1):
  x_shape = list(x.shape)
  repeat_frame = [1] * (len(x_shape) + 1)
  repeat_frame[hypo_dim] = hypo_bsz
  if only_frame:
    return repeat_frame
  else:
    x_repeated = x.unsqueeze(hypo_dim).repeat(repeat_frame)
    return repeat_frame, x_repeated

def get_win_hypo(hypos, wta_idx):
    batch_list = torch.arange(hypos.size(0))
    winner_hypo = hypos[batch_list, wta_idx, :]
    return winner_hypo