borderownership / src / rf_mapping / map_fit / gaussian_fit_script.py
gaussian_fit_script.py
Raw
"""
Gaussian fit script for all maps out there.

Tony Fu, Sep 19, 2022
"""
import os
import sys

import numpy as np
import torch.nn as nn
from torchvision import models
# from torchvision.models import AlexNet_Weights, VGG16_Weights
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from tqdm import tqdm
from scipy.ndimage.filters import gaussian_filter

sys.path.append('../../..')
from src.rf_mapping.gaussian_fit import (gaussian_fit,
                                        calc_f_explained_var,
                                        theta_to_ori)
from src.rf_mapping.gaussian_fit import GaussianFitParamFormat as ParamFormat
from src.rf_mapping.hook import ConvUnitCounter
from src.rf_mapping.spatial import get_rf_sizes
from src.rf_mapping.reproducibility import set_seeds
import src.rf_mapping.constants as c


# Please specify some details here:
# set_seeds()
model = models.alexnet(pretrained=True).to(c.DEVICE)
model_name = 'alexnet'
# model = models.vgg16(pretrained=True).to(c.DEVICE)
# model_name = "vgg16"
# model = models.resnet18(pretrained=True).to(c.DEVICE)
# model_name = "resnet18"

this_is_a_test_run = False
is_random = False
map_name = 'block'
sigma_rf_ratio = 1/30

###############################################################################

# Script guard
if __name__ == "__main__":
    print("Look for a prompt.")
    user_input = input("This code may take time to run. Are you sure? [y/n] ")
    if user_input == 'y':
        pass
    else: 
        raise KeyboardInterrupt("Interrupted by user")


# Get info of conv layers.
unit_counter = ConvUnitCounter(model)
layer_indices, nums_units = unit_counter.count()
_, rf_sizes = get_rf_sizes(model, (227, 227), layer_type=nn.Conv2d)

#############################  HELPER FUNCTIONS  ##############################

def write_txt(f, layer_name, unit_i, raw_params, explained_variance, map_size):
    # Unpack params
    amp = raw_params[ParamFormat.A_IDX]
    mu_x = raw_params[ParamFormat.MU_X_IDX]
    mu_y = raw_params[ParamFormat.MU_Y_IDX]
    sigma_1 = raw_params[ParamFormat.SIGMA_1_IDX]
    sigma_2 = raw_params[ParamFormat.SIGMA_2_IDX]
    theta = raw_params[ParamFormat.THETA_IDX]
    offset = raw_params[ParamFormat.OFFSET_IDX]
    
    # Some primitive processings:
    # (1) move original from top-left to map center.
    mu_x = mu_x - (map_size/2)
    mu_y = mu_y - (map_size/2)
    # (2) take the abs value of sigma values.
    sigma_1 = abs(sigma_1)
    sigma_2 = abs(sigma_2)
    # (3) convert theta to orientation.
    orientation = theta_to_ori(sigma_1, sigma_2, theta)

    f.write(f"{layer_name} {unit_i} ")
    f.write(f"{mu_x:.2f} {mu_y:.2f} ")
    f.write(f"{sigma_1:.2f} {sigma_2:.2f} ")
    f.write(f"{orientation:.2f} ")
    f.write(f"{amp:.3f} {offset:.3f} ")
    f.write(f"{explained_variance:.4f}\n")


