ViTGuard / target_models / run / train.py
train.py
Raw
import torch
import torch.nn as nn
from transformers import ViTModel
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from tqdm import tqdm, trange
import numpy as np
import argparse
import os

import sys
sys.path.append('../')
from TransformerModels_pretrain import ViTModel_custom, ViTForImageClassification
import TransformerConfigs_pretrain as configs
sys.path.append('../../')
from load_data import load_tiny, GetCIFAR100Training, GetCIFAR100Validation, GetCIFAR10Training, GetCIFAR10Validation

def get_aug():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='TinyImagenet', type=str)
    args = parser.parse_args()
    return args

args = get_aug()
if args.dataset == 'CIFAR10':
    train_loader = GetCIFAR10Training(imgSize=224)
    test_loader = GetCIFAR10Validation(imgSize=224)
    num_labels = 10
elif args.dataset == 'CIFAR100':
    train_loader = GetCIFAR100Training(imgSize=224)
    test_loader = GetCIFAR100Validation(imgSize=224)
    num_labels = 100
elif args.dataset == 'TinyImagenet':
    train_loader, test_loader = load_tiny(shuffle=True, is_train=True)
    num_labels = 200

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")


# Load the configuration of target models
model_arch = 'ViT-16'
# Load the pretrained model
vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
# print(vit_model.config)
config = configs.get_b16_config()
model = ViTModel_custom(config=config)

# Load the weights from the pretrained model
model.load_state_dict(vit_model.state_dict())
for param in model.parameters():
    param.requires_grad = False

model = ViTForImageClassification(config, model, num_labels)
model = nn.DataParallel(model).cuda()
N_EPOCHS = 50
LR = 1e-4

# Training loop
optimizer = Adam(model.parameters(), lr=LR)
criterion = CrossEntropyLoss()
train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []
for epoch in trange(N_EPOCHS, desc="Training"):
    train_loss = 0.0
    correct, total = 0, 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)

        loss = criterion(y_hat, y)
        train_loss += loss.detach().cpu().item() / len(train_loader)
        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
    train_loss_list.append(train_loss)
    train_acc_list.append(correct/total)
    
    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")
    print(f"Epoch {epoch + 1}/{N_EPOCHS} acc: {correct/total:.2f}")

directory_path = "../results/{}/{}/".format(model_arch, args.dataset)
if not os.path.exists(directory_path):
    os.makedirs(directory_path)
filename = "../results/{}/{}/weights.pth".format(model_arch, args.dataset)
torch.save(model.module.state_dict(), filename)