import torch
from torch.nn import CrossEntropyLoss
# Test loop
def test_vit(model, test_loader, device, criterion=CrossEntropyLoss(), vis=False):
with torch.no_grad():
correct, total = 0, 0
test_loss = 0.0
for batch in test_loader:
x, y = batch
x, y = x.to(device), y.to(device)
if vis == True:
y_hat,_ = model(x)
else:
y_hat = model(x)
loss = criterion(y_hat, y)
test_loss += loss.detach().cpu().item() / len(test_loader)
correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
total += len(x)
print(f"Test loss: {test_loss:.2f}")
print(f"Test accuracy: {correct / total * 100:.2f}%")
return test_loss, correct / total