borderownership / src / rf_mapping / net.py
net.py
Raw
"""
Code for truncating models and tracing data flow inside them. Need these
functions because many models like resnets are not single-path, i.e., the data
do not neccessarily flow from one layer to another in the order they are
presented in model.children().

Tony Fu, Aug 4th, 2022
"""
import sys
import copy

import torch
import torch.fx as fx
import torch.nn as nn
from torchvision import models

sys.path.append('../..')
import src.rf_mapping.constants as c


#######################################.#######################################
#                                                                             #
#                             GET_TRUNCATED_MODEL                             #
#                                                                             #
###############################################################################
def get_truncated_model(model, layer_index):
    """
    Create a truncated version of the neural network.

    Parameters
    ----------
    model : torch.nn.Module
        The neural network to be truncated.
    layer_index : int
        The last layer (inclusive) to be included in the truncated model.

    Returns
    -------
    truncated_model : torch.nn.Module
        The truncated model.

    Example
    -------
    model = models.alexnet(pretrained=True)
    model_to_conv2 = get_truncated_model(model, 3)
    y = model(torch.ones(1,3,200,200))
    """
    model = copy.deepcopy(model).to(c.DEVICE)
    model.eval()  # Make sure to trace the eval() version of the net.
    graph = fx.Tracer().trace(model)
    new_graph = fx.Graph()
    layer_counter = 0
    value_remap = {}

    for node in graph.nodes:
        # Create a new module that will be returned
        # new_graph.create_node(node.op, node.target, args=node.args,
        #                       kwargs=node.kwargs, name=node.name)
        value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])

        # If the node is a module...
        if node.op == 'call_module':
            # Get the layer object using the node.target attribute.
            layer = model
            for level in node.target.split("."):
                layer = getattr(layer, level)
            # Stop at the desired layer (i.e., truncate).
            if layer_counter == layer_index:
                new_graph.output(node)
                break

            layer_counter += 1

    # new_graph.lint()
    # new_graph.eliminate_dead_code()
    return fx.GraphModule(model, new_graph)
    
if __name__ == "__main__":
    model = models.resnet18(pretrained=True)
    model.eval()
    dummy_input = torch.ones(1,3,200,200)
    tm = get_truncated_model(model, 100)
    torch.testing.assert_allclose(tm(dummy_input), model(dummy_input))


#######################################.#######################################
#                                                                             #
#                                  LayerNode                                  #
#                                                                             #
###############################################################################
class LayerNode:
    def __init__(self, name, layer=None, parents=(), children=(), idx=None):
        self.idx = idx
        self.name = name
        self.layer = layer
        self.parents = parents
        self.children = children

    def __repr__(self):
        return f"LayerNode '{self.name}' (idx = {self.idx})\n"\
               f"       parents  = {self.parents}\n"\
               f"       children = {self.children}"


#######################################.#######################################
#                                                                             #
#                                 MAKE_GRAPH                                  #
#                                                                             #
###############################################################################
def make_graph(truncated_model):
    """
    Generate a directed, acyclic graph representation of the model.

    Parameters
    ----------
    truncated_model : UNION[fx.graph_module.GraphModule, torch.nn.Module]
        The neural network. Can be truncated or not.

    Returns
    -------
    nodes : dict
        key : the unique name of each operation performed on the input tensor.
        value : a LayerNode object containing the information about the
                operation.
    """
    # Make sure that the truncated_model is a GraphModule. 
    if not isinstance(truncated_model, fx.graph_module.GraphModule):
        truncated_model = copy.deepcopy(truncated_model)
        graph = fx.Tracer().trace(truncated_model.eval())
        truncated_model = fx.GraphModule(truncated_model, graph)

    nodes = {}
    idx_count = 0  # for layer indexing
    # Populate the nodes dictionary with the initialized Nodes.
    for node in truncated_model.graph.nodes:
        # Get the layer torch.nn object.
        if node.op == 'call_module':
            layer = truncated_model
            idx = idx_count
            idx_count += 1
            for level in node.target.split("."):
                layer = getattr(layer, level)
        else:
            layer = None
            idx = None

        # Get the name of the parents.
        parents = []
        for parent in node.args:
            if isinstance(parent, fx.node.Node):
                parents.append(parent.name)

        # Initialize Nodes.
        nodes[node.name] = LayerNode(node.name, layer, parents=tuple(parents),
                                     idx=idx)

    # Determine the children of the nodes.
    for node in truncated_model.graph.nodes:
        for parent in nodes[node.name].parents:
            existing_children = nodes[parent].children
            nodes[parent].children = (*existing_children, node.name)

    return nodes


if __name__ == '__main__':
    model = models.resnet18()
    model.eval()
    for layer in make_graph(model).values():
        print(layer)


#######################################.#######################################
#                                                                             #
#                               GET_LAYER_INDICES                             #
#                                                                             #
###############################################################################
def get_conv_layer_indices(model, layer_types=(nn.Conv2d)):
    """
    Gets the indicies of all layers of the types {layer_types}.
    
    Parameters
    ----------
    model : UNION[fx.graph_module.GraphModule, torch.nn.Module]
        The neural network. Can be truncated or not.
    layer_types : [type, ...]
        The type of layer to include in the indexing.

    Returns
    -------
    layer_indices : [int, ...]
        Indices of the layer given by the torch.fx.Tracer() object.
    """
    layer_indices = []
    for layer in make_graph(model).values():
        if isinstance(layer.layer, layer_types):
            layer_indices.append(layer.idx)
    return layer_indices


if __name__ == "__main__":
    model = models.alexnet()
    print(get_conv_layer_indices(model, layer_types=(nn.Conv2d)))


#######################################.#######################################
#                                                                             #
#                                 IS_RESIDUAL                                 #
#                                 (NOT USED)                                  #
#                                                                             #
###############################################################################
def is_residual(container_layer):
    """Check if the container layer has residual connection or not."""
    has_conv = False
    has_conv1x1 = False
    first_in_channels = None

    for sublayer in container_layer.children():
        if isinstance(sublayer, nn.Conv2d):
            if (not has_conv):
                has_conv = True
                first_in_channels = sublayer.in_channels
            
            if (not has_conv1x1):
                try:
                    has_conv1x1 = (sublayer.kernel_size == (1,1))
                except:
                    has_conv1x1 = (sublayer.kernel_size == 1)

    if not has_conv:
        return 0

    dummy_input = torch.ones((1, first_in_channels, 100, 100)).to(c.DEVICE)
    x = dummy_input.detach()
    for sublayer in container_layer.children():
        x = sublayer(x)
    print(x.mean())
    print(container_layer(dummy_input).mean())
    print(has_conv1x1)

    original_output = container_layer(dummy_input)
    if not has_conv1x1:
        if torch.sum(x != original_output) == 0:
            return 0
        
        rectified_x = x.detach()
        rectified_x += dummy_input
        rectified_x[rectified_x < 0] = 0
        if torch.sum(rectified_x != original_output) == 0:
            return 1
        print(rectified_x.mean())

    conv1x1 = nn.Conv2d(x.shape[1], original_output.shape[1], kernel_size=1, stride=2, bias=False)
    subsampled_x = conv1x1(x)
    subsampled_x += dummy_input
    subsampled_x[subsampled_x < 0] = 0
    if torch.sum(subsampled_x != original_output) == 0:
        return 2


if __name__ == "__main__":
    res_block = list(list(models.resnet18().children())[4].children())[0]
    print(is_residual(res_block))

    not_res_block = list(models.alexnet().children())[0]
    print(is_residual(not_res_block))

    not_res_block2 = nn.Sequential()
    print(is_residual(not_res_block2))