honeyplotnet / utils / recon_plots.py
recon_plots.py
Raw
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
import random

from .constant import CHART_TO_HEAD_MAP, UNIQ_CHART_HEADS

from matplotlib.colors import ListedColormap

SINGLE_PLOT_TIGHT_LAYOUT_REC = [0, 0.2, 1.0, 1.0]
SINGLE_PLOT_TEXT_LOC = [0.5, 0.05]
MAX_SENT_CAPTION = 2
MAX_TEXT_CAPTION = 400

COLOR_PALETTES = [
    ['#a6cee3','#1f78b4','#b2df8a','#33a02c','#fb9a99','#e31a1c','#fdbf6f','#ff7f00','#cab2d6'],
    ['#e41a1c','#377eb8','#4daf4a','#984ea3','#ff7f00','#ffff33','#a65628','#f781bf','#999999'],
    ['#8dd3c7','#bebada','#fb8072','#80b1d3','#fdb462','#b3de69','#fccde5','#d9d9d9'],
    ['#d0d1e6','#a6bddb','#74a9cf','#3690c0','#0570b0','#045a8d','#023858'],
    ['#ccece6','#99d8c9','#66c2a4','#41ae76','#238b45','#006d2c','#00441b'],
    ['#d4b9da','#c994c7','#df65b0','#e7298a','#ce1256','#980043','#67001f']
]

def unscale_conts(cont_preds, scale_preds, tab_shape=None, scale_eps=1.00001, scale_exponent=10):
    if tab_shape is not None:
      row_counts = tab_shape['row']
      col_counts = tab_shape['col']
    
    cont_values = []
    for bidx, (all_cont_pred) in enumerate(cont_preds):

      #Can only generate as far as we have predictions available.
      if tab_shape is not None:
        rows = min(row_counts[bidx], all_cont_pred.size(0))
        cols = min(col_counts[bidx], all_cont_pred.size(1))
      else:
        rows = all_cont_pred.size(0)
        cols = all_cont_pred.size(1)

      cont_pred = all_cont_pred[:rows, :cols, :].clone()
      scale_pred = scale_preds[bidx].clone()
      if len(scale_pred.shape) == 1:
        scale_pred = scale_pred.view(1, -1)

      cont_dims = cont_pred.size(-1)
      
      #Reverse offseting. Increment from first point
      if cont_dims * 2 == scale_pred.size(-1):

        for dim in range(cont_dims):
          mindim = dim * 2
          rngdim = mindim + 1
          
          min_scale = scale_pred[:rows, mindim]
          rng_scale = scale_pred[:rows, rngdim]

          #scale_min_rng = (bsz, 2)
          min_scale = min_scale[:,None]
          rng_scale = rng_scale[:,None]

          #Perform scaling
          cont_pred[:,:,dim] = cont_pred[:,:,dim] * rng_scale + min_scale
      else:

        # for cidx in range(1, cont_pred.size(1)):
        #   cont_pred[:, cidx, :] += cont_pred[:, cidx - 1, :]
        for cidx in range(1, cont_pred.size(2)):
          cont_pred[:, :, cidx] += cont_pred[:, :, cidx - 1]

        min_scale = scale_pred[:rows, 0]
        rng_scale = scale_pred[:rows, 1]

        #scale_min_rng = (bsz, 2)
        min_scale = min_scale.view(-1,1,1)
        rng_scale = rng_scale.view(-1,1,1)

        cont_pred = cont_pred * rng_scale + min_scale

      cont_values.append(cont_pred)
    return cont_values
    

def unscale_scales(scale_logits, tab_shape=None,  scale_eps_min=1.100001, scale_eps_rng=1.100001, scale_exponent=10):
    if tab_shape is not None:
        row_counts = tab_shape['row']
    
    scale_values = []
    for bidx, logits in enumerate(scale_logits):

      if tab_shape is not None:
        rows = int(row_counts[bidx])
        logits = logits[:rows, :]

      base = torch.ones_like(logits) * scale_exponent
      scale_pred = torch.pow(base, logits)

      dims = scale_pred.size(-1)
      min_dims = [i for i in range(dims) if (i % 2) == 0]
      rng_dims  = [i for i in range(dims) if (i % 2) == 1]

      scale_pred[:,min_dims] -= scale_eps_min
      scale_pred[:,rng_dims] -= scale_eps_rng

      scale_pred = torch.abs(scale_pred)
      
      scale_values.append(scale_pred)

    return scale_values

