ViTGuard / Evaluations.py
Evaluations.py
Raw
import torch
from torch.nn import CrossEntropyLoss

# Test loop
def test_vit(model, test_loader, device, criterion=CrossEntropyLoss(), vis=False):
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in test_loader:
            x, y = batch
            x, y = x.to(device), y.to(device)
            if vis == True:
                y_hat,_ = model(x)
            else:
                y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")
        return test_loss, correct / total