borderownership / src / rf_mapping / fnat / show_top_fnat_units.py
show_top_fnat_units.py
Raw
"""
To plot the natural images of the units that has a large or small FNAT.

Tony Fu, Sep 26, 2022
"""
import os
import sys

import numpy as np
import pandas as pd
import torchvision.models as models
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from tqdm import tqdm

sys.path.append('../../..')
import src.rf_mapping.constants as c
from src.rf_mapping.hook import ConvUnitCounter
from src.rf_mapping.spatial import SpatialIndexConverter
from src.rf_mapping.image import preprocess_img_for_plot, make_box
from src.rf_mapping.guided_backprop import GuidedBackprop


# Please specify the model
model = models.alexnet()
model_name = 'alexnet'
# model = models.vgg16()
# model_name = 'vgg16'
# model = models.resnet18()
# model_name = 'resnet18'

this_is_a_test_run = False
map1_name = 'gt'                # ['gt', 'occlude']
map2_name = 'rfmp4a'            # ['rfmp4a', 'rfmp4c7o', 'rfmp_sin1', 'pasu']
fnat_upper_thres = 0.8
fnat_lower_thres = 0.05
image_size = (227, 227)
top_n = 5

# Please double-check the directories:
img_dir = c.IMG_DIR
index_dir = os.path.join(c.REPO_DIR, 'results', 'ground_truth', 'top_n', model_name)
fnat_dir = os.path.join(c.REPO_DIR, 'results', 'fnat', map2_name, model_name)
result_dir = fnat_dir

###############################################################################

# Script guard
# if __name__ == "__main__":
#     print("Look for a prompt.")
#     user_input = input("This code may take time to run. Are you sure? [y/n]") 
#     if user_input == 'y':
#         pass
#     else: 
#         raise KeyboardInterrupt("Interrupted by user")

# Initiate helper objects.
converter = SpatialIndexConverter(model, image_size)
conv_counter = ConvUnitCounter(model)

# Get info of conv layers.
layer_indices, nums_units = conv_counter.count()


############################  HELPER FUNCTIONS ################################


def plot_one_img(im, ax, img_idx, box):
    """Plots the image and draw the red box."""
    # Remove patches
    try:
        ax.patches.pop()
    except:
        pass
    # Plot image
    img_path = os.path.join(img_dir, f"{img_idx}.npy")
    img = np.load(img_path)
    img = preprocess_img_for_plot(img)
    im.set_data(img)
    # Remove tick marks and labels
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.add_patch(make_box(box))


def plot_one_grad_map(im, ax, img_idx, unit_idx, patch_idx, box, grad_method):
    """Plots the target unit's gradient map for the image."""
    # Remove patches
    try:
        ax.patches.pop()
    except:
        pass
    # Plot gradients
    img_path = os.path.join(img_dir, f"{img_idx}.npy")
    img = np.load(img_path)
    gbp_map = grad_method.generate_gradients(img, unit_idx, patch_idx)
    im.set_data(preprocess_img_for_plot(gbp_map))
    # Remove tick marks and labels
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    ax.add_patch(make_box(box))
    
    
##############################  LOAD FNAT DATA  ###############################


# Load the fnat
top_n_r = 10
fnat_path = os.path.join(fnat_dir, f"{map2_name}_fnat_{top_n_r}_avg.txt")
fnat_df = pd.read_csv(fnat_path, sep=" ", header=None)
fnat_df.columns = ['LAYER', 'UNIT', 'TOP_AVG_FNAT', 'BOT_AVG_FNAT']


#################################  MAKE PDF  ##################################


