borderownership / src / rf_mapping / mapping.py
mapping.py
Raw
"""
Receptive field mapping paradigms.

Tony Fu, July 8, 2022
"""
import os
import sys
import copy
import math

import numpy as np
import torch
import torch.nn as nn
from torchvision import models
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

sys.path.append('../..')
from src.rf_mapping.spatial import (get_conv_output_shapes,
                                    calculate_center,
                                    get_rf_sizes,
                                    RfGrid,
                                    SpatialIndexConverter,)
from src.rf_mapping.image import make_box
from src.rf_mapping.hook import ConvUnitCounter
from src.rf_mapping.bar import stimfr_bar, stimset_dict_rfmp_4a
import src.rf_mapping.constants as c
from src.rf_mapping.net import get_truncated_model


#######################################.#######################################
#                                                                             #
#                                RF Mapper Base                               #
#                                                                             #
###############################################################################
class RfMapper:
    def __init__(self, model, conv_i, image_shape):
        """
        The base class of receptive fields mapping of a single convolutional
        layer. Initializes the information needed about the conv layer.

        Parameters
        ----------
        model : torchvision.Module
            The pretrained neural network.
        conv_i : int
            The index of the convolutional layer starting from zero. That is,
            Conv1 should be 0. 
        image_shape : (int, int)
            The dimension of the image in (yn, xn) format (pix).
        """
        self.conv_i = conv_i
        self.yn, self.xn = image_shape
        print(f"The RF mapper is for Conv{conv_i + 1} (not Conv{conv_i}) "
              f"with input shape (yn = {self.yn}, xn = {self.xn}).")

        # Get basic info about the conv layer.
        layer_indices, rf_sizes = get_rf_sizes(model, (self.yn, self.xn),
                                               layer_type=nn.Conv2d)
        self.rf_size = rf_sizes[conv_i][0]
        self.num_units = self._get_num_units()
        self.layer_idx = layer_indices[conv_i]  # See hook.py module for 
                                                # indexing convention.
        self.truncated_model = get_truncated_model(model, self.layer_idx)

        # Locate spatial center of the layer's output.
        self.output_shape = self._get_output_shape()
        self.output_yc, self.output_xc = calculate_center(self.output_shape)

        # Locate RF in the pixel space.
        self.converter = SpatialIndexConverter(model, (self.yn, self.xn))
        self.box = self.converter.convert((self.output_yc, self.output_xc),
                                           self.layer_idx, 0, is_forward=False)
        self.box_yc, self.box_xc = self._get_box_center(self.box)

    def _get_num_units(self):
        """Finds how many unique units/kernels are in this layer."""
        unit_counter = ConvUnitCounter(self.model)
        _, nums_units = unit_counter.count()
        return nums_units[self.conv_i]

    def _get_output_shape(self):
        """Finds the shape of output layer in (yn, xn) format."""
        conv_output_shapes = get_conv_output_shapes(self.model,
                                                    (self.yn, self.xn))
        return np.array(conv_output_shapes[self.conv_i][-2:])

    def _get_box_center(self, box):
        """Finds the center coordinates of the box in (y, x) format."""
        y_min, x_min, y_max, x_max = box
        xc = (x_min + x_max)//2
        yc = (y_min + y_max)//2
        return yc, xc

    def _print_progress(self, progess, pre_text='progress = ', post_text=''):
        """
        Prints progress (whatever quantity) without printing a new line
        everytime.
        """
        sys.stdout.write('\r')
        sys.stdout.write(f"{pre_text}{progess}{post_text}")
        sys.stdout.flush()
    

