borderownership / src / rf_mapping / occlude / occlude_script.py
occlude_script.py
Raw
import os
import sys

import numpy as np
from torchvision import models
# from torchvision.models import AlexNet_Weights, VGG16_Weights
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

sys.path.append('../../..')
from src.rf_mapping.hook import ConvUnitCounter
from src.rf_mapping.spatial import SpatialIndexConverter, get_rf_sizes
from src.rf_mapping.net import get_truncated_model
import src.rf_mapping.constants as c
from src.rf_mapping.reproducibility import set_seeds
from src.rf_mapping.occluder_discrepancy import (get_occluder_params,
                                                 get_random_discrepancy_map)


# Please specify some details here:
model = models.alexnet(pretrained=True).to(c.DEVICE)
model_name = "alexnet"
# model = models.vgg16(pretrained=True).to(c.DEVICE)
# model_name = "vgg16"
# model = models.resnet18(pretrained=True).to(c.DEVICE)
# model_name = "resnet18"
top_n = 100
image_size = (227, 227)
this_is_a_test_run = False
batch_size = 10

# Please double-check the directories:
img_dir = c.IMG_DIR
index_dir = os.path.join(c.REPO_DIR, 'results', 'ground_truth', 'top_n', model_name)
if this_is_a_test_run:
    result_dir = os.path.join(c.REPO_DIR, 'results', 'occlude', 'mapping', 'test')
else:
    result_dir = os.path.join(c.REPO_DIR, 'results', 'occlude', 'mapping', model_name)

###############################################################################

# Script guard
# if __name__ == "__main__":
#     print("Look for a prompt.")
#     user_input = input("This code may take hours to run. Are you sure? [y/n] ")
#     if user_input == 'y':
#         pass
#     else: 
#         raise KeyboardInterrupt("Interrupted by user")

# Initiate helper objects.
converter = SpatialIndexConverter(model, image_size)
conv_counter = ConvUnitCounter(model)

# Get info of conv layers.
layer_indices, nums_units = conv_counter.count()
layer_indices, rf_sizes = get_rf_sizes(model, image_size)
top_n_to_plot = min(top_n, 5)
set_seeds()  # For reproducibility

if __name__ == "__main__":
    for conv_i, layer_idx in enumerate(layer_indices):
        truncated_model = get_truncated_model(model, layer_idx)
        layer_name = f"conv{conv_i + 1}"
        index_path = os.path.join(index_dir, f"{layer_name}.npy")
        max_min_indices = np.load(index_path).astype(int)
        # with dimension: [units, top_n_img, [max_img_idx, max_idx, min_img_idx, min_idx]]
        
        rf_size = rf_sizes[conv_i][0]
        num_units = nums_units[conv_i]
        
        # Array intializations
        max_discrepancy_maps = np.zeros((num_units, rf_size, rf_size))
        min_discrepancy_maps = np.zeros((num_units, rf_size, rf_size))

        pdf_path = os.path.join(result_dir, f"{layer_name}.pdf")
        with PdfPages(pdf_path) as pdf:
            fig, plt_axes = plt.subplots(2, top_n_to_plot)
            fig.set_size_inches(5 * top_n_to_plot, 8)
            
            # Collect axis and imshow handles in a list.
            ax_handles = []
            im_handles = []
            for ax_row in plt_axes:
                for ax in ax_row:
                    ax_handles.append(ax)
                    im_handles.append(ax.imshow(np.zeros((rf_size, rf_size)),
                                                vmin=0, vmax=1, cmap='gray'))

            for unit_i in tqdm(range(num_units)):
                if this_is_a_test_run and unit_i > 2:
                    break
                # sys.stdout.write('\r')
                # sys.stdout.write(f"Making pdf for {layer_name} no.{unit_i}...")
                # sys.stdout.flush()
                # print()

                fig.suptitle(f"{layer_name} unit no.{unit_i}", fontsize=20)
                # Get top and bottom image indices and patch spatial indices
                max_n_img_indices   = max_min_indices[unit_i, :top_n, 0]
                max_n_patch_indices = max_min_indices[unit_i, :top_n, 1]
                min_n_img_indices   = max_min_indices[unit_i, :top_n, 2]
                min_n_patch_indices = max_min_indices[unit_i, :top_n, 3]

                # Top N images and gradient patches:
                for i, (max_img_idx, max_patch_idx) in enumerate(zip(max_n_img_indices,
                                                                max_n_patch_indices)):
                    img_path = os.path.join(img_dir, f"{max_img_idx}.npy")
                    img = np.load(img_path)
                    box = converter.convert(max_patch_idx, layer_idx, 0, is_forward=False)
                    occluder_params = get_occluder_params(box, rf_size, image_size)
                    discrepancy_map = get_random_discrepancy_map(img, occluder_params, 
                                                          truncated_model, rf_size,
                                                          max_patch_idx, unit_i, box,
                                                          batch_size=batch_size, _debug=this_is_a_test_run, image_size=image_size)
                    max_discrepancy_maps[unit_i] += discrepancy_map / top_n
                    if i < top_n_to_plot:
                        im_handles[i].set_data(discrepancy_map/discrepancy_map.max())
                        ax_handles[i].set_title(f"top {i+1} image")
                        

                # Bottom N images and gradient patches:
                for i, (min_img_idx, min_patch_idx) in enumerate(zip(min_n_img_indices,
                                                                min_n_patch_indices)):
                    img_path = os.path.join(img_dir, f"{min_img_idx}.npy")
                    img = np.load(img_path)
                    box = converter.convert(min_patch_idx, layer_idx, 0, is_forward=False)
                    occluder_params = get_occluder_params(box, rf_size, image_size)
                    discrepancy_map = get_random_discrepancy_map(img, occluder_params, 
                                                          truncated_model, rf_size,
                                                          min_patch_idx, unit_i, box,
                                                          batch_size=batch_size, _debug=this_is_a_test_run, image_size=image_size)
                    min_discrepancy_maps[unit_i] += discrepancy_map / top_n
                    if i < top_n_to_plot:
                        im_handles[i+top_n_to_plot].set_data(discrepancy_map/discrepancy_map.max())
                        ax_handles[i+top_n_to_plot].set_title(f"bottom {i+1} image")

                plt.show()
                pdf.savefig(fig)
                plt.close()
        
        # Save the results in npy files.
        max_map_path = os.path.join(result_dir, f'{layer_name}_max.npy')
        min_map_path = os.path.join(result_dir, f'{layer_name}_min.npy')
        np.save(max_map_path, max_discrepancy_maps)
        np.save(min_map_path, min_discrepancy_maps)