mvq / utils / plot_helpers.py
plot_helpers.py
Raw
# ---------------------------------------------------------------
# Copyright (c) ___________________ 2023.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# ---------------------------------------------------------------

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

def save_fid_line(data, save_path, **kwargs):
    plt.clf()
    sns.reset_defaults()
    sns.set_style("ticks")

    processed_data = []
    for d in data:
        d_train = {'epoch': d['epoch'], 'fid': d['train_fid'], 'split': 'train'}
        d_eval = {'epoch': d['epoch'], 'fid': d['eval_fid'], 'split': 'eval'}
        processed_data += [d_eval, d_train]
    
    df = pd.DataFrame(processed_data)
    ax = sns.lineplot(data=df, x='epoch', y="fid", hue='split')
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

    plt.savefig(save_path, bbox_inches='tight')
    plt.close('all')

def tile_image(batch_image, h, w=None): 
    if w is None:
        w = h
    assert w * h == batch_image.size(0)
    channels, height, width = batch_image.size(1), batch_image.size(2), batch_image.size(3)
    batch_image = batch_image.view(h, w, channels, height, width)
    batch_image = batch_image.permute(2, 0, 3, 1, 4)                              # n, height, n, width, c
    batch_image = batch_image.contiguous().view(channels, h * height, w * width)
    return batch_image

def save_scatter_plot(save_path, all_df, x_lim=None, y_lim=None):
    plt.clf()
    markers = {"data": "o", "center": "s", "sample": "X"}
    scatter_kws={"zorder": 10}

    all_df = all_df.sort_values(by=['size', 'label'])
    
    fig, ax = plt.subplots(1)
    size = all_df['size'].values
    sns.scatterplot(x='x', y='y', hue='label', s=size, style="style", 
        data=all_df, ax=ax, markers=markers, palette="tab10")
    # lim = (all_df[['x', 'y']].values.min()-0.05, all_df[['x', 'y']].values.max()+0.05)
    if x_lim is not None:
        ax.set_xlim(x_lim)
    if y_lim is not None:
        ax.set_ylim(y_lim)
    #ax.set_aspect('equal')
    sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))

    plt.savefig(save_path, bbox_inches='tight')
    plt.close('all')

    
def save_scatter_subplot(save_path, titles, all_df1, all_df2, all_df3=None, x_lim=None, y_lim=None, base_width=5, palettes=None):
    plt.clf()

    plot_count = 2 if all_df3 is None else 3
    assert len(titles) == plot_count
    if palettes is None:
        palettes = ['flare', 'viridis', 'flare']
    else:
        assert len(palettes) >= plot_count

    plot_width = plot_count * base_width

    sns.set(rc={"figure.figsize":(plot_width, base_width)})

    markers = {"data": "o", "center": "s", "sample": "X"}

    all_df1 = all_df1.sort_values(by=['size', 'label'])
    all_df2 = all_df2.sort_values(by=['size', 'label'])

    size1 = all_df1['size'].values
    size2 = all_df2['size'].values

    if all_df3 is not None:
        all_df3 = all_df3.sort_values(by=['size', 'label'])
        size3 = all_df3['size'].values

    fig, axes = plt.subplots(1, plot_count)

    sp1 = sns.scatterplot(x='x', y='y', hue='label', s=size1, style="style", 
        data=all_df1, ax=axes[0], markers=markers, palette=palettes[0], legend=False)
    
    sp1.set(title=titles[0])

    sp2 = sns.scatterplot(x='x', y='y', hue='label', s=size2, style="style", 
        data=all_df2, ax=axes[1], markers=markers, palette=palettes[1], legend=False)
    
    sp2.set(yticklabels=[])
    sp2.set(ylabel=None)
    sp2.set(title=titles[1])

    if all_df3 is not None:
        sp3 = sns.scatterplot(x='x', y='y', hue='label', s=size3, style="style", 
            data=all_df3, ax=axes[2], markers=markers, palette=palettes[2], legend=False)
        sp3.set(yticklabels=[])
        sp3.set(ylabel=None)
        sp3.set(title=titles[2])

    if x_lim is not None:
        axes[0].set_xlim(x_lim)
        axes[1].set_xlim(x_lim)
        if all_df3 is not None:
            axes[2].set_xlim(x_lim)
    if y_lim is not None:
        axes[0].set_ylim(y_lim)
        axes[1].set_ylim(y_lim)
        if all_df3 is not None:
            axes[2].set_ylim(y_lim)

    #ax.set_aspect('equal')
    #sns.move_legend(axes[0], "lower center", bbox_to_anchor=(1, 1))
    #sns.move_legend(axes[1], "lower center", bbox_to_anchor=(1, 1))

    plt.savefig(save_path, bbox_inches='tight')
    plt.close(fig)


def save_mega_subplot(save_path, plots, col_titles, row_titles=None):
    plt.clf()
    #assert all(len(p) == len(col_titles) for p in plots), "Provide equal size lists: {}={}".format(len(plots[0]), len(col_titles))
    row_count = len(plots)
    col_count = len(plots[0])

    idx = 0
    for row_idx in range(row_count):
        plot_row = plots[row_idx]
        for col_idx in range(col_count):
            if isinstance(plot_row, list):
                if col_idx >= len(plot_row):
                    break
            else:
                if col_idx >= plot_row.size(0):
                    break
            idx += 1
            plt.subplot(row_count, col_count, idx)
            plt.imshow(plot_row[col_idx].squeeze())
            if row_idx == 0:
                plt.title(col_titles[col_idx])

                #plt.subplot(row_count, col_count, idx).set_ylabel(row_titles[row_idx])
            
            if row_titles is not None and col_idx == 0:
                #plt.title()
                plt.ylabel(row_titles[row_idx], rotation=0)
            
            #Remove borders
            for pos in ['right', 'top', 'bottom', 'left']:
                plt.gca().spines[pos].set_visible(False)

            plt.tick_params(left=False,
                bottom=False,
                labelleft=False,
                labelbottom=False)

    plt.savefig(save_path, bbox_inches='tight')
    plt.close('all')