# PDF 1. Visualize those with large FNAT.
pdf_path = os.path.join(result_dir, f"{model_name}_{map2_name}_fnat_greater_than_{fnat_upper_thres}.pdf")
with PdfPages(pdf_path) as pdf:
    for conv_i, layer_idx in enumerate(layer_indices):
        # Skipping Conv1
        if conv_i == 0:
            continue

        top_grad_method = GuidedBackprop(model, layer_idx, remove_neg_grad=True)
        bot_grad_method = GuidedBackprop(model, layer_idx, remove_neg_grad=False)
        layer_name = f"conv{conv_i + 1}"
        
        # Load top image indices
        index_path = os.path.join(index_dir, f"{layer_name}.npy")
        max_min_indices = np.load(index_path).astype(int)
        # with dimension: [units, top_n_img, [max_img_idx, max_idx, min_img_idx, min_idx]]
        
        # Extract CRI for this layer
        top_cri_df = fnat_df.loc[(fnat_df.LAYER == layer_name) & (fnat_df.TOP_AVG_FNAT > fnat_upper_thres), ['UNIT', 'TOP_AVG_FNAT']]
        bot_cri_df = fnat_df.loc[(fnat_df.LAYER == layer_name) & (fnat_df.BOT_AVG_FNAT > fnat_upper_thres), ['UNIT', 'BOT_AVG_FNAT']]

        top_num_units = len(top_cri_df.UNIT)
        top_num_units = len(bot_cri_df.UNIT)
        print(f"Making pdf for {layer_name}...")

        fig, plt_axes = plt.subplots(2, top_n)
        fig.set_size_inches(20, 8)
        
        # Collect axis and imshow handles in a list.
        ax_handles = []
        im_handles = []
        for ax_row in plt_axes:
            for ax in ax_row:
                ax_handles.append(ax)
                im_handles.append(ax.imshow(np.zeros((*image_size, 3))))

        # Plot the natural images of the top units.
        for _, unit_data in tqdm(top_cri_df.iterrows()):
            unit_i = int(unit_data['UNIT'])
            this_fnat = unit_data['TOP_AVG_FNAT']
            
            fig.suptitle(f"{layer_name} unit no.{unit_i} (fnat = {this_fnat})", fontsize=20)
            # Get top and bottom image indices and patch spatial indices
            max_n_img_indices   = max_min_indices[unit_i, :top_n, 0]
            max_n_patch_indices = max_min_indices[unit_i, :top_n, 1]

            # Top N images and gradient patches:
            for i, (max_img_idx, max_patch_idx) in enumerate(zip(max_n_img_indices,
                                                              max_n_patch_indices)):
                box = converter.convert(max_patch_idx, layer_idx, 0, is_forward=False)
                plot_one_img(im_handles[i], ax_handles[i], max_img_idx, box)
                ax_handles[i].set_title(f"top {i+1} image")

                plot_one_grad_map(im_handles[i+top_n], ax_handles[i+top_n],
                                  max_img_idx, unit_i, max_patch_idx,
                                  box, top_grad_method)
                ax_handles[i+top_n].set_title(f"top {i+1} gradient")

            pdf.savefig(fig)
            plt.close()

        # TODO: show the natural images of the bottom units as well.


# PDF 2. Visualize those with small FNAT.
pdf_path = os.path.join(result_dir, f"{model_name}_{map2_name}_fnat_less_than_{fnat_lower_thres}.pdf")
with PdfPages(pdf_path) as pdf:
    for conv_i, layer_idx in enumerate(layer_indices):
        # Skipping Conv1
        if conv_i == 0:
            continue

        top_grad_method = GuidedBackprop(model, layer_idx, remove_neg_grad=True)
        bot_grad_method = GuidedBackprop(model, layer_idx, remove_neg_grad=False)
        layer_name = f"conv{conv_i + 1}"
        
        # Load top image indices
        index_path = os.path.join(index_dir, f"{layer_name}.npy")
        max_min_indices = np.load(index_path).astype(int)
        # with dimension: [units, top_n_img, [max_img_idx, max_idx, min_img_idx, min_idx]]
        
        # Extract CRI for this layer
        top_cri_df = fnat_df.loc[(fnat_df.LAYER == layer_name) & (fnat_df.TOP_AVG_FNAT < fnat_lower_thres), ['UNIT', 'TOP_AVG_FNAT']]
        bot_cri_df = fnat_df.loc[(fnat_df.LAYER == layer_name) & (fnat_df.BOT_AVG_FNAT < fnat_lower_thres), ['UNIT', 'BOT_AVG_FNAT']]

        top_num_units = len(top_cri_df.UNIT)
        top_num_units = len(bot_cri_df.UNIT)
        print(f"Making pdf for {layer_name}...")

        fig, plt_axes = plt.subplots(2, top_n)
        fig.set_size_inches(20, 8)
        
        # Collect axis and imshow handles in a list.
        ax_handles = []
        im_handles = []
        for ax_row in plt_axes:
            for ax in ax_row:
                ax_handles.append(ax)
                im_handles.append(ax.imshow(np.zeros((*image_size, 3))))

        # Plot the natural images of the top units.
        for _, unit_data in tqdm(top_cri_df.iterrows()):
            unit_i = int(unit_data['UNIT'])
            this_fnat = unit_data['TOP_AVG_FNAT']
            
            fig.suptitle(f"{layer_name} unit no.{unit_i} (fnat = {this_fnat})", fontsize=20)
            # Get top and bottom image indices and patch spatial indices
            max_n_img_indices   = max_min_indices[unit_i, :top_n, 0]
            max_n_patch_indices = max_min_indices[unit_i, :top_n, 1]

            # Top N images and gradient patches:
            for i, (max_img_idx, max_patch_idx) in enumerate(zip(max_n_img_indices,
                                                              max_n_patch_indices)):
                box = converter.convert(max_patch_idx, layer_idx, 0, is_forward=False)
                plot_one_img(im_handles[i], ax_handles[i], max_img_idx, box)
                ax_handles[i].set_title(f"top {i+1} image")

                plot_one_grad_map(im_handles[i+top_n], ax_handles[i+top_n],
                                  max_img_idx, unit_i, max_patch_idx,
                                  box, top_grad_method)
                ax_handles[i+top_n].set_title(f"top {i+1} gradient")

            pdf.savefig(fig)
            plt.close()

        # TODO: show the natural images of the bottom units as well.