borderownership / src / rf_mapping / fnat / calculate_fnat.py
calculate_fnat.py
Raw
"""
Code to compute the Color Rotation Indices (CRI) of the units.

Tony Fu, August 19, 2022
"""
import os
import sys

import numpy as np
import pandas as pd
import torch.nn as nn
from scipy.stats import pearsonr
import torchvision.models as models
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

sys.path.append('../../..')
import src.rf_mapping.constants as c
from src.rf_mapping.spatial import get_rf_sizes
from src.rf_mapping.result_txt_format import (CenterReponses as CR)
from src.rf_mapping.hook import ConvUnitCounter

# Please specify some details here:
model = models.alexnet(pretrained=True).to(c.DEVICE)
model_name = 'alexnet'
# model = models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1).to(c.DEVICE)
# model_name = "vgg16"
# model = models.resnet18(pretrained=True).to(c.DEVICE)
# model_name = "resnet18"

top_n_r = 10
this_is_a_test_run = True
rfmp_name = 'rfmp4a'

# Please double-check the directories:
gt_response_dir = os.path.join(c.REPO_DIR, 'results', 'ground_truth', 'top_n', model_name)
rfmp_response_dir = os.path.join(c.REPO_DIR, 'results', rfmp_name, 'mapping', model_name)
result_dir = os.path.join(c.REPO_DIR, 'results', 'fnat', rfmp_name, model_name)

# Txt file to save fnat
fnat_txt_path = os.path.join(result_dir, f"{rfmp_name}_fnat_{top_n_r}_avg.txt")
if os.path.exists(fnat_txt_path):
    os.remove(fnat_txt_path)

###############################################################################

# Script guard.
# if __name__ == "__main__":
#     user_input = input("This code takes time to run. Are you sure? "\
#                        "Enter 'y' to proceed. Type any other key to stop: ")
#     if user_input == 'y':
#         pass
#     else: 
#         raise KeyboardInterrupt("Interrupted by user")


# Get info of conv layers.
unit_counter = ConvUnitCounter(model)
layer_indices, nums_units = unit_counter.count()
_, rf_sizes = get_rf_sizes(model, (227, 227), layer_type=nn.Conv2d)
num_layers = len(rf_sizes)


def config_plot(x, y):
    padding = 5
    rmin = min(x.min(), y.min()) - padding
    rmax = max(x.max(), y.max()) + padding

    plt.plot([rmin, rmax], [rmin, rmax], 'k', alpha=0.4)
    plt.xlim((rmin, rmax))
    plt.ylim((rmin, rmax))
    plt.gca().set_aspect('equal')


all_top_fnats = []
all_bot_fnats = []


