SAI-training / src / tools / export.py
export.py
Raw
import os
import json
import argparse

import matplotlib.pyplot as plt


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--directory",
        help="Path to folder containing ground truth prediction pair .json files",
        required=True,
    )
    parser.add_argument(
        "--csv-output", help="Where to save .csv version of data", default=None
    )
    parser.add_argument("--plot", help="Plot results", action="store_true")
    return parser.parse_args()


def load_gt_pred_pairs(directory):
    files = os.listdir(directory)
    image_detections = []
    for file in files:
        if not file[-4:] == "json":
            continue
        with open(os.path.join(directory, file), "r") as read_file:
            image_data = json.load(read_file)
            if '-predictions' in file:
                image_data["image_name"] = file[:-16]
            else:
                image_data["image_name"] = file[:-5]
            image_detections.append(image_data)
    return image_detections


def associate_ids_to_pairs(image_dicts):
    stoma_properties = dict()
    counter = 0
    for image_dict in image_dicts:
        image_name = image_dict["image_name"]
        detections = image_dict["detections"]
        for detection in detections:
            d_width, d_length, d_area, d_class, d_conf = {}, {}, {}, {}, {}
            if "gt" in detection:
                gt = detection["gt"]
                pred = detection["pred"]
                d_width["gt"] = gt["width"]
                d_length["gt"] = gt["length"]
                d_area["gt"] = gt["area"]
                d_class["gt"] = gt["category_id"]
            else:
                pred = detection
            d_width["pred"] = pred["width"]
            d_length["pred"] = pred["length"]
            d_area["pred"] = pred["area"]
            d_class["pred"] = pred["category_id"]
            d_conf["pred"] = pred["confidence"]
            d_image_name = {"pred": image_name}

            property_pairs = {
                "width": d_width,
                "length": d_length,
                "area": d_area,
                "class": d_class,
                "confidence": d_conf,
                "image_name": d_image_name,
            }
            if "gt" in detection:
                stoma_properties[gt["id"]] = property_pairs
            else:
                stoma_properties[counter] = property_pairs
                counter += 1
    return stoma_properties


def extract_property(stoma_id_dict, key):
    property_pairs = []
    for stoma_id in stoma_id_dict.keys():
        property_pair = stoma_id_dict[stoma_id]
        pred_value = property_pair[key]["pred"]
        if "gt" in property_pair[key]:
            gt_value = property_pair[key]["gt"]
            property_pairs.append([gt_value, pred_value])
        else:
            property_pairs.append([pred_value])
    return property_pairs


def write_to_csv(stoma_id_dict, filepath):
    if "gt" in stoma_id_dict[list(stoma_id_dict.keys())[0]]["class"]:
        column_names = [
            "id",
            "image_name",
            "class",
            "pred_class",
            "length",
            "pred_length",
            "width",
            "pred_width",
            "area",
            "pred_area",
            "confidence",
        ]
    else:
        column_names = [
            "id",
            "image_name",
            "pred_class",
            "pred_length",
            "pred_width",
            "pred_area",
            "confidence",
        ]

    column_keys = ["image_name", "class", "length", "width", "area", "confidence"]
    csv = ",".join(column_names) + "\n"
    for key in stoma_id_dict.keys():
        values = [key]
        detection = stoma_id_dict[key]
        for stoma_property in column_keys:
            if 'gt' in detection[stoma_property]:
                values.append(detection[stoma_property]["gt"])
            values.append(detection[stoma_property]["pred"])
        values = [str(x) for x in values]
        csv += ",".join(values) + "\n"

    with open(filepath, "w") as file:
        file.write(csv)


def plot_x_y_line(line_range, title):
    x = [i / 2 for i in range(line_range[0], line_range[1])]
    y = x
    plt.plot(x, y, label="y = x")
    plt.title(title)


def plot_scatter(data, label, colour):
    gt = [point[0] for point in data]
    pred = [point[1] for point in data]
    plt.xlabel("Observed (pixels)")
    plt.ylabel("Predicted (pixels)")
    plt.scatter(gt, pred, label=label, c=colour)
    plt.legend(loc="lower right")


if __name__ == "__main__":
    # Get Commandline arguments
    args = get_arguments()
    # Load predictions and GT from .jsons
    gt_pred_dicts = load_gt_pred_pairs(args.directory)
    # Extract pairs into stomal id lists
    stoma_id_pairs = associate_ids_to_pairs(gt_pred_dicts)
    # Extract individual property pairs
    width_pairs = extract_property(stoma_id_pairs, "width")
    length_pairs = extract_property(stoma_id_pairs, "length")
    class_pairs = extract_property(stoma_id_pairs, "class")
    area_pairs = extract_property(stoma_id_pairs, "area")
    # Covert to csv formatted file
    if not args.csv_output is None:
        write_to_csv(stoma_id_pairs, args.csv_output)

    # Display Plots
    if args.plot:
        plot_x_y_line([0, 140], "Width Predictions")
        plot_scatter(width_pairs, "RCNN v2", "#2ca02c")
        plt.show()
        plot_x_y_line([160, 400], "Length Predictions")
        plot_scatter(length_pairs, "RCNN v2", "#2ca02c")
        plt.show()
        plot_x_y_line([0, 17000], "Area Predictions")
        plot_scatter(area_pairs, "RCNN v2", "#2ca02c")
        plt.show()