def prepare_mpl(x):
    chart_types      = x.get('chart_type')
    chart_idx        = x.get('chart_idx')
    if chart_idx is None and chart_types is None:
        raise ValueError("Provide one")

    scale_tens_list  = x['scale']['inputs_embeds']
    cont_tens_list   = x['continuous']['inputs_embeds']

    #if tab_shape not available, it must be from dataset
    tab_shape = x.get('shape')
    if tab_shape is not None:
        tab_shape = tab_shape['counts']
    else:
        attn_mask = x['continuous']['attention_mask'].sum(-1)
        num_cols = torch.max(attn_mask, dim=-1).values.cpu().tolist()
        num_rows = torch.where(attn_mask > 0, 1, 0).sum(-1).cpu().tolist()
        
        tab_shape = {}
        tab_shape['row'] = num_rows
        tab_shape['col'] = num_cols

    scale_values = unscale_scales(scale_tens_list, tab_shape=tab_shape)
    cont_values = unscale_conts(cont_tens_list, scale_values, tab_shape=tab_shape)

    #Convert into a list
    chart_list = []
    for idx, values in enumerate(cont_values):
        chart_dict = {}

        if chart_types is not None:
            ct = chart_types[idx]
            chart_dict['chart_type'] = ct
            chart_dict['head_type']  = CHART_TO_HEAD_MAP[ct]
        elif chart_idx is not None:
            ct_idx = chart_idx[idx]
            chart_dict['head_type'] = UNIQ_CHART_HEADS[ct_idx]
        else:
            raise

        if chart_dict['head_type'] == 'categorical':
            assert values.size(-1) == 1, f"invalid shape for categorical data, {values.shape}"
            values = torch.reshape(values, [values.size(0), values.size(1)])
            chart_dict['values'] = values.cpu().tolist()
        elif chart_dict['head_type'] == 'point':
            assert values.size(-1) == 2, f"invalid shape for point data, {values.shape}"
            x_val = values[:,:,0].cpu().tolist()
            y_val = values[:,:,1].cpu().tolist()
            chart_dict['values'] = {'x': x_val, 'y': y_val}
        elif chart_dict['head_type'] == 'boxplot':
            assert values.size(-1) == 5, f"invalid shape for boxplot data, {values.shape}"
            chart_dict['values'] = values.cpu().tolist()
            
        chart_list.append(chart_dict)
    return chart_list

def create_bar_chart(x_data, text, f_name, xhat_data=None, bar_width=0.75, step=None):

    if xhat_data is not None:
        _create_bar_chart_double(x_data=x_data, xhat_data=xhat_data, text=text, f_name=f_name, bar_width=bar_width, step=step)
    else:
        _create_bar_chart_single(x_data=x_data, text=text, f_name=f_name, base_width=bar_width, step=step)


