import torch import torch.nn as nn from transformers import ViTModel from torch.optim import Adam from torch.nn import CrossEntropyLoss from tqdm import tqdm, trange import numpy as np import argparse import os import sys sys.path.append('../') from TransformerModels_pretrain import ViTModel_custom, ViTForImageClassification import TransformerConfigs_pretrain as configs sys.path.append('../../') from load_data import load_tiny, GetCIFAR100Training, GetCIFAR100Validation, GetCIFAR10Training, GetCIFAR10Validation def get_aug(): parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='TinyImagenet', type=str) args = parser.parse_args() return args args = get_aug() if args.dataset == 'CIFAR10': train_loader = GetCIFAR10Training(imgSize=224) test_loader = GetCIFAR10Validation(imgSize=224) num_labels = 10 elif args.dataset == 'CIFAR100': train_loader = GetCIFAR100Training(imgSize=224) test_loader = GetCIFAR100Validation(imgSize=224) num_labels = 100 elif args.dataset == 'TinyImagenet': train_loader, test_loader = load_tiny(shuffle=True, is_train=True) num_labels = 200 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "") # Load the configuration of target models model_arch = 'ViT-16' # Load the pretrained model vit_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k') # print(vit_model.config) config = configs.get_b16_config() model = ViTModel_custom(config=config) # Load the weights from the pretrained model model.load_state_dict(vit_model.state_dict()) for param in model.parameters(): param.requires_grad = False model = ViTForImageClassification(config, model, num_labels) model = nn.DataParallel(model).cuda() N_EPOCHS = 50 LR = 1e-4 # Training loop optimizer = Adam(model.parameters(), lr=LR) criterion = CrossEntropyLoss() train_loss_list = [] train_acc_list = [] test_loss_list = [] test_acc_list = [] for epoch in trange(N_EPOCHS, desc="Training"): train_loss = 0.0 correct, total = 0, 0 for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False): x, y = batch x, y = x.to(device), y.to(device) y_hat = model(x) loss = criterion(y_hat, y) train_loss += loss.detach().cpu().item() / len(train_loader) correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item() total += len(x) optimizer.zero_grad() loss.backward() optimizer.step() train_loss_list.append(train_loss) train_acc_list.append(correct/total) print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}") print(f"Epoch {epoch + 1}/{N_EPOCHS} acc: {correct/total:.2f}") directory_path = "../results/{}/{}/".format(model_arch, args.dataset) if not os.path.exists(directory_path): os.makedirs(directory_path) filename = "../results/{}/{}/weights.pth".format(model_arch, args.dataset) torch.save(model.module.state_dict(), filename)