import torch from torch.nn import CrossEntropyLoss # Test loop def test_vit(model, test_loader, device, criterion=CrossEntropyLoss(), vis=False): with torch.no_grad(): correct, total = 0, 0 test_loss = 0.0 for batch in test_loader: x, y = batch x, y = x.to(device), y.to(device) if vis == True: y_hat,_ = model(x) else: y_hat = model(x) loss = criterion(y_hat, y) test_loss += loss.detach().cpu().item() / len(test_loader) correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item() total += len(x) print(f"Test loss: {test_loss:.2f}") print(f"Test accuracy: {correct / total * 100:.2f}%") return test_loss, correct / total