def _create_bar_chart_single(x_data, text, f_name, base_width=0.25, flip_thres=4, step=None):
    recon_values = x_data['values']
    
    _ = plt.figure()
    plt.figure(facecolor='white')

    #Remove repeats between categorical and series names
    text['categorical'] = [s for s in text['categorical'] if 'unnamed' not in s]
    text['series_name'] = [s for s in text['series_name'] if s not in text['categorical'] and 'unnamed'  not in s]
    text['axis_titles'] = [s for s in text['axis_titles'] if s not in text['categorical'] and 'unnamed'  not in s]

    #clip by number of available categorical values
    max_cat = min(len(text['categorical']), len(recon_values[0]))
    max_series = min(len(text['series_name']), len(recon_values))

    if step is None:
        cidx = random.choice(range(len(COLOR_PALETTES)))
        cmap = COLOR_PALETTES[cidx]
    else:
        cmap = COLOR_PALETTES[step % len(COLOR_PALETTES)]

    #Manually remove
    if max_cat <= 1 or max_series <= 1:
        return

    if max_series == 1:
        bar_width = base_width
        ticks = [r for r in range(max_cat)]
    else:
        bar_width = base_width / max_series
        #ticks = [r + bar_width for r in range(max_cat)]
        ticks = [r + base_width / 3 for r in range(max_cat)]
    
    label_count = 0
    max_label_len = 0
    for idx, (y, label)in enumerate(zip(recon_values, text['series_name'])):
        if idx > max_series:
            break
        
        color = cmap[idx % len(cmap)]

        y = y[:max_cat]
        if idx == 0:
            br1 = np.arange(len(y))
        else:
            br1 = [x + bar_width for x in br1]

        label = text['series_name'][idx]
        if 'unnamed' in label:
            label = None
        else:
            label_count += 1
            if len(label) > max_label_len:
                max_label_len = len(label)
        
        if max_cat > flip_thres:
            plt.barh(br1, y, height=bar_width, label=label, color=color)
        else:
            plt.bar(br1, y, width=bar_width, label=label, color=color)

        #plt.bar(br1, y, width=bar_width, label=label)
    
    rotation = 0
    fontsize = 'small'
    wrap = True
    if max_cat > flip_thres:
        rotation = 30
        wrap = False
        fontsize=8

    categorical_labels = text['categorical'][:max_cat]
    max_cat_label_len = max(len(c) for c in categorical_labels)
    if max_cat > flip_thres:
        plt.yticks(ticks, categorical_labels, wrap=wrap, fontsize=fontsize)
    else:
        plt.xticks(ticks, categorical_labels, wrap=wrap, fontsize=fontsize, rotation=rotation)

    if max_series > 1 and label_count > 0:
        if max_label_len <= 5:
            #Short legend box on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        elif max_label_len >= 40:
            #Long legend box should be put on the top
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=1)
        elif max_cat_label_len >= 40:
            #Long categorical labels mean it must be put on the top
            if max_label_len < 10 and label_count < 4:
                plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=label_count)
            else:
                plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=1)
        elif label_count >= 4:
            #Lots of labels, put on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")

        else:
            if max_label_len > 20 or max_cat_label_len > 40:
                ncol = 1
            else:
                ncol = label_count
            
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=ncol)

    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        if max_cat > flip_thres:
            plt.xlabel(axis_title1)
        else:
            plt.ylabel(axis_title1)

    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]
    s = SINGLE_PLOT_TEXT_LOC
    plt.figtext(s[0], s[1], caption, wrap=True, horizontalalignment='center')

    plt.tight_layout(rect=SINGLE_PLOT_TIGHT_LAYOUT_REC)
    plt.savefig(fname=f_name)
    plt.close()

