borderownership / src / not_used / layer_finder.py
layer_finder.py
Raw
"""
Code to find the natural image patches that drive the units the most/least.

Tony Fu, June 18, 2022
"""
import numpy as np
import torch
import torch.nn as nn
from torchvision import models


def layer_indices(model, layer_type=nn.Conv2d):
    """
    Recursively find all layers of the type <layer_type> in a given model, then
    returns their indices of the model. All Pytorch neural networks are tree
    data structures: the "model" is the root node, the container layers such as
    Sequential layers are the intermediate nodes, and layers like Conv2d, ReLU,
    MaxPool2d, and Linear are leave nodes.

    Parameters
    ----------
    model: torchvision.models or torch.nn.Module
        The neural network model or layer object.
    layer_type: type
        The target type of the leave node, e.g., nn.Conv2d, nn.ReLU, etc.

    Returns
    -------
    indices: list of lists
        Each sublist is a sequence of child indices from the "model" root node
        to a target layer. For example, if the indices returned for layer_type
        nn.Conv2d is [0, [0, 3], 1, []], then the first nn.Conv2d layer can be
        accessed with the code:
            list(list(model.children())[0].children())[0]
        And the second nn.Conv2d layer of the model can be found using:
            list(list(model.children())[0].children())[3]
        The 1, [] in the example output means that the second child of the
        model has been checked, and there is no nn.Conv2d in it. There are only
        two nn.Conv2d layers in this example model.
    """
    indices = _layer_indices(model, layer_type, [])

    # Remove trailing comma.
    indices[-1] = indices[-1].split(',')[0]

    return indices


def _layer_indices(layer, layer_type, indices):
    """
    Private function used for recursion in layer_indices().
    """
    # Return the index if layer is a leave node and match the target type.
    if (len(list(layer.children())) == 0):
        if (not isinstance(layer, layer_type)):
            indices.pop(-1)
        return indices

    # Recurse otherwise.
    else:
        indices.append("[")
        for i, sublayer in enumerate(layer.children()):
            indices.append(f"{i}, ")
            indices = _layer_indices(sublayer, layer_type, indices)

        # Remove trailing comma.
        indices[-1] = indices[-1].split(',')[0]

        indices.append("], ")
    return indices


if __name__ == '__main__':
    model = models.alexnet()
    print(''.join(layer_indices(model)))


def layer_counter(model, layer_type=nn.Conv2d):
    """
    Recursively find all layers of the type <layer_type> in a given model, then
    returns the counts.

    Parameters
    ----------
    model : torchvision.models or torch.nn.Module
        The neural network model or layer object.
    layer_type : type
        The target type of the leave node, e.g., nn.Conv2d, nn.ReLU, etc.

    Returns
    -------
    count : int
        The number of <layer_type> layers in the model.
    """
    count = _layer_counter(model, layer_type, 0)
    return count


def _layer_counter(layer, layer_type, count):
    """
    Private function used for recursion in layer_counter().
    """
    # Return the index if layer is a leave node and match the target type.
    if (len(list(layer.children())) == 0):
        if (isinstance(layer, layer_type)):
            return count + 1
        else: 
            return count

    # Recurse otherwise.
    else:
        for i, sublayer in enumerate(layer.children()):
            count = _layer_counter(sublayer, layer_type, count)
    return count


if __name__ == '__main__':
    model = models.alexnet()
    count = layer_counter(model, nn.Conv2d)
    assert  count == 5, "count should be 5 for Alexnet"


def num_units_in_layers(model, layer_type=nn.Conv2d):
    """
    Recursively find all layers of the type <layer_type> in a given model, then
    returns the number of their units.

    Parameters
    ----------
    model : torchvision.models or torch.nn.Module
        The neural network model or layer object.
    layer_type : type
        The target type of the leave node, e.g., nn.Conv2d, nn.ReLU, etc.

    Returns
    -------
    num_units_list : list of int
        The number of units in each layer of the type <layer_type>.
    """
    num_units_list = []
    _num_units_in_layers(model, layer_type, num_units_list)
    return num_units_list


def _num_units_in_layers(layer, layer_type, num_units_list):
    """
    Private function used for recursion in lnum_units_in_layers().
    """
    # Return the index if layer is a leave node and match the target type.
    if (len(list(layer.children())) == 0):
        if (isinstance(layer, layer_type)):
            num_units_list.append(layer.weight.shape[0])
        return

    # Recurse otherwise.
    else:
        for i, sublayer in enumerate(layer.children()):
            _num_units_in_layers(sublayer, layer_type, num_units_list)


if __name__ == '__main__':
    model = models.alexnet()
    num_units_list = num_units_in_layers(model, nn.Conv2d)
    print(num_units_list)