pdf_path = os.path.join(result_dir, f"{rfmp_name}_fnat_{top_n_r}_avg.pdf")
with PdfPages(pdf_path) as pdf:
    plt.figure(figsize=(num_layers*5, 10))
    for conv_i in range(num_layers):
        layer_name = f"conv{conv_i+1}"

        # Load Rfmp4 center responses
        top_rfmp_path = os.path.join(rfmp_response_dir, f'{layer_name}_top5000_responses.txt')
        top_rfmp_df = pd.read_csv(top_rfmp_path, sep=" ", header=None)
        top_rfmp_df.columns = [e.name for e in CR]

        bot_rfmp_path = os.path.join(rfmp_response_dir, f'{layer_name}_bot5000_responses.txt')
        bot_rfmp_df = pd.read_csv(bot_rfmp_path, sep=" ", header=None)
        bot_rfmp_df.columns = [e.name for e in CR]
        
        # Average the top-1 and bottom-1 resposnes for each unit
        # top_rfmp_responses = top_rfmp_df.loc[(top_rfmp_df.RANK == 0), ['UNIT', 'R']]
        # avg_top_rfmp_responses = top_rfmp_responses.groupby('UNIT').mean().to_numpy()
        # avg_top_rfmp_responses = np.squeeze(avg_top_rfmp_responses, axis=1)
        avg_top_rfmp_responses = np.squeeze(top_rfmp_df.loc[(top_rfmp_df.RANK == 0), ['R']].to_numpy(), axis=1)
        
        # bot_rfmp_responses = bot_rfmp_df.loc[(bot_rfmp_df.RANK == 0), ['UNIT', 'R']]
        # avg_bot_rfmp_responses = bot_rfmp_responses.groupby('UNIT').mean().to_numpy()
        # avg_bot_rfmp_responses = np.squeeze(avg_bot_rfmp_responses, axis=1)
        avg_bot_rfmp_responses = np.squeeze(bot_rfmp_df.loc[(bot_rfmp_df.RANK == 0), ['R']].to_numpy(), axis=1)
        
        print(f"avg top rfmp has {np.sum(avg_top_rfmp_responses < 1):3} units less than 1")
        print(f"avg bot rfmp has {np.sum(avg_bot_rfmp_responses > -1):3} units greater than -1")
        
        # Load the GT responses.
        gt_response_path = os.path.join(gt_response_dir, f"{layer_name}_responses.npy")
        gt_responses = np.load(gt_response_path)
        gt_responses = np.sort(gt_responses, axis=1)
        # Shape = [num_units, num_images, 2]. There are 2 columns:
        # 0. Max responses of the given image and unit
        # 1. Min responses of the given image and unit
        
        # Average the top- and bottom-N responses for each unit
        avg_top_gt_responses = np.mean(gt_responses[:, -top_n_r:, 0], axis=1)
        avg_bot_gt_responses = np.mean(gt_responses[:, :top_n_r, 1], axis=1)
        
        plt.subplot(2, num_layers, conv_i+1)
        plt.scatter(avg_top_gt_responses, avg_top_rfmp_responses)
        config_plot(avg_top_gt_responses, avg_top_rfmp_responses)
        plt.xlabel(f"GT")
        plt.ylabel(f"{rfmp_name}")
        r_val, _ = pearsonr(avg_top_gt_responses, avg_top_rfmp_responses)
        plt.title(f"{layer_name}, top, r = {r_val:.4f}")

        plt.subplot(2, num_layers, conv_i+1+num_layers)
        plt.scatter(avg_bot_gt_responses, avg_bot_rfmp_responses)
        config_plot(avg_bot_gt_responses, avg_bot_rfmp_responses)
        plt.xlabel(f"GT")
        plt.ylabel(f"{rfmp_name}")
        r_val, _ = pearsonr(avg_bot_gt_responses, avg_bot_rfmp_responses)
        plt.title(f"{layer_name}, bottom, r = {r_val:.4f}")
        
        # Compuate fnat
        top_fnat = avg_top_rfmp_responses / avg_top_gt_responses
        bot_fnat = avg_bot_rfmp_responses / avg_bot_gt_responses
        
        # Remove avg_gt_r that is less than 1 (or greater than -1 for bottom)
        top_fnat[avg_top_gt_responses < 1] = np.NaN
        bot_fnat[avg_bot_gt_responses > -1] = np.NaN
        
        # Remove negative rfmp responses (or positive for bottom map)
        top_fnat[avg_top_rfmp_responses < 0] = 0
        bot_fnat[avg_bot_rfmp_responses > 0] = 0

        # Append to lists for the next figure
        all_top_fnats.append(top_fnat)
        all_bot_fnats.append(bot_fnat)
        
        # Save fnat in a txt file
        with open(fnat_txt_path, 'a') as f:
            for unit_i, (t, b) in enumerate(zip(top_fnat, bot_fnat)):
                f.write(f"{layer_name} {unit_i} {t:.4f} {b:.4f}\n")
        
    pdf.savefig()
    plt.show()
    plt.close()
    

    plt.figure(figsize=(num_layers*5, 10))
    plt.suptitle("Distribution of fnat", fontsize=24)
    for conv_i in range(num_layers):
        layer_name = f"conv{conv_i+1}"

        plt.subplot(2, num_layers, conv_i+1)
        plt.hist(all_top_fnats[conv_i])
        plt.title(f"{layer_name}, top (n = {np.sum(np.isfinite(all_top_fnats[conv_i]))})", fontsize=14)
        
        plt.subplot(2, num_layers, conv_i+1+num_layers)
        plt.hist(all_bot_fnats[conv_i])
        plt.title(f"{layer_name}, bottom (n = {np.sum(np.isfinite(all_bot_fnats[conv_i]))})", fontsize=14)
                
    
    pdf.savefig()
    plt.show()
    plt.close()