traffic-sign-classifier-robustness-testing / rt_search_based / models / classifiers.py
classifiers.py
Raw
from abc import ABC

import torch
import torch.nn.functional as fun
from typing import Dict
from rt_search_based.models.classifier2test_1.classifier2test_1 import (
    TrafficSignClassifier,
)


class Classifier(ABC):
    classifier: torch.nn.Module

    def __repr__(self) -> str:
        return f"{type(self).__name__}()"

    def get_predicted_class(self, image: torch.Tensor) -> float:
        return fun.softmax(self.classifier(image), dim=1).max(dim=1)[1].item()

    def get_class_probabilities(self, image: torch.Tensor) -> Dict[int, float]:
        prob_dict = {}
        for pred_class, prob in enumerate(
            fun.softmax(self.classifier(image), dim=1)[0]
        ):
            prob_dict[pred_class] = prob.item()

        return prob_dict


class Classifier2Test_1(Classifier):
    """GTSRB-Classifier from https://github.com/poojahira/gtsrb-pytorch.
    It is a CNN and Spatial transformer network that uses data
    augmentation techniques to reach 99.809% accuracy
    """

    classifier = TrafficSignClassifier(43)
    classifier.eval()
    classifier.load_state_dict(
        torch.load(
            "./rt_search_based/models/classifier2test_1/model/classifier_pretrained_weights.pth",
            map_location=torch.device("cuda"),
        )
    )