from abc import ABC import torch import torch.nn.functional as fun from typing import Dict from rt_search_based.models.classifier2test_1.classifier2test_1 import ( TrafficSignClassifier, ) class Classifier(ABC): classifier: torch.nn.Module def __repr__(self) -> str: return f"{type(self).__name__}()" def get_predicted_class(self, image: torch.Tensor) -> float: return fun.softmax(self.classifier(image), dim=1).max(dim=1)[1].item() def get_class_probabilities(self, image: torch.Tensor) -> Dict[int, float]: prob_dict = {} for pred_class, prob in enumerate( fun.softmax(self.classifier(image), dim=1)[0] ): prob_dict[pred_class] = prob.item() return prob_dict class Classifier2Test_1(Classifier): """GTSRB-Classifier from https://github.com/poojahira/gtsrb-pytorch. It is a CNN and Spatial transformer network that uses data augmentation techniques to reach 99.809% accuracy """ classifier = TrafficSignClassifier(43) classifier.eval() classifier.load_state_dict( torch.load( "./rt_search_based/models/classifier2test_1/model/classifier_pretrained_weights.pth", map_location=torch.device("cuda"), ) )