borderownership / src / rf_mapping / reproducibility.py
reproducibility.py
Raw
import sys
import random

import torch
import torchvision.models as models
import numpy as np
from tqdm import tqdm

sys.path.append('../..')
import src.rf_mapping.constants as c

__all__ = ['set_seeds']

def set_seeds(seed=123):
    """
    Sets seeds for randomly initialized models. Should be called everytime
    an untrained instance of a model is created.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


if __name__ == "__main__":
    # Importing modules needed for tests below.
    from src.rf_mapping.hook import ConvUnitCounter
    from src.rf_mapping.net import get_truncated_model
    
    def truncated_model_test(model1, model2):
        counter = ConvUnitCounter(model1)
        layer_indices, _ = counter.count()

        for layer_idx in tqdm(layer_indices):
            tm1 = get_truncated_model(model1, layer_idx)
            tm2 = get_truncated_model(model2, layer_idx)
            y1 = tm1(torch.ones((1, 3, 227, 227)))
            y2 = tm2(torch.ones((1, 3, 227, 227)))
            assert(torch.allclose(y1, y2))
        print(f"All truncated model tests passed.")


if __name__ == "__main__":
    print("Testing Alexnet...")
    set_seeds()
    model1 = models.alexnet().to(c.DEVICE)
    set_seeds()
    model2 = models.alexnet().to(c.DEVICE)
    
    counter = ConvUnitCounter(model1)
    layer_indices, _ = counter.count()

    for conv_i, layer_idx in enumerate(layer_indices):
        model1_conv_weights = model1.features[0].weight
        model2_conv_weights = model2.features[0].weight
        assert(torch.allclose(model1_conv_weights, model2_conv_weights))
    print(f"All untruncated model tests passed.")

    truncated_model_test(model1, model2)


if __name__ == "__main__":
    print("Testing Vgg16...")
    set_seeds()
    model1 = models.vgg16().to(c.DEVICE)
    set_seeds()
    model2 = models.vgg16().to(c.DEVICE)
    
    truncated_model_test(model1, model2)


if __name__ == "__main__":
    print("Testing Resnet18...")
    set_seeds()
    model1 = models.resnet18().to(c.DEVICE)
    set_seeds()
    model2 = models.resnet18().to(c.DEVICE)
    
    truncated_model_test(model1, model2)