#######################################.#######################################
#                                                                             #
#                                BAR RF MAPPER                                #
#                                                                             #
###############################################################################
class BarRfMapper(RfMapper):
    def __init__(self, model, conv_i, image_shape):
        """
        The class of receptive fields mapping of a single convolutional layer
        using bar stimuli. For each unit (a.k.a. kernel), only the response of
        the spatial center is recorded.
        
        Parameters
        ----------
        model : torchvision.Module
            The neural network.
        conv_i : int
            The index of the convolutional layer starting from zero. That is,
            Conv1 should be 0. 
        image_shape : (int, int)
            The dimension of the image in (yn, xn) format (pix).
        """
        super().__init__(model, conv_i, image_shape)
        self.cumulate_threshold = None
    
    def _bar_full_image(self, bar, bgval):
        bar_full = np.full((self.yn, self.xn), bgval)
        y_min, x_min, y_max, x_max = self.box
        bar_full[y_min:y_max+1, x_min:x_max+1] = bar
        return bar_full

    def _bar_to_tensor(self, bar):
        bar_3_chan = np.zeros((bar.shape[0], 3, bar.shape[1], bar.shape[2]))
        bar_3_chan[:, 0, :, :] = bar
        bar_3_chan[:, 1, :, :] = bar
        bar_3_chan[:, 2, :, :] = bar
        return torch.tensor(bar_3_chan).type('torch.FloatTensor')
    
    def _get_center_responses(self, input):
        """Gets the responses of the spatial centers of all bars and units."""
        input_tensor = self._bar_to_tensor(input)
        y = self.truncated_model(input_tensor)
        return y[:, :, self.output_yc, self.output_xc].cpu().detach().numpy()

    def map(self):
        raise NotImplementedError("The map method is not implemented.")
    

