borderownership / src / rf_mapping / images / image_reconstruction_script.py
image_reconstruction_script.py
Raw
"""
TODO:

Tony Fu, August 29 2022
"""
import os
import sys

import numpy as np
from torchvision import models
# from torchvision.models import AlexNet_Weights, VGG16_Weights
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

sys.path.append('../../..')
from src.rf_mapping.hook import ConvUnitCounter
from src.rf_mapping.image import preprocess_img_for_plot, make_box, preprocess_img_to_tensor
from src.rf_mapping.spatial import SpatialIndexConverter
from src.rf_mapping.guided_backprop import GuidedBackprop
import src.rf_mapping.constants as c
from src.rf_mapping.net import get_truncated_model

# Please specify some details here:
model = models.alexnet().to(c.DEVICE)
model_name = 'alexnet'
# model = models.vgg16().to(c.DEVICE)
# model_name = "vgg16"
# model = models.resnet18().to(c.DEVICE)
# model_name = "resnet18"

image_size = (227, 227)
img_idx = 44011
layer_idx = 8
unit_idx = 0
pos_idx = 100

# Please double-check the directories:
img_dir = c.IMG_DIR
rank_dir = os.path.join(c.REPO_DIR, 'results', 'ground_truth', 'top_n', model_name)
result_dir = rank_dir

# Get info of conv layers.
converter = SpatialIndexConverter(model, image_size)
unit_counter = ConvUnitCounter(model)
layer_indices, nums_units = unit_counter.count()
num_layers = len(layer_indices)

img = np.load(os.path.join(img_dir, f"{img_idx}.npy"))
img_tensor = preprocess_img_to_tensor(img)

truncated_model = get_truncated_model(model, layer_idx)
y = truncated_model(img_tensor)
feature_map = y[0, unit_idx].detach().numpy()
plt.imshow(feature_map)