# --------------------------------------------------------------- # Copyright (c) ___________________ 2023. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # --------------------------------------------------------------- import matplotlib.pyplot as plt import seaborn as sns import pandas as pd def save_fid_line(data, save_path, **kwargs): plt.clf() sns.reset_defaults() sns.set_style("ticks") processed_data = [] for d in data: d_train = {'epoch': d['epoch'], 'fid': d['train_fid'], 'split': 'train'} d_eval = {'epoch': d['epoch'], 'fid': d['eval_fid'], 'split': 'eval'} processed_data += [d_eval, d_train] df = pd.DataFrame(processed_data) ax = sns.lineplot(data=df, x='epoch', y="fid", hue='split') sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1)) plt.savefig(save_path, bbox_inches='tight') plt.close('all') def tile_image(batch_image, h, w=None): if w is None: w = h assert w * h == batch_image.size(0) channels, height, width = batch_image.size(1), batch_image.size(2), batch_image.size(3) batch_image = batch_image.view(h, w, channels, height, width) batch_image = batch_image.permute(2, 0, 3, 1, 4) # n, height, n, width, c batch_image = batch_image.contiguous().view(channels, h * height, w * width) return batch_image def save_scatter_plot(save_path, all_df, x_lim=None, y_lim=None): plt.clf() markers = {"data": "o", "center": "s", "sample": "X"} scatter_kws={"zorder": 10} all_df = all_df.sort_values(by=['size', 'label']) fig, ax = plt.subplots(1) size = all_df['size'].values sns.scatterplot(x='x', y='y', hue='label', s=size, style="style", data=all_df, ax=ax, markers=markers, palette="tab10") # lim = (all_df[['x', 'y']].values.min()-0.05, all_df[['x', 'y']].values.max()+0.05) if x_lim is not None: ax.set_xlim(x_lim) if y_lim is not None: ax.set_ylim(y_lim) #ax.set_aspect('equal') sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1)) plt.savefig(save_path, bbox_inches='tight') plt.close('all') def save_scatter_subplot(save_path, titles, all_df1, all_df2, all_df3=None, x_lim=None, y_lim=None, base_width=5, palettes=None): plt.clf() plot_count = 2 if all_df3 is None else 3 assert len(titles) == plot_count if palettes is None: palettes = ['flare', 'viridis', 'flare'] else: assert len(palettes) >= plot_count plot_width = plot_count * base_width sns.set(rc={"figure.figsize":(plot_width, base_width)}) markers = {"data": "o", "center": "s", "sample": "X"} all_df1 = all_df1.sort_values(by=['size', 'label']) all_df2 = all_df2.sort_values(by=['size', 'label']) size1 = all_df1['size'].values size2 = all_df2['size'].values if all_df3 is not None: all_df3 = all_df3.sort_values(by=['size', 'label']) size3 = all_df3['size'].values fig, axes = plt.subplots(1, plot_count) sp1 = sns.scatterplot(x='x', y='y', hue='label', s=size1, style="style", data=all_df1, ax=axes[0], markers=markers, palette=palettes[0], legend=False) sp1.set(title=titles[0]) sp2 = sns.scatterplot(x='x', y='y', hue='label', s=size2, style="style", data=all_df2, ax=axes[1], markers=markers, palette=palettes[1], legend=False) sp2.set(yticklabels=[]) sp2.set(ylabel=None) sp2.set(title=titles[1]) if all_df3 is not None: sp3 = sns.scatterplot(x='x', y='y', hue='label', s=size3, style="style", data=all_df3, ax=axes[2], markers=markers, palette=palettes[2], legend=False) sp3.set(yticklabels=[]) sp3.set(ylabel=None) sp3.set(title=titles[2]) if x_lim is not None: axes[0].set_xlim(x_lim) axes[1].set_xlim(x_lim) if all_df3 is not None: axes[2].set_xlim(x_lim) if y_lim is not None: axes[0].set_ylim(y_lim) axes[1].set_ylim(y_lim) if all_df3 is not None: axes[2].set_ylim(y_lim) #ax.set_aspect('equal') #sns.move_legend(axes[0], "lower center", bbox_to_anchor=(1, 1)) #sns.move_legend(axes[1], "lower center", bbox_to_anchor=(1, 1)) plt.savefig(save_path, bbox_inches='tight') plt.close(fig) def save_mega_subplot(save_path, plots, col_titles, row_titles=None): plt.clf() #assert all(len(p) == len(col_titles) for p in plots), "Provide equal size lists: {}={}".format(len(plots[0]), len(col_titles)) row_count = len(plots) col_count = len(plots[0]) idx = 0 for row_idx in range(row_count): plot_row = plots[row_idx] for col_idx in range(col_count): if isinstance(plot_row, list): if col_idx >= len(plot_row): break else: if col_idx >= plot_row.size(0): break idx += 1 plt.subplot(row_count, col_count, idx) plt.imshow(plot_row[col_idx].squeeze()) if row_idx == 0: plt.title(col_titles[col_idx]) #plt.subplot(row_count, col_count, idx).set_ylabel(row_titles[row_idx]) if row_titles is not None and col_idx == 0: #plt.title() plt.ylabel(row_titles[row_idx], rotation=0) #Remove borders for pos in ['right', 'top', 'bottom', 'left']: plt.gca().spines[pos].set_visible(False) plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) plt.savefig(save_path, bbox_inches='tight') plt.close('all')