#######################################.#######################################
#                                                                             #
#                            RF MAPPING PARADIGM 4a                           #
#      Get the top N most activating bars, take the absolute value, and       #
#         add them up or apply pixel-wise OR to the cumulative map.           #
#                                                                             #
###############################################################################
class BarRfMapperP4a(BarRfMapper):
    def __init__(self, model, conv_i, image_shape, percent_max_min_to_cumulate=0.1):
        super().__init__(model, conv_i, image_shape)

        # Bar parameters
        self.stim_dicts = stimset_dict_rfmp_4a(self.rf_size, self.rf_size)
        self.num_stim = len(self.stim_dicts)

        # Mapping parameters
        self.percent_max_min_to_cumulate = percent_max_min_to_cumulate
        self.batch_size = 100
        self.bar_thres = 0.2
        
        # Debugging parameters
        self.DEBUG = False
        self.DEBUG_NUM_UNITS = 10

        # Use self._present_and_record() to initialize the followings:
        self.center_responses = None  # [stim, unit]

        # Use self._sort_responses() to initialize the followings:
        self.max_bar_indices = None  # [unit, bar_indices]
        self.min_bar_indices = None  # [unit, bar_indices]

        # Use self._make_maps() to initialize the followings:
        self.max_weighted_bar_sum = None
        self.min_weighted_bar_sum = None
        self.max_or_bar_sum = None
        self.min_or_bar_sum = None

    def set_debug(self, debug):
        """
        If debug is set to True, the number of bar locations is reduced
        significantly in order to allow the map() method to run at much faster
        rate. This is mainly to access if the bar mapper works, and the results
        data of the debugging mode should not be used.
        """
        self.DEBUG = debug

    def _present_and_record(self):
        """
        Presents bars and record the center responses of in the proper arrays.
        """
        self.center_responses = np.zeros((self.num_stim, self.num_units))

        bar_i = 0
        while (bar_i < self.num_stim):
            if self.DEBUG and bar_i > 1000:
                break
            real_batch_size = min(self.batch_size, self.num_stim-bar_i)
            new_bars = np.zeros((real_batch_size, self.yn, self.xn))
            for i in range(real_batch_size):
                params = self.stim_dicts[bar_i + i]

                # Create a bar at this location and record the responses of all center units.
                new_bar = stimfr_bar(params['xn'], params['yn'], params['x0'], params['y0'],
                                    params['theta'], params['len'], params['wid'], 
                                    params['aa'], params['fgval'], params['bgval'])
                new_bars[i] = self._bar_full_image(new_bar, params['bgval'])
                
                if self.DEBUG:
                    plt.imshow(self._bar_full_image(new_bar, params['bgval']), cmap='gray')
                    boundary = 10
                    plt.xlim([self.box[1] - boundary, self.box[3] + boundary])
                    plt.ylim([self.box[0] - boundary, self.box[2] + boundary])
                    rect = make_box(self.box, linewidth=2)
                    ax = plt.gca()
                    ax.add_patch(rect)
                    ax.invert_yaxis()
                    plt.show()

            self.center_responses[bar_i:bar_i+real_batch_size, :] = self._get_center_responses(new_bars)
            self._print_progress(bar_i, pre_text="Presenting ", post_text=" stimuli...")
            bar_i += real_batch_size

    def _sort_responses(self):
        """
        After mapping, call this function. Instead of storing each set of bar
        parameters, here it uses a single index of to represent each bar's
        param in the self.dicts. 
        """
        # Clear existing elements in the dictionaries.
        self.max_bar_indices = []
        self.min_bar_indices = []

        # Update bar indicies.
        for unit_i in range(self.num_units):
            if self.DEBUG and unit_i > self.DEBUG_NUM_UNITS:
                break
            unit_responses = self.center_responses[:, unit_i].copy()

            # Get the max and min of the unit (of all bar lengths).
            unit_max_response = unit_responses.max()
            unit_min_response = unit_responses.min()

            # Max threshold: include every bar that results in an response
            # larger than this.
            max_threshold = (1 - self.percent_max_min_to_cumulate) * (unit_max_response - unit_min_response) + unit_min_response
            max_threshold = self.percent_max_min_to_cumulate * unit_max_response
            max_threshold = max(0, max_threshold)
            num_max_units = len(unit_responses[unit_responses >= max_threshold])

            # Min threshold: include every bar that results in an response
            # less than this.
            min_threshold = self.percent_max_min_to_cumulate * (unit_max_response - unit_min_response) + unit_min_response
            min_threshold = self.percent_max_min_to_cumulate * unit_min_response
            min_threshold = min(0, min_threshold)
            num_min_units = len(unit_responses[unit_responses <= min_threshold])

            sorted_bar_index = unit_responses.argsort(axis=None)  # Ascending
            self.max_bar_indices.append(sorted_bar_index[::-1][:num_max_units])
            self.min_bar_indices.append(sorted_bar_index[:num_min_units])

            if self.DEBUG:
                print(f"unit {unit_i}, unit_responses.shape: {unit_responses.shape}")
                print(f"unit_max_response: {unit_max_response}, max_threshold: {max_threshold}, num_max_units: {num_max_units}")
                print(f"unit_min_response: {unit_min_response}, min_threshold: {min_threshold}, num_min_units: {num_min_units}")
            if len(self.max_bar_indices[-1]) !=0:
                print(f"unit {unit_i}, max_ranking: {self.max_bar_indices[-1][:5]}, "
                      f"r_max = {unit_responses[self.max_bar_indices[-1][0]]:.4f}")

    def index_to_params(self, index):
        """Given a bar index, returns the corresponding bar parameters."""
        return self.stim_dicts[index]

    def _make_maps(self):
        """Updates all three cumulate maps at once for all units."""
        self.max_weighted_bar_sum = np.zeros((self.num_units, self.rf_size, self.rf_size))
        self.min_weighted_bar_sum = np.zeros((self.num_units, self.rf_size, self.rf_size))
        self.max_or_bar_sum = np.zeros((self.num_units, self.rf_size, self.rf_size))
        self.min_or_bar_sum = np.zeros((self.num_units, self.rf_size, self.rf_size))

        for unit_i in range(self.num_units):
            if self.DEBUG and unit_i > self.DEBUG_NUM_UNITS:
                break
            self._print_progress(unit_i, pre_text="Making maps for unit no.", post_text="...")

            max_bar_indices = self.max_bar_indices[unit_i]
            min_bar_indices = self.min_bar_indices[unit_i]

            for max_bar_index in max_bar_indices:
                params = self.index_to_params(max_bar_index)
                new_bar = stimfr_bar(params['xn'], params['yn'], params['x0'], params['y0'],
                                     params['theta'], params['len'], params['wid'], 
                                     0, 1, 0)
                # Note the new_bar used for making maps are always white on
                # gray (zeros) to prevent canceling.

                # weighted sum
                response = self.center_responses[max_bar_index, unit_i]
                self.max_weighted_bar_sum[unit_i] += new_bar * abs(response)

                # or sum
                if not np.any(np.logical_and(self.max_or_bar_sum[unit_i]>self.bar_thres, new_bar>0)):
                    new_bar[new_bar < self.bar_thres] = 0
                    self.max_or_bar_sum[unit_i] += new_bar

            for min_bar_index in min_bar_indices:
                params = self.index_to_params(min_bar_index)
                new_bar = stimfr_bar(params['xn'], params['yn'], params['x0'], params['y0'],
                                     params['theta'], params['len'], params['wid'], 
                                     0, 1, 0)
                # Note that the new_bar used for making maps are always white
                # on gray (zeros) to prevent canceling.
                
                # weighted sum
                response = self.center_responses[min_bar_index, unit_i]
                self.min_weighted_bar_sum[unit_i] += new_bar * abs(response)

                # or sum
                if not np.any(np.logical_and(self.min_or_bar_sum[unit_i]>self.bar_thres, new_bar>0)):
                    new_bar[new_bar < self.bar_thres] = 0
                    self.min_or_bar_sum[unit_i] += new_bar

    def map(self):
        """
        Apply receptive field mapping paradigm 4a.
        
        Returns
        -------
        max_weighted_bar_sum : numpy.array
            The weighted sum of the top bars in dimension [unit, yn, xn].
        min_weighted_bar_sum : numpy.array
            The weighted sum of the bottom bars in dimension [unit, yn, xn].
        max_or_bar_sum : numpy.array
            The pixel-wise OR sum of the top bars in dimension [unit, yn, xn].
        min_or_bar_sum : numpy.array
            The pixel-wise OR sum of the bottom bars in dimension [unit, yn, xn].
        """
        self._present_and_record()
        self._sort_responses()
        self._make_maps()

        return self.max_weighted_bar_sum, self.min_weighted_bar_sum,\
               self.max_or_bar_sum, self.min_or_bar_sum

    def save_maps(self, map_dir):
        """Save the maps as npy files to the map_dir."""
        max_weighted_path = os.path.join(map_dir, 'weighted', f"conv{self.conv_i+1}_max_maps.npy")
        np.save(max_weighted_path, self.max_weighted_bar_sum)

        min_weighted_path = os.path.join(map_dir, 'weighted', f"conv{self.conv_i+1}_min_maps.npy")
        np.save(min_weighted_path, self.min_weighted_bar_sum)

        max_or_path = os.path.join(map_dir, 'or', f"conv{self.conv_i+1}_max_maps.npy")
        np.save(max_or_path, self.max_or_bar_sum)
        
        min_or_path = os.path.join(map_dir, 'or', f"conv{self.conv_i+1}_min_maps.npy")
        np.save(min_or_path, self.min_or_bar_sum)

    def plot_one_unit(self, cumulate_mode, unit):
        """
        Plots the max, min, and both (max + min) cumulative bar maps of one
        unit.
        """
        if cumulate_mode == 'weighted':
            max_bar_sum = self.max_weighted_bar_sum
            min_bar_sum = self.min_weighted_bar_sum
        elif cumulate_mode == 'or':
            max_bar_sum = self.max_or_bar_sum
            min_bar_sum = self.min_or_bar_sum

        plt.figure(figsize=(15, 5))
        plt.suptitle(f"Cumulative map with bars (conv{self.conv_i+1}, no.{unit}, cumulate mode = {cumulate_mode})", fontsize=20)
        
        plt.subplot(1, 3, 1)
        plt.imshow(max_bar_sum[unit], cmap='gray')
        plt.title(f"max", fontsize=16)

        plt.subplot(1, 3, 2)
        plt.imshow(min_bar_sum[unit], cmap='gray')
        plt.title(f"min", fontsize=16)

        plt.subplot(1, 3, 3)
        both_map = (max_bar_sum[unit] + min_bar_sum[unit])/2
        if cumulate_mode == 'or':
            both_map[both_map > 0] = 1
        plt.imshow(both_map, cmap='gray')
        plt.title(f"max + min", fontsize=16)

    def make_pdf(self, pdf_path, cumulate_mode, show=False):
        """
        Makes a pdf, with each page printing the cumulative bar maps of a unit.
        
        Parameters
        ----------
        pdf_path : str
            The file path (must ends with .pdf) of the pdf file.
        cumulate_mode : str
            Either 'weighted' or 'or'.
        show : bool
            If True, show the plots as they are printed to the pdf.
        """
        with PdfPages(pdf_path) as pdf:
            for unit_i in range(self.num_units):
                if self.DEBUG and unit_i > self.DEBUG_NUM_UNITS:
                    break
                self.plot_one_unit(cumulate_mode, unit_i)
                if show: plt.show()
                pdf.savefig()
                plt.close()