def _create_bar_chart_double(x_data, xhat_data, text, f_name, bar_width=0.25, step=None):

    # Create original plot
    original_values = x_data['values']
    plt.subplot(2, 1, 1)  # row 1, column 2, count 1

    #Pick horizontal or horizontal based on xticks

    #plt.title('Original')
    if step is None:
        cidx = random.choice(range(len(COLOR_PALETTES)))
        cmap = COLOR_PALETTES[cidx]
    else:
        cmap = COLOR_PALETTES[step % len(COLOR_PALETTES)]

    #plt.xlabel('x-axis')
    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        if len(text['categorical']) > 5:
            plt.xlabel(axis_title1)
        else:
            plt.ylabel(axis_title1)
    
    if len(original_values) == 1:
        bar_width = 0.75
        ticks = [r for r in range(len(original_values[0]))]
    else:
        bar_width = 0.75 / len(original_values)
        ticks = [r + bar_width for r in range(len(original_values[0]))]
        
    for idx, y in enumerate(original_values):
        color = cmap[idx % len(cmap)]

        if idx == 0:
            br1 = np.arange(len(y))
        else:
            br1 = [x + bar_width for x in br1]
        
        label = text['series_name'][idx]
        if 'unnamed' in label:
            label = None
        
        if len(text['categorical']) > 5:
            plt.barh(br1, y, height=bar_width, label=label, color=color)
        else:
            plt.bar(br1, y, width=bar_width, label=label, color=color)

    rotation = 0
    wrap = True
    fontsize = 'small'
    if len(text['categorical']) > 5:
        rotation = 30
        wrap = True
        fontsize = 8    
    
    cat_labels = [c for c in text['categorical'] if 'unnamed' not in c]
    if len(cat_labels):
        if len(text['categorical']) > 5:
            plt.yticks(ticks[:len(cat_labels)], cat_labels, wrap=wrap, fontsize=fontsize)
        else:
            plt.xticks(ticks[:len(cat_labels)], cat_labels, wrap=wrap, fontsize=fontsize, rotation=rotation)


    if len(original_values) > 1:
        plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        
    # Create reconstruction
    plt.subplot(2, 1, 2)
    #plt.title('Reconstruction')

    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        if len(text['categorical']) > 5:
            plt.xlabel(axis_title1)
        else:
            plt.ylabel(axis_title1)
    
    recon_values = xhat_data['values']

    #clip by number of available categorical values
    clip_val = min(len(text['categorical']), len(recon_values[0]))

    for idx, (y, label)in enumerate(zip(recon_values, text['series_name'])):
        color = cmap[idx % len(cmap)]

        #label = [l for l in label if 'unnamed' not in l]
        if 'unnamed' in label:
            label = None

        y = y[:clip_val]
        if idx == 0:
            br1 = np.arange(len(y))
        else:
            br1 = [x + bar_width for x in br1][:clip_val]

        if len(text['categorical']) > 5:
            plt.barh(br1, y, height=bar_width, label=label, color=color)
        else:
            plt.bar(br1, y, width=bar_width, label=label, color=color)

    rotation = 0
    wrap = True
    fontsize = 'small'
    if clip_val > 5:
        rotation = 30
        wrap = True
        fontsize = 8

    if len(text['categorical']) > 5:
        plt.yticks(ticks[:clip_val], text['categorical'][:clip_val], wrap=wrap, fontsize=fontsize)
    else:
        plt.xticks(ticks[:clip_val], text['categorical'][:clip_val], wrap=wrap, fontsize=fontsize, rotation=rotation)

    if len(recon_values) > 1:
        plt.legend(bbox_to_anchor=(1,1), loc="upper left")

    #Add caption
    #plt.subplots_adjust(bottom=0.2) 
    plt.tight_layout(rect=[0, 0.0, 1.0, 0.8])

    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]
    plt.figtext(0.5, 0.9, caption, wrap=True, horizontalalignment='center')

    # space between the plots
    plt.savefig(fname=f_name)
    plt.close()

def create_scatter(x_data, text, f_name, xhat_data=None, bar_width=0.25, step=None):

    if xhat_data is not None:
        _create_scatter_double(x_data=x_data, xhat_data=xhat_data, text=text, f_name=f_name, step=step)
    else:
        _create_scatter_single(x_data=x_data, text=text, f_name=f_name, step=step)

def _create_scatter_single(x_data, text, f_name, markersize=4, step=None):
    x_values = x_data['values']['x']
    y_values = x_data['values']['y']
    fig = plt.figure()
    plt.figure(facecolor='white')

    if step is None:
        cidx = random.choice(range(len(COLOR_PALETTES)))
        cmap = ListedColormap(COLOR_PALETTES[cidx])
    else:
        cmap = ListedColormap(COLOR_PALETTES[step % len(COLOR_PALETTES)])

    #Remove repeats between categorical and series names
    text['series_name'] = [s for s in text['series_name'] if 'unnamed'  not in s]

    label_count = 0
    max_label_len = 0
    if len(x_values[0]) == 1:
        return

    for idx, (x, y, label) in enumerate(zip(x_values, y_values, text['series_name'])):
        if 'unnamed' in label:
            label = None
        else:
            label_count += 1
            if len(label) > max_label_len:
                max_label_len = len(label)

        plt.scatter(x, y, cmap=cmap, label=label, s=markersize)

    if len(x_values) > 1 and label_count > 0:
        if max_label_len <= 5:
            #Short legend box on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        elif max_label_len >= 40:
            #Long legend box should be put on the top
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=1)

        elif label_count >= 4:
            #Lots of labels, put on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")

        else:
            if max_label_len > 20:
                ncol = 1
            else:
                ncol = label_count
            
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=ncol)

    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        plt.xlabel(axis_title1)
    if len(text['axis_titles']) > 1:
        axis_title2 = text['axis_titles'][1].split('.')[0]
        plt.ylabel(axis_title2)

    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]

    s = SINGLE_PLOT_TEXT_LOC
    plt.figtext(s[0], s[1], caption, wrap=True, horizontalalignment='center')

    plt.tight_layout(rect=SINGLE_PLOT_TIGHT_LAYOUT_REC)
    plt.savefig(fname=f_name)
    plt.close()


