import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm, trange
import sys
import argparse
import os
sys.path.append('../')
from ViTMAEModels_pretrain import ViTMAEForPreTraining_custom
import ViTMAEConfigs_pretrain as configs
sys.path.append('../../')
from load_data import load_tiny, GetCIFAR100Training, GetCIFAR100Validation, GetCIFAR10Training, GetCIFAR10Validation
import time
def get_aug():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='TinyImagenet', type=str)
args = parser.parse_args()
return args
args = get_aug()
if args.dataset == 'CIFAR10':
train_loader = GetCIFAR10Training(imgSize=224)
test_loader = GetCIFAR10Validation(imgSize=224)
N_EPOCHS = 250
LR = 1.5e-4
elif args.dataset == 'CIFAR100':
train_loader = GetCIFAR100Training(imgSize=224)
test_loader = GetCIFAR100Validation(imgSize=224)
N_EPOCHS = 250
LR = 1.5e-4
elif args.dataset == 'TinyImagenet':
train_loader = load_tiny(shuffle=True, is_train=True)
N_EPOCHS = 500
LR = 2.5e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")
ratio = 0.5
config = configs.ViTMAEConfig(ratio=ratio)
model_custom = ViTMAEForPreTraining_custom(config=config)
model_custom = nn.DataParallel(model_custom).cuda()
## training process
optimizer = Adam(model_custom.parameters(), lr=LR)
scheduler = CosineAnnealingLR(optimizer, T_max=N_EPOCHS)
train_loss_list = []
test_loss_list = []
lr_list = []
for epoch in trange(N_EPOCHS, desc="Training"):
train_loss = 0.0
total = 0
lr_list.append(optimizer.param_groups[0]["lr"])
# start_time = time.time()
for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
x, y = batch
x, y = x.to(device), y.to(device)
outputs = model_custom(x)
target = model_custom.module.patchify(x)
loss = (outputs[1] - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * outputs[2]).sum() / outputs[2].sum()
train_loss += loss.detach().cpu().item() / len(train_loader)
total += len(x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# end_time = time.time()
# print('duration', end_time-start_time)
train_loss_list.append(train_loss)
print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.5f}")
scheduler.step()
weights_path = '../results/{}/'.format(args.dataset)
if not os.path.isdir(weights_path):
os.mkdir(weights_path)
torch.save(model_custom.module.state_dict(), weights_path+'weights.pth')