def load_maps(map_name, layer_name, max_or_min, is_random, rf_size):
    """Loads the maps of the layer."""
    mapping_dir = os.path.join(c.REPO_DIR, 'results')
    is_random_str = "_random" if is_random else ""
    
    if map_name == 'gt':
        mapping_path = os.path.join(mapping_dir,
                                    'ground_truth',
                                    f'backprop_sum{is_random_str}',
                                    model_name,
                                    'abs',
                                    f"{layer_name}_{max_or_min}.npy")
        return np.load(mapping_path)  # [unit, yn, xn]
    if map_name == 'gt_composite':
        max_mapping_path = os.path.join(mapping_dir,
                                    'ground_truth',
                                    f'backprop_sum{is_random_str}',
                                    model_name,
                                    'abs',
                                    f"{layer_name}_max.npy")
        min_mapping_path = os.path.join(mapping_dir,
                                    'ground_truth',
                                    f'backprop_sum{is_random_str}',
                                    model_name,
                                    'abs',
                                    f"{layer_name}_min.npy")
        max_maps = np.load(max_mapping_path)
        min_maps = np.load(min_mapping_path) 
        return  max_maps + min_maps # [unit, yn, xn]
    elif map_name == 'occlude':
        mapping_path = os.path.join(mapping_dir,
                                    'occlude',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_{max_or_min}.npy")
        maps = np.load(mapping_path)  # [unit, yn, xn]
        # blur the maps
        sigma = rf_size * sigma_rf_ratio
        for i in range(maps.shape[0]):
            maps[i] = gaussian_filter(maps[i], sigma=sigma)
        return maps
    elif map_name == 'occlude_composite':
        top_mapping_path = os.path.join(mapping_dir,
                                    'occlude',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_max.npy")
        bot_mapping_path = os.path.join(mapping_dir,
                                    'occlude',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_min.npy")
        top_maps = np.load(top_mapping_path)  # [unit, yn, xn]
        bot_maps = np.load(bot_mapping_path)  # [unit, yn, xn]
        maps = top_maps + bot_maps
        # blur the maps
        sigma = rf_size * sigma_rf_ratio
        for i in range(maps.shape[0]):
            maps[i] = gaussian_filter(maps[i], sigma=sigma)
        return maps
    elif map_name == 'rfmp4a':
        mapping_path = os.path.join(mapping_dir,
                                    'rfmp4a',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_weighted_{max_or_min}_barmaps.npy")
        return np.load(mapping_path)  # [unit, yn, xn]
    elif map_name == 'rfmp4c7o':
        mapping_path = os.path.join(mapping_dir,
                                    'rfmp4c7o',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_weighted_{max_or_min}_barmaps.npy")
        maps = np.load(mapping_path)  # [unit, 3, yn, xn]
        return np.mean(maps, axis=1)
    elif map_name == 'rfmp_sin1':
        mapping_path = os.path.join(mapping_dir,
                                    'rfmp_sin1',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_weighted_{max_or_min}_sinemaps.npy")
        return np.load(mapping_path)  # [unit, yn, xn]
    elif map_name == 'pasu':
        mapping_path = os.path.join(mapping_dir,
                                    'pasu',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_weighted_{max_or_min}_shapemaps.npy")
        return np.load(mapping_path)  # [unit, yn, xn]
    elif map_name == 'block':
        mapping_path = os.path.join(mapping_dir,
                                    'block',
                                    'mapping',
                                    model_name,
                                    f"{layer_name}_weighted_{max_or_min}_blockmaps.npy")
        maps = np.load(mapping_path)  # [unit, 3, yn, xn]
        return np.mean(maps, axis=1)
    else:
        raise KeyError(f"{map_name} does not exist.")


def get_result_dir(map_name, is_random, this_is_a_test_run):
    mapping_dir = os.path.join(c.REPO_DIR, 'results')
    is_random_str = "_random" if is_random else ""

    if map_name in ('gt', 'gt_composite'):
        if this_is_a_test_run:
            return os.path.join(mapping_dir,
                                'ground_truth',
                                f'gaussian_fit{is_random_str}',
                                model_name,
                                'test')
        else:
            return os.path.join(mapping_dir,
                                'ground_truth',
                                f'gaussian_fit{is_random_str}',
                                model_name,
                                'abs')
    elif map_name in ('occlude', 'occlude_composite', 'rfmp4a', 'rfmp4c7o',
                      'rfmp_sin1', 'pasu', 'block'):
        if map_name == 'occlude_composite':
            map_name = 'occlude'
        if this_is_a_test_run:
            return os.path.join(mapping_dir,
                                map_name,
                                'gaussian_fit',
                                'test')
        else:
            return os.path.join(mapping_dir,
                                map_name,
                                'gaussian_fit',
                                model_name)
    else:
        raise KeyError(f"{map_name} does not exist.")


###############################################################################

result_dir = get_result_dir(map_name, is_random, this_is_a_test_run)

if map_name in ('gt_composite', 'occlude_composite'):
    composite_file_path = os.path.join(result_dir, f"{model_name}_{map_name}_gaussian.txt")
    # Delete previous files.
    if os.path.exists(composite_file_path):
        os.remove(composite_file_path)
else:
    top_file_path = os.path.join(result_dir, f"{model_name}_{map_name}_gaussian_top.txt")
    bot_file_path = os.path.join(result_dir, f"{model_name}_{map_name}_gaussian_bot.txt")
    # Delete previous files.
    if os.path.exists(top_file_path):
        os.remove(top_file_path)
    if os.path.exists(bot_file_path):
        os.remove(bot_file_path)

for conv_i in range(len(layer_indices)):
    layer_name = f"conv{conv_i + 1}"
    num_units = nums_units[conv_i]
    print(f"Fitting elliptical Gaussian for {layer_name}...")

    # For param_cleaner to check if Gaussian is inside in RF or not.
    rf_size = rf_sizes[conv_i][0]
    box = (0, 0, rf_size, rf_size)

    # Load backprop sums:
    max_maps = load_maps(map_name, layer_name, 'max', is_random, rf_size)
    min_maps = load_maps(map_name, layer_name, 'min', is_random, rf_size)

    pdf_path = os.path.join(result_dir, f"{layer_name}_gaussian_fit.pdf")
    with PdfPages(pdf_path) as pdf:
        
        for unit_i, (max_map, min_map) in enumerate(tqdm(zip(max_maps, min_maps))):
            # Do only the first 5 unit during testing phase
            if this_is_a_test_run and unit_i >= 5:
                break
            
            if map_name in ('gt_composite', 'occlude_composite'):
                # Fit 2D Gaussian, and plot them.
                plt.figure(figsize=(8, 10))
                plt.suptitle(f"Elliptical Gaussian fit ({layer_name} no.{unit_i})", fontsize=20)

                plt.subplot(1, 1, 1)
                params, sems = gaussian_fit(max_map, plot=True, show=False)
                fxvar = calc_f_explained_var(max_map, params)
                with open(composite_file_path, 'a') as composite_f:
                    write_txt(composite_f, layer_name, unit_i, params, fxvar, rf_size)
                plt.title(f"composite (fxvar = {fxvar:.4f})\n"
                            f"A={params[ParamFormat.A_IDX]:.2f}(err={sems[ParamFormat.A_IDX]:.2f}), "
                            f"mu_x={params[ParamFormat.MU_X_IDX]:.2f}(err={sems[ParamFormat.MU_X_IDX]:.2f}), "
                            f"mu_y={params[ParamFormat.MU_Y_IDX]:.2f}(err={sems[ParamFormat.MU_Y_IDX]:.2f}),\n"
                            f"sigma_1={params[ParamFormat.SIGMA_1_IDX]:.2f}(err={sems[ParamFormat.SIGMA_1_IDX]:.2f}), "
                            f"sigma_2={params[ParamFormat.SIGMA_2_IDX]:.2f}(err={sems[ParamFormat.SIGMA_2_IDX]:.2f}),\n"
                            f"theta={params[ParamFormat.THETA_IDX]:.2f}(err={sems[ParamFormat.THETA_IDX]:.2f}), "
                            f"offset={params[ParamFormat.OFFSET_IDX]:.2f}(err={sems[ParamFormat.OFFSET_IDX]:.2f})",
                            fontsize=14)

                pdf.savefig()
                plt.close()
                
            else:
                # Fit 2D Gaussian, and plot them.
                plt.figure(figsize=(20, 10))
                plt.suptitle(f"Elliptical Gaussian fit ({layer_name} no.{unit_i})", fontsize=20)

                plt.subplot(1, 2, 1)
                params, sems = gaussian_fit(max_map, plot=True, show=False)
                fxvar = calc_f_explained_var(max_map, params)
                with open(top_file_path, 'a') as top_f:
                    write_txt(top_f, layer_name, unit_i, params, fxvar, rf_size)
                plt.title(f"max (fxvar = {fxvar:.4f})\n"
                            f"A={params[ParamFormat.A_IDX]:.2f}(err={sems[ParamFormat.A_IDX]:.2f}), "
                            f"mu_x={params[ParamFormat.MU_X_IDX]:.2f}(err={sems[ParamFormat.MU_X_IDX]:.2f}), "
                            f"mu_y={params[ParamFormat.MU_Y_IDX]:.2f}(err={sems[ParamFormat.MU_Y_IDX]:.2f}),\n"
                            f"sigma_1={params[ParamFormat.SIGMA_1_IDX]:.2f}(err={sems[ParamFormat.SIGMA_1_IDX]:.2f}), "
                            f"sigma_2={params[ParamFormat.SIGMA_2_IDX]:.2f}(err={sems[ParamFormat.SIGMA_2_IDX]:.2f}),\n"
                            f"theta={params[ParamFormat.THETA_IDX]:.2f}(err={sems[ParamFormat.THETA_IDX]:.2f}), "
                            f"offset={params[ParamFormat.OFFSET_IDX]:.2f}(err={sems[ParamFormat.OFFSET_IDX]:.2f})",
                            fontsize=14)

                plt.subplot(1, 2, 2)
                params, sems = gaussian_fit(min_map, plot=True, show=False)
                fxvar = calc_f_explained_var(min_map, params)
                with open(bot_file_path, 'a') as bot_f:
                    write_txt(bot_f, layer_name, unit_i, params, fxvar, rf_size)
                plt.title(f"min (fxvar = {fxvar:.4f})\n"
                            f"A={params[ParamFormat.A_IDX]:.2f}(err={sems[ParamFormat.A_IDX]:.2f}), "
                            f"mu_x={params[ParamFormat.MU_X_IDX]:.2f}(err={sems[ParamFormat.MU_X_IDX]:.2f}), "
                            f"mu_y={params[ParamFormat.MU_Y_IDX]:.2f}(err={sems[ParamFormat.MU_Y_IDX]:.2f}),\n"
                            f"sigma_1={params[ParamFormat.SIGMA_1_IDX]:.2f}(err={sems[ParamFormat.SIGMA_1_IDX]:.2f}), "
                            f"sigma_2={params[ParamFormat.SIGMA_2_IDX]:.2f}(err={sems[ParamFormat.SIGMA_2_IDX]:.2f}),\n"
                            f"theta={params[ParamFormat.THETA_IDX]:.2f}(err={sems[ParamFormat.THETA_IDX]:.2f}), "
                            f"offset={params[ParamFormat.OFFSET_IDX]:.2f}(err={sems[ParamFormat.OFFSET_IDX]:.2f})",
                            fontsize=14)

                pdf.savefig()
                plt.close()