traffic-sign-classifier-robustness-testing / rt_search_based / gui / gui.py
gui.py
Raw
import logging
from collections import defaultdict
from pathlib import Path
from typing import List, Literal

from flask import Flask, redirect, render_template, request, url_for
from rt_search_based.database.database import Database

from rt_search_based.gui.gui_dataclasses import (
    ComparisonRow,
    PaginationInfo,
    StatisticsTable,
    StrategyResult,
)

TEMPLATE_FOLDER_DIR = Path(__file__).parent / "templates"
STATIC_FOLDER_DIR = Path(__file__).parent.parent / "imgs"


def create_app(db: Database):
    """
    Returns the GUI Flask webapp

    Args:
        db: Database
            The Database instance to be used by the GUI
    """

    app = Flask(
        "rt_search_based_gui",
        template_folder=str(TEMPLATE_FOLDER_DIR),
        static_folder=STATIC_FOLDER_DIR,
    )

    @app.route("/")
    def home():
        return redirect(url_for("sticker_comparison_table"))

    @app.route("/sticker_comparison_table", methods=["GET"])
    @app.route("/sticker_comparison_table/<int:page>", methods=["GET"])
    def sticker_comparison_table(page: int = 1):

        comparison_type: int = request.args.get("type", 0, int)
        comparison_option: int = request.args.get("option", 0, int)

        # Can be expanded in the future for comparisons between fitness functions directly, etc.
        COMPARISON_TYPES = ["Search Strategies"]
        # Initialize comparison options, will later be filled by e.g. Fitness Functions
        comparison_options: List[str] = []

        TOTAL_ENTRIES = db.count_imgs  # Number of images in our database

        if comparison_type == 0:
            # Comparing how different search strategies perform with the same fitness function
            comparison_options = list(db.get_saved_fitness_function_names())
            fitness_function_name = comparison_options[comparison_option]
            strategy_names = list(db.get_saved_strategy_names(fitness_function_name))
            # Move BruteForce strategies to the start and then sort alphabetically
            strategy_names = sorted(
                strategy_names,
                key=lambda x: x.lower() if "bruteforce" not in x.lower() else "",
            )
            logging.info(f"STRATEGY NAMES: {strategy_names=}")
            column_headers = strategy_names

        ROWS_PER_PAGE = 50
        LAST_PAGE = TOTAL_ENTRIES // ROWS_PER_PAGE

        pagination_info = PaginationInfo(page, LAST_PAGE)

        index_offset = (pagination_info.current_page - 1) * ROWS_PER_PAGE

        def get_url(
            page: int = pagination_info.current_page,
            comparison_type: int = comparison_type,
            comparison_option: int = comparison_option,
        ) -> str:
            """returns a url based on the given parameters"""
            return f"/sticker_comparison_table/{page}?type={comparison_type}&option={comparison_option}"

        def get_image_src(
            index: int, strategy_name: str = None, fitness_function_name: str = None
        ) -> str:

            # image nomenclature:
            # {index}_{strategy_name}_{fitness_function_name}.png
            # {index}.png for unprocessed images

            try:
                image_path = db.get_image_path(
                    index, strategy_name, fitness_function_name
                )

                src_path = image_path.absolute().relative_to(STATIC_FOLDER_DIR)

            except Exception:
                return "https://upload.wikimedia.org/wikipedia/commons/thumb/a/ac/No_image_available.svg/1024px-No_image_available.svg.png"

            return url_for("static", filename=src_path.as_posix())

        def get_statistics_table() -> StatisticsTable:
            """
            returns the statistics table for the current page
            """

            # create empty statistics table
            statistics_table = StatisticsTable([], defaultdict(dict))

            if comparison_type == 0:
                # Comparing strategies using the same fitness function

                statistics_table.column_headers = strategy_names

                for strategy_name in strategy_names:

                    report = db.get_report(
                        fitness_function=fitness_function_name, strategy=strategy_name
                    )

                    for statistic_name, value in report.items():

                        statistic = value[2].item()

                        if isinstance(statistic, float):
                            statistic = round(statistic, 2)

                        statistics_table.rows[statistic_name][strategy_name] = str(
                            statistic
                        )

            return statistics_table

        def get_diagram_src(diagram_type: Literal["distribution", "box-plot"]) -> str:
            """gets the src for the distribution diagram image for current comparison"""
            image_path = db.get_diagram_image_path(fitness_function_name, diagram_type)

            src_path = image_path.absolute().relative_to(STATIC_FOLDER_DIR)

            return url_for("static", filename=src_path.as_posix())

        def get_sticker_comparison_rows() -> List[ComparisonRow]:
            """
            retrieves the sticker data from the database and returns it in ComparisonRow format to be visualized in the GUI
            """

            rows: List[ComparisonRow] = []

            for i in range(index_offset, index_offset + ROWS_PER_PAGE):
                index = i
                image = get_image_src(i)

                strategy_results: List[StrategyResult] = []

                if comparison_type == 0:
                    # Comparing strategies using the same fitness function
                    for strategy_name in strategy_names:
                        data = [
                            item.item()
                            for item in db[i, fitness_function_name, strategy_name]
                        ]

                        strategy_results.append(
                            StrategyResult(
                                image_src=get_image_src(
                                    i, strategy_name, fitness_function_name
                                ),
                                area=data[2],
                                # convert nanoseconds to milliseconds
                                time=round(data[1] / 1_000_000, 2),
                                fitness_value=data[0],
                                sticker_coords=((data[3], data[4]), (data[5], data[6])),
                                sticker_color=(data[7], data[8], data[9]),
                            )
                        )

                rows.append(ComparisonRow(index, image, strategy_results))

            return rows

        return render_template(
            "sticker_comparison_table.html",
            comparison_rows=get_sticker_comparison_rows(),
            pagination_info=pagination_info,
            comparison_types=COMPARISON_TYPES,
            comparison_type=comparison_type,
            comparison_options=comparison_options,
            comparison_option=comparison_option,
            get_url=get_url,
            column_headers=column_headers,
            statistics_table=get_statistics_table(),
            distribution_diagram_src=get_diagram_src("distribution"),
            boxplot_diagram_src=get_diagram_src("box-plot"),
        )

    return app