#######################################.#######################################
#                                                                             #
#                            RF MAPPING PARADIGM z                            #
#      Summing all black and white bars, weighted by rectified responses.     #
#                 Incorrect implementation of paradigm 4a.                    #
#                 Can use it for Conv1 for cool animation.                    #
#                                                                             #
###############################################################################
class BarRfMapperPz(BarRfMapper):
    def __init__(self, model, conv_i, image_shape):
        super().__init__(model, conv_i, image_shape)

        # Bar parameters
        self.rf_blen_ratios = [3/4, 3/8, 3/16, 3/32]
        self.rf_blen_ratio_strs = ['3/4', '3/8', '3/16', '3/32']
        self.aspect_ratios = [1/2, 1/5, 1/10]
        self.thetas = np.arange(0, 180, 22.5)
        self.fgval_bgval_pairs = [(1, -1), (-1, 1)]
        self.laa = 0.5  # anti-alias distance

        # Mapping parameters
        self.cumulate_threshold = 1
        self.DEBUG = False

        # Array initializations
        self.all_responses = np.zeros((self.num_units,
                                       len(self.rf_blen_ratios),
                                       len(self.aspect_ratios),
                                       len(self.thetas),
                                       len(self.fgval_bgval_pairs)))
        self.weighted_bar_sum = np.zeros((self.num_units, self.yn, self.xn))
        self.threshold_bar_sum = np.zeros((self.num_units, self.yn, self.xn))
        self.center_only_bar_sum = np.zeros((self.num_units, self.yn, self.xn))

    def _weighted_cumulate(self, new_bar, bar_sum, unit, response):
        """
        Adds the new_bar, weighted by the unit's response to that bar, to the
        cumulative bar map.
        
        Parameters
        ----------
        new_bar : numpy.array
            The new bar.
        bar_sum : numpy.array
            The cumulated weighted sum of all previous bars. This is modified
            in-place.
        unit : int
            The unit's number.
        response : float
            The unit's response (spatial center only) to the new bar.
        """
        bar_sum[unit, :, :] += new_bar * response

    def _threshold_cumulate(self, new_bar, bar_sum, unit, response):
        """
        Adds to a cumulative map only bars that gave a threshold response.
        
        Parameters
        ----------
        See _weighted_cumulate() for repeated parameters.
        threshold : float
            The unit's response (spatial center only) to the new bar.
        """
        if response > self.cumulate_threshold:
            bar_sum[unit, :, :] += new_bar
    
    def _center_only_cumulate(self, bar_sum, unit, response):
        """
        Adds to a cumulative map only the center points of bars that gave a
        threshold response.
        
        Parameters
        ----------
        See _weighted_cumulate() for repeated parameters.
        bar_sum : numpy.array
            The cumulated weighted sum of all previous bar centers. This is
            modified in-place.
        """
        if response > self.cumulate_threshold:
            bar_sum[unit, self.box_yc, self.box_xc] += response
    
    def _update_bar_sums(self, new_bar, responses):
        """Updates all three cumulate maps at once for all units."""
        for unit in range(self.num_units):
            self._weighted_cumulate(new_bar, self.weighted_bar_sum, unit, responses[unit])
            self._threshold_cumulate(new_bar, self.threshold_bar_sum, unit, responses[unit])
            self._center_only_cumulate(self.center_only_bar_sum, unit, responses[unit])

    def set_debug(self, debug):
        """
        If debug is set to True, the number of bar locations is reduced
        significantly in order to allow the map() method to run at much faster
        rate. This is mainly to access if the bar mapper works, and the results
        data should not be used.
        """
        self.DEBUG = debug

    def _map(self, animation=False, unit=None, cumulate_mode=None, bar_sum=None):
        num_stimuli = 0
        for blen_i, rf_blen_ratio in enumerate(self.rf_blen_ratios):
            for bwid_i, aspect_ratio in enumerate(self.aspect_ratios):
                for theta_i, theta in enumerate(self.thetas):
                    for val_i, (fgval, bgval) in enumerate(self.fgval_bgval_pairs):
                        # Some bar parameters
                        blen = round(rf_blen_ratio * self.rf_size)
                        bwid = round(aspect_ratio * blen)
                        grid_spacing = blen/2
                        
                        # Get grid coordinates.
                        grid_coords = self.grid_calculator.get_grid_coords(self.layer_idx, (self.output_yc, self.output_xc), grid_spacing)
                        grid_coords_np = np.array(grid_coords)

                        # Create bars.
                        for grid_coord_i, (xc, yc) in enumerate(grid_coords_np):
                            if self.DEBUG and grid_coord_i > 10:
                                break

                            new_bar = draw_bar(self.xn, self.yn, xc, yc, theta, blen, bwid, self.laa, fgval, bgval)
                            center_responses = self._get_center_responses(new_bar)
                            center_responses[center_responses < 0] = 0  # ReLU
                            self.all_responses[:, blen_i, bwid_i, theta_i, val_i] += center_responses.copy()

                            num_stimuli += 1

                            if not animation:
                                self._print_progress(num_stimuli)
                                self._update_bar_sums(self, new_bar, center_responses)
                            
                            else:
                                if cumulate_mode == 'weighted':
                                    self._weighted_cumulate(new_bar, bar_sum, unit, center_responses[unit])
                                elif cumulate_mode == 'threshold':
                                    self._threshold_cumulate(new_bar, bar_sum, unit, center_responses[unit])
                                elif cumulate_mode == 'center_only':
                                    self._center_only_cumulate(bar_sum, unit, center_responses[unit])
                                else:
                                    raise ValueError(f"cumulate_mode: {cumulate_mode} is not supported.")
                                yield bar_sum[unit], center_responses[unit], num_stimuli, new_bar, np.mean(self.all_responses[unit,...], axis=(0,1,3))


    def map(self):
        self._map(animation=False)
        return self.all_responses

    def animate(self, unit, cumulate_mode='weighted'):
        bar_sum = np.zeros((self.num_units, self.yn, self.xn))
        return self._map(animation=True, unit=unit, cumulate_mode=cumulate_mode, bar_sum=bar_sum)

    def plot_one_unit(self, cumulate_mode, unit):
        if cumulate_mode == 'weighted':
            bar_sum = self.weighted_bar_sum
        elif cumulate_mode == 'threshold':
            bar_sum = self.threshold_bar_sum
        elif cumulate_mode == 'center_only':
            bar_sum = self.center_only_bar_sum

        plt.figure(figsize=(25, 5))
        plt.suptitle(f"RF mapping with bars no.{unit}", fontsize=20)
        
        plt.subplot(1, 5, 1)
        plt.imshow(bar_sum[unit, :, :], cmap='gray')
        plt.title("Cumulated bar maps")
        boundary = 10
        plt.xlim([self.box[1] - boundary, self.box[3] + boundary])
        plt.ylim([self.box[0] - boundary, self.box[2] + boundary])
        rect = make_box(self.box, linewidth=2)
        ax = plt.gca()
        ax.add_patch(rect)
        ax.invert_yaxis()
        
        plt.subplot(1, 5, 2)
        blen_tuning = np.mean(self.all_responses[unit,...], axis=(1,2,3))
        blen_std = np.mean(self.all_responses[unit,...], axis=(1,2,3))/math.sqrt(self.num_units)
        plt.errorbar(self.rf_blen_ratios, blen_tuning, yerr=blen_std)
        plt.title("Bar length tuning")
        plt.xlabel("blen/RF ratio")
        plt.ylabel("avg response")
        plt.grid()
        
        plt.subplot(1, 5, 3)
        bwid_tuning = np.mean(self.all_responses[unit,...], axis=(0,2,3))
        bwid_std = np.mean(self.all_responses[unit,...], axis=(0,2,3))/math.sqrt(self.num_units)
        plt.errorbar(self.aspect_ratios, bwid_tuning, yerr=bwid_std)
        plt.title("Bar width tuning")
        plt.xlabel("aspect ratio")
        plt.ylabel("avg response")
        plt.grid()

        plt.subplot(1, 5, 4)
        theta_tuning = np.mean(self.all_responses[unit,...], axis=(0,1,3))
        theta_std = np.mean(self.all_responses[unit,...], axis=(0,1,3))/math.sqrt(self.num_units)
        plt.errorbar(self.thetas, theta_tuning, yerr=theta_std)
        plt.title("Theta tuning")
        plt.xlabel("theta")
        plt.ylabel("avg response")
        plt.grid()
        
        plt.subplot(1, 5, 5)
        val_tuning = np.mean(self.all_responses[unit,...], axis=(0,1,2))
        val_std = np.mean(self.all_responses[unit,...], axis=(0,1,2))/math.sqrt(self.num_units)
        plt.bar(['white on black', 'black on white'], val_tuning, yerr=val_std, width=0.4)
        plt.title("Contrast tuning")
        plt.ylabel("avg response")
        plt.grid()
    
    def make_pdf(self, pdf_path, cumulate_mode, show=False):
        with PdfPages(pdf_path) as pdf:
            for unit in range(self.num_units):
                self.plot_one_unit(cumulate_mode, unit)
                if show: plt.show()
                pdf.savefig()
                plt.close()