LCYE / util / FeatureExtractor.py
FeatureExtractor.py
Raw
import torch.nn as nn
from IPython import embed

class FeatureExtractor(nn.Module):
    def __init__(self,submodule,extracted_layers):
        super(FeatureExtractor,self).__init__()
        self.submodule = submodule
        self.extracted_layers = extracted_layers

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            if name is "classfier":
                x = x.view(x.size(0),-1)
            if name is "base":
                for block_name, cnn_block in module._modules.items():
                    x = cnn_block(x)
                    if block_name in self.extracted_layers:
                        outputs.append(x)
        return outputs