def _create_scatter_double(x_data, xhat_data, text, f_name, step=None):

    if step is None:
        cidx = random.choice(range(len(COLOR_PALETTES)))
        cmap = ListedColormap(COLOR_PALETTES[cidx])
    else:
        cmap = ListedColormap(COLOR_PALETTES[step % len(COLOR_PALETTES)])

    x_values = x_data['values']['x']
    y_values = x_data['values']['y']
    plt.subplot(2, 1, 1)  

    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        plt.xlabel(axis_title1)
    if len(text['axis_titles']) > 1:
        axis_title2 = text['axis_titles'][1].split('.')[0]
        plt.ylabel(axis_title2)

    for idx, (x,y, label) in enumerate(zip(x_values, y_values, text['series_name'])):
        if 'unnamed' in label:
            label = None
        plt.scatter(x, y, cmap=cmap, label=label)

    if len(x_values) > 1:
        plt.legend(bbox_to_anchor=(1,1), loc="upper left")

    x_values = xhat_data['values']['x']
    y_values = xhat_data['values']['y']

    plt.subplot(2, 1, 2)
    #plt.title('Reconstruction')
    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        plt.xlabel(axis_title1)
    if len(text['axis_titles']) > 1:
        axis_title2 = text['axis_titles'][1].split('.')[0]
        plt.ylabel(axis_title2)
    
    for idx, (x,y, label) in enumerate(zip(x_values, y_values, text['series_name'])):
        if 'unnamed' in label:
            label = None
        plt.scatter(x, y, cmap=cmap, label=label)

    if len(x_values) > 1:
        plt.legend(bbox_to_anchor=(1,1), loc="upper left")
    
    #Add caption
    #plt.subplots_adjust(bottom=0.2) 
    plt.tight_layout(rect=[0, 0.0, 1.0, 0.8])

    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]
    plt.figtext(0.5, 0.9, caption, wrap=True, horizontalalignment='center')

    # space between the plots
    plt.savefig(fname=f_name)
    plt.close()

def create_boxplot(x_data, text, f_name, xhat_data=None, step=None):

    if xhat_data is not None:
        create_boxplot_double(x_data=x_data, xhat_data=xhat_data, text=text, f_name=f_name)
    else:
        _create_boxplot_single(x_data=x_data, text=text, f_name=f_name)


def _create_boxplot_single(x_data, text, f_name):

    fig = plt.figure()
    plt.figure(facecolor='white')

    text['categorical'] = [s for s in text['categorical'] if 'unnamed'  not in s]

    recon_values = x_data['values']
    clip_val = min(len(text['categorical']), len(recon_values[0]))

    #Manually remove
    if clip_val <= 1:
        return

    plt.boxplot(recon_values[0][:clip_val], labels=text['categorical'][:clip_val])

    if len(text['axis_titles']) > 0:
        axis_title = text['axis_titles'][0].split('.')[0]
        plt.ylabel(axis_title)


    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]

    s = SINGLE_PLOT_TEXT_LOC
    plt.figtext(s[0], s[1], caption, wrap=True, horizontalalignment='center')
    plt.tight_layout(rect=SINGLE_PLOT_TIGHT_LAYOUT_REC)

    # space between the plots
    plt.savefig(fname=f_name)
    plt.close()

def create_boxplot_double(x_data, xhat_data, text, f_name):
    
    # Create original plot
    original_values = x_data['values']
    plt.subplot(2, 1, 1)  # row 1, column 2, count 1
    #plt.title('Original')

    if len(text['axis_titles']) > 0:
        plt.ylabel(text['axis_titles'][0])
    plt.boxplot(original_values[0], labels=text['categorical'])

    # Create reconstruction
    plt.subplot(2, 1, 2)
    #plt.title('Reconstruction')
    if len(text['axis_titles']) > 0:
        axis_title = text['axis_titles'][0].split('.')[0]
        plt.ylabel(axis_title)
        plt.ylabel(axis_title)


    recon_values = xhat_data['values']
    clip_val = min(len(text['categorical']), len(recon_values[0]))

    plt.boxplot(recon_values[0][:clip_val], labels=text['categorical'][:clip_val])

    #Add caption
    #plt.subplots_adjust(bottom=0.2) 
    plt.tight_layout(rect=[0, 0.0, 1.0, 0.8])

    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]
    plt.figtext(0.5, 0.9, caption, wrap=True, horizontalalignment='center')

    # space between the plots
    plt.savefig(fname=f_name)
    plt.close()

