FOT-OOD / model.py
model.py
Raw
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import DistilBertForSequenceClassification, DistilBertModel
from transformers import BertTokenizerFast, DistilBertTokenizerFast

import os
os.environ['CURL_CA_BUNDLE'] = ''

def EfficientNetB4(num_classes, seed=123, pretrained=True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    model = models.efficientnet_b4(pretrained=pretrained)
    model.classifier[1] = nn.Linear(1792, num_classes)
    return model
    

def ViT_B_16(num_classes=10, seed=123, pretrained=True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    model = models.vit_b_16(pretrained=pretrained)
    model.heads.head = nn.Linear(768, num_classes)

    return model

def ResNet18(num_classes=10, seed=123, pretrained=True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    resnet18 = models.resnet18(pretrained=pretrained)
    resnet18.fc = nn.Linear(512, num_classes)
    return resnet18


def ResNet50(num_classes=10, seed=123, pretrained=True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    resnet50 = models.resnet50(pretrained=pretrained).cuda()
    resnet50.fc = nn.Linear(2048, num_classes).cuda()
    return resnet50


def ResNet101(num_classes=10, seed=123, pretrained=True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    resnet101 = models.resnet101(pretrained=pretrained).cuda()
    resnet101.fc = nn.Linear(2048, num_classes).cuda()
    return resnet101

def DenseNet121(num_classes=10, seed=123, pretrained=True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    densenet121 = models.densenet121(pretrained=pretrained).cuda()
    densenet121.classifier = nn.Linear(1024, num_classes).cuda()
    return densenet121


def VGG11(num_classes=10, seed=123, pretrained=True):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    vgg11 = models.vgg11(pretrained=pretrained).cuda()
    vgg11.classifier[-1] = nn.Linear(4096, num_classes).cuda()
    return vgg11


class DistilBertClassifier(DistilBertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)

    def __call__(self, x):
        input_ids = x[:, :, 0]
        attention_mask = x[:, :, 1]
        outputs = super().__call__(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )[0]
        return outputs
    

def initialize_bert_based_model(num_classes):
    model = DistilBertClassifier.from_pretrained(
        'distilbert-base-uncased',
        num_labels=num_classes
    )
    return model


def initialize_bert_transform(net, max_token_length=512):
    # assert 'bert' in config.model
    # assert config.max_token_length is not None

    tokenizer = getBertTokenizer(net)
    def transform(text):
        tokens = tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=max_token_length,
            return_tensors='pt')
        if net == 'bert-base-uncased':
            x = torch.stack(
                (tokens['input_ids'],
                 tokens['attention_mask'],
                 tokens['token_type_ids']),
                dim=2)
        elif net == 'distilbert-base-uncased':
            x = torch.stack(
                (tokens['input_ids'],
                 tokens['attention_mask']),
                dim=2)
        x = torch.squeeze(x, dim=0) # First shape dim is always 1
        return x
    return transform


def getBertTokenizer(model):
    if model == 'bert-base-uncased':
        tokenizer = BertTokenizerFast.from_pretrained(model)
    elif model == 'distilbert-base-uncased':
        tokenizer = DistilBertTokenizerFast.from_pretrained(model)
    else:
        raise ValueError(f'Model: {model} not recognized.')

    return tokenizer