ViTGuard / detection / run / train_mae.py
train_mae.py
Raw
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm, trange
import sys
import argparse
import os
sys.path.append('../')
from ViTMAEModels_pretrain import ViTMAEForPreTraining_custom
import ViTMAEConfigs_pretrain as configs
sys.path.append('../../')
from load_data import load_tiny, GetCIFAR100Training, GetCIFAR100Validation, GetCIFAR10Training, GetCIFAR10Validation
import time


def get_aug():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='TinyImagenet', type=str)
    args = parser.parse_args()
    return args

args = get_aug()
if args.dataset == 'CIFAR10':
    train_loader = GetCIFAR10Training(imgSize=224)
    test_loader = GetCIFAR10Validation(imgSize=224)
    N_EPOCHS = 250
    LR = 1.5e-4

elif args.dataset == 'CIFAR100':
    train_loader = GetCIFAR100Training(imgSize=224)
    test_loader = GetCIFAR100Validation(imgSize=224)
    N_EPOCHS = 250
    LR = 1.5e-4

elif args.dataset == 'TinyImagenet':
    train_loader = load_tiny(shuffle=True, is_train=True)
    N_EPOCHS = 500
    LR = 2.5e-4


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

ratio = 0.5 
config = configs.ViTMAEConfig(ratio=ratio)
model_custom = ViTMAEForPreTraining_custom(config=config)
model_custom = nn.DataParallel(model_custom).cuda()

## training process

optimizer = Adam(model_custom.parameters(), lr=LR)
scheduler = CosineAnnealingLR(optimizer,  T_max=N_EPOCHS)
train_loss_list = []
test_loss_list = []
lr_list = []
for epoch in trange(N_EPOCHS, desc="Training"):

    train_loss = 0.0
    total = 0
    lr_list.append(optimizer.param_groups[0]["lr"])
    # start_time = time.time()
    for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
        x, y = batch
        x, y = x.to(device), y.to(device)
        outputs = model_custom(x)

        target = model_custom.module.patchify(x)
        loss = (outputs[1] - target) ** 2
        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch
        loss = (loss * outputs[2]).sum() / outputs[2].sum()

        train_loss += loss.detach().cpu().item() / len(train_loader)
        total += len(x)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
    # end_time = time.time()
    # print('duration', end_time-start_time)
    train_loss_list.append(train_loss)
    print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.5f}")
    scheduler.step()

weights_path = '../results/{}/'.format(args.dataset)
if not os.path.isdir(weights_path):
    os.mkdir(weights_path)
torch.save(model_custom.module.state_dict(), weights_path+'weights.pth')