MVA-2021 / dl_in_practice / hw2_gradcam / load_model.py
load_model.py
Raw
import os, sys

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import threading



import resnet
from model import Net, apply_attention, tile_2d_over_nd


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ResNetLayer4(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.r_model = resnet.resnet152(pretrained=True)
        self.r_model.eval()
        self.r_model.to(device)

        self.buffer = {}
        lock = threading.Lock()

        # Since we only use the output of the 4th layer from the resnet model and do not
        # need to do forward pass all the way to the final layer we can terminate forward
        # execution in the forward hook of that layer after obtaining the output of it.
        # For that reason, we can define a custom Exception class that will be used for
        # raising early termination error.
        def save_output(module, input, output):
            with lock:
                self.buffer[output.device] = output

        self.r_model.layer4.register_forward_hook(save_output)

    def forward(self, x):
        self.r_model(x)          
        return self.buffer[x.device]

class VQA_Resnet_Model(Net):
    def __init__(self, embedding_tokens):
        super().__init__(embedding_tokens)
        self.resnet_layer4 = ResNetLayer4()
    
    def forward(self, v, q, q_len):
        q = self.text(q, list(q_len.data))
        v = self.resnet_layer4(v)

        v = v / (v.norm(p=2, dim=1, keepdim=True).expand_as(v) + 1e-8)

        a = self.attention(v, q)
        v = apply_attention(v, a)

        combined = torch.cat([v, q], dim=1)
        answer = self.classifier(combined)
        return answer

def load_model():
    saved_state = torch.load('2017-08-04_00.55.19.pth', map_location=device)

    vocab = saved_state['vocab']

    # reading word tokens from saved model
    token_to_index = vocab['question']
    num_tokens = len(token_to_index) + 1
    vqa_net = torch.nn.DataParallel(Net(num_tokens))
    vqa_net.load_state_dict(saved_state['weights'])
    vqa_net.to(device)
    vqa_net.eval()
    
    vqa_resnet = VQA_Resnet_Model(vqa_net.module.text.embedding.num_embeddings)

    vqa_resnet.text.load_state_dict(vqa_net.module.text.state_dict())
    vqa_resnet.attention.load_state_dict(vqa_net.module.attention.state_dict())
    vqa_resnet.classifier.load_state_dict(vqa_net.module.classifier.state_dict())
    
    vqa_resnet.to(device)
    vqa_resnet.eval()
    
    return vqa_resnet