borderownership / src / rf_mapping / pasu / pasu_shape_rank.py
pasu_shape_rank.py
Raw
"""
Plot the shapes in descending order of its responses.

Note: all code assumes that the y-axis points downward.

Tony Fu, September 12th, 2022
"""
import os
import sys

from torchvision import models
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from tqdm import tqdm

sys.path.append('../../..')
import src.rf_mapping.constants as c
from src.rf_mapping.pasu_shape import make_pasu_shape
from src.rf_mapping.spatial import get_num_layers
from src.rf_mapping.result_txt_format import (CenterReponses as CR,
                                              PasuBWSplist as SP,)

# Please specify some details here:
model = models.alexnet(pretrained=True).to(c.DEVICE)
model_name = 'alexnet'
# model = models.vgg16(pretrained=True).to(c.DEVICE)
# model_name = 'vgg16'
# model = models.resnet18(pretrained=True).to(c.DEVICE)
# model_name = "resnet18"
this_is_a_test_run = False
top_n = 10

# Please double-check the directories:
if this_is_a_test_run:
    source_dir = os.path.join(c.REPO_DIR, 'results', 'pasu', 'mapping', 'test')
else:
    source_dir = os.path.join(c.REPO_DIR, 'results', 'pasu', 'mapping', model_name)
    
if this_is_a_test_run:
    result_dir = os.path.join(c.REPO_DIR, 'results', 'pasu', 'analysis', 'test')
else:
    result_dir = os.path.join(c.REPO_DIR, 'results', 'pasu', 'analysis', model_name)
    

###############################################################################

# Script guard
# if __name__ == "__main__":
#     print("Look for a prompt.")
#     user_input = input("This code may take time to run. Are you sure? [y/n] ")
#     if user_input == 'y':
#         pass
#     else:
#         raise KeyboardInterrupt("Interrupted by user")

num_layers = get_num_layers(model)


# Define some helper functions
def set_column_names(df, Format):
    """Name the columns of the pandas DF according to Format."""
    df.columns = [e.name for e in Format]


def plot_one_shape(stim_i, im, ax):
    params = splist_df.loc[splist_df.STIM_I == stim_i]
    shape = make_pasu_shape(params.XN.item(), params.YN.item(),
                            params.X0.item(), params.Y0.item(),
                            params.SI.item(), params.RI.item(),
                            params.FGVAL.item(), params.BGVAL.item(),
                            params.SIZE.item(), plot=False)
    im.set_data(shape)
    # Remove tick marks and labels
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])


for conv_i in range(num_layers):
    layer_name = f"conv{conv_i+1}"
    
    # Load the splist and center responses.
    splist_path = os.path.join(source_dir, f"{layer_name}_splist.txt")
    top_cr_path = os.path.join(source_dir, f"{layer_name}_top5000_responses.txt")
    bot_cr_path = os.path.join(source_dir, f"{layer_name}_bot5000_responses.txt")
    
    try:
        splist_df = pd.read_csv(splist_path, sep=" ", header=None)
        top_cr_df = pd.read_csv(top_cr_path, sep=" ", header=None)
        bot_cr_df = pd.read_csv(bot_cr_path, sep=" ", header=None)
    except:
        continue
    
    set_column_names(splist_df, SP)
    set_column_names(top_cr_df, CR)
    set_column_names(bot_cr_df, CR)
    
    num_units = len(top_cr_df['UNIT'].unique())
    map_size = splist_df.loc[0, 'XN']

    pdf_path = os.path.join(result_dir, f"{model_name}_{layer_name}_top_bot_shapes.pdf")
    with PdfPages(pdf_path) as pdf:
        fig, plt_axes = plt.subplots(2, top_n)
        fig.set_size_inches(top_n * 2, 5)
        
        # Collect axis and imshow handles in a list.
        ax_handles = []
        im_handles = []
        for ax_row in plt_axes:
            for ax in ax_row:
                ax_handles.append(ax)
                im_handles.append(ax.imshow(np.zeros((map_size, map_size)),
                                            vmin=-1, vmax=1, cmap='gray'))
        
        for unit_i in tqdm(range(num_units)):
            fig.suptitle(f"{model_name} {layer_name} no.{unit_i} top/bottom Pasupathy shapes", fontsize=18)

            subplot_i = 0

            # Top-N shapes
            for rank_i in range(top_n):
                stim_i = top_cr_df.loc[(top_cr_df.UNIT == unit_i) & (top_cr_df.RANK == rank_i), 'STIM_I'].item()
                plot_one_shape(stim_i, im_handles[subplot_i], ax_handles[subplot_i])
                ax_handles[subplot_i].set_title(f"top {rank_i+1}\nshape no.{stim_i}")
                subplot_i += 1

            # Bottom-N shapes
            for rank_i in range(top_n):
                stim_i = bot_cr_df.loc[(bot_cr_df.UNIT == unit_i) & (bot_cr_df.RANK == rank_i), 'STIM_I'].item()
                plot_one_shape(stim_i, im_handles[subplot_i], ax_handles[subplot_i])
                ax_handles[subplot_i].set_title(f"bottom {rank_i+1}\nshape no.{stim_i}")
                subplot_i += 1

            pdf.savefig(fig)
            plt.close()