def create_recon_plots(x, x_hat, text_data, step, epoch_dir):

    if 'chart_data' in x:
        x = x['chart_data']
    
    if 'chart_data' in x_hat:
        x_hat  = x_hat['chart_data']

    x_chart_data    = prepare_mpl(x)
    xhat_chart_data = prepare_mpl(x_hat)

    for idx, (x_data, xhat_data, text) in enumerate(zip(x_chart_data, xhat_chart_data, text_data)):
        assert x_data['head_type'] == xhat_data['head_type']

        f_name = os.path.join(epoch_dir, f"{step}-{idx}-{xhat_data['head_type']}.png")
        if x_data['head_type'] == 'categorical':
            create_bar_chart(x_data=x_data, xhat_data=xhat_data, text=text, f_name=f_name, step=step)
        elif x_data['head_type'] == 'point':
            create_line(x_data=x_data, xhat_data=xhat_data, text=text, f_name=f_name, step=step)
        else:
            create_boxplot(x_data=x_data, xhat_data=xhat_data, text=text, f_name=f_name, step=step)


def create_line(x_data, text, f_name, xhat_data=None, bar_width=0.25, markersize=3, step=None):

    if xhat_data is not None:
        _create_line_double(x_data=x_data, xhat_data=xhat_data, text=text, f_name=f_name, markersize=markersize, step=step)
    else:
        _create_line_single(x_data=x_data, text=text, f_name=f_name, markersize=markersize, step=step)

def _create_line_single(x_data, text, f_name, markersize=4, linewidth=2, step=None):

    x_values = x_data['values']['x']
    y_values = x_data['values']['y']
    fig, ax = plt.subplots()

    plt.figure(facecolor='white')
    
    if step is None:
        cidx = random.choice(range(len(COLOR_PALETTES)))
        cmap = COLOR_PALETTES[cidx]
    else:
        cmap = COLOR_PALETTES[step % len(COLOR_PALETTES)]

    #Remove repeats between categorical and series names
    text['series_name'] = [s for s in text['series_name'] if 'unnamed'  not in s]

    label_count = 0
    max_label_len = 0
    if len(x_values[0]) == 1:
        return

    for idx, (x, y, label) in enumerate(zip(x_values, y_values, text['series_name'])):
        if 'unnamed' in label:
            label = None
        else:
            label_count += 1
            if len(label) > max_label_len:
                max_label_len = len(label)

        x, y = zip(*sorted(zip(x, y)))

        color = cmap[idx % len(cmap)]
        plt.plot(x, y, label=label, 
            markersize=markersize,  color=color, 
            marker='o', linewidth=linewidth)
    
    if len(x_values) > 1 and label_count > 0:
        if max_label_len <= 5:
            #Short legend box on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        elif max_label_len >= 40:
            #Long legend box should be put on the top
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=1)

        elif label_count >= 4:
            #Lots of labels, put on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")

        else:
            if max_label_len > 20:
                ncol = 1
            else:
                ncol = label_count
            
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=ncol)

    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        plt.xlabel(axis_title1)
    if len(text['axis_titles']) > 1:
        axis_title2 = text['axis_titles'][1].split('.')[0]
        plt.ylabel(axis_title2)

    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]

    s = SINGLE_PLOT_TEXT_LOC
    plt.figtext(s[0], s[1], caption, wrap=True, horizontalalignment='center')

    plt.tight_layout(rect=SINGLE_PLOT_TIGHT_LAYOUT_REC)
    plt.savefig(fname=f_name)
    plt.close()


def _create_line_double(x_data, xhat_data, text, f_name,  markersize=4, linewidth=2, step=None):

    x_values = x_data['values']['x']
    y_values = x_data['values']['y']
    fig, ax = plt.subplots()

    plt.subplot(2, 1, 1)  

    if step is None:
        cidx = random.choice(range(len(COLOR_PALETTES)))
        cmap = COLOR_PALETTES[cidx]
    else:
        cmap = COLOR_PALETTES[step % len(COLOR_PALETTES)]

    text['series_name'] = [s for s in text['series_name'] if 'unnamed'  not in s]

    ## start from here...
    label_count = 0
    max_label_len = 0
    if len(x_values[0]) == 1:
        return

    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        plt.xlabel(axis_title1)
    if len(text['axis_titles']) > 1:
        axis_title2 = text['axis_titles'][1].split('.')[0]
        plt.ylabel(axis_title2)

    for idx, (x,y, label) in enumerate(zip(x_values, y_values, text['series_name'])):
        if 'unnamed' in label:
            label = None
        else:
            label_count += 1
            if len(label) > max_label_len:
                max_label_len = len(label)
        
        x, y = zip(*sorted(zip(x, y)))
        color = cmap[idx % len(cmap)]

        plt.plot(x, y, label=label,
            markersize=markersize,  color=color, 
            marker='o', linewidth=linewidth)

    if len(x_values) > 1 and label_count > 0:
        if max_label_len <= 5:
            #Short legend box on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        elif max_label_len >= 40:
            #Long legend box should be put on the top
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=1)

        elif label_count >= 4:
            #Lots of labels, put on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")

        else:
            if max_label_len > 20:
                ncol = 1
            else:
                ncol = label_count
            
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=ncol)

    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        plt.xlabel(axis_title1)
    if len(text['axis_titles']) > 1:
        axis_title2 = text['axis_titles'][1].split('.')[0]
        plt.ylabel(axis_title2)

    # Reconstruction
    x_values = xhat_data['values']['x']
    y_values = xhat_data['values']['y']

    plt.subplot(2, 1, 2)
    #plt.title('Reconstruction')
    if len(text['axis_titles']) > 0:
        axis_title1 = text['axis_titles'][0].split('.')[0]
        plt.xlabel(axis_title1)
    if len(text['axis_titles']) > 1:
        axis_title2 = text['axis_titles'][1].split('.')[0]
        plt.ylabel(axis_title2)
    
    for idx, (x,y, label) in enumerate(zip(x_values, y_values, text['series_name'])):
        if 'unnamed' in label:
            label = None

        x, y = zip(*sorted(zip(x, y)))
        color = cmap[idx % len(cmap)]

        plt.plot(x, y, label=label,
            markersize=markersize,  color=color, 
            marker='o', linewidth=linewidth)

    label_count = 0
    max_label_len = 0
    if len(x_values) > 1 and label_count > 0:
        if max_label_len <= 5:
            #Short legend box on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")
        elif max_label_len >= 40:
            #Long legend box should be put on the top
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=1)

        elif label_count >= 4:
            #Lots of labels, put on the right
            plt.legend(bbox_to_anchor=(1,1), loc="upper left")

        else:
            if max_label_len > 20:
                ncol = 1
            else:
                ncol = label_count
            
            plt.legend(bbox_to_anchor=(0, 1, 1, 0), loc="lower left", mode="expand", ncol=ncol)

    #Add caption
    #plt.subplots_adjust(bottom=0.2) 
    plt.tight_layout(rect=[0, 0.0, 1.0, 0.8])

    #Only keep first 3 sentences
    caption = text['caption'].replace('\n','')
    caption = '.'.join(caption.split('.')[:MAX_SENT_CAPTION])[:MAX_TEXT_CAPTION]
    plt.figtext(0.5, 0.9, caption, wrap=True, horizontalalignment='center')

    # space between the plots
    plt.savefig(fname=f_name)
    plt.close()


def create_single_plot(x, text_data, save_dir, step):

    if 'chart_data' in x:
      x = x['chart_data']
    
    x_data_mpl = prepare_mpl(x)

    for idx, x_data in enumerate(x_data_mpl):
      f_name = os.path.join(save_dir, f"{step}-{idx}-{x_data['head_type']}.png")

      text = text_data[idx]

      if x_data['head_type'] == 'categorical':
          create_bar_chart(x_data, text, f_name, step=step)
      elif x_data['head_type'] == 'point':
          create_line(x_data, text, f_name, step=step)
      elif x_data['head_type'] == 'boxplot':
          create_boxplot(x_data, text, f_name, step=step)
      else:
        raise ValueError(f"Invalid chart type given: {x_data['head_type']}")
      plt.close()