import os import json from scipy.stats import iqr import numpy as np from record import is_overlapping, Predicted_Lengths CLOSE_TO_EDGE_DISTANCE = 20 CLOSE_TO_EDGE_SIZE_THRESHOLD = 0.85 SIZE_THRESHOLD = 0.3 def filter_invalid_predictions(predictions): remove_intersecting_predictions(predictions) remove_close_to_edge_detections(predictions) remove_extremley_small_detections(predictions) def remove_intersecting_predictions(predictions): final_indices = [] for i, bbox_i in enumerate(predictions.pred_boxes): intersecting = [ j for j, bbox_j in enumerate(predictions.pred_boxes) if not i == j and is_overlapping(bbox_i, bbox_j) ] is_larger = [ predictions.scores[i].item() >= predictions.scores[j].item() for j in intersecting ] if all(is_larger) or not intersecting: final_indices.append(i) select_predictions(predictions, final_indices) def select_predictions(predictions, indices): predictions.pred_boxes.tensor = predictions.pred_boxes.tensor[indices] predictions.pred_classes = predictions.pred_classes[indices] predictions.pred_masks = predictions.pred_masks[indices] predictions.pred_keypoints = predictions.pred_keypoints[indices] def remove_close_to_edge_detections(predictions): average_area = calculate_average_bbox_area(predictions) image_height, image_width = predictions.image_size final_indices = [] for i, bbox_i in enumerate(predictions.pred_boxes): is_near_edge = is_bbox_near_edge(bbox_i, image_height, image_width) is_significantly_smaller_than_average = is_bbox_small(bbox_i, average_area) if is_near_edge and is_significantly_smaller_than_average: continue else: final_indices.append(i) select_predictions(predictions, final_indices) def calculate_average_bbox_area(predictions): areas = [ calculate_bbox_area(bbox) for bbox in predictions.pred_boxes ] n_bbox = len(areas) if n_bbox == 0: return 0 return sum(areas) / n_bbox def calculate_bbox_area(bbox): width = abs(bbox[2] - bbox[0]) height = abs(bbox[3] - bbox[1]) return width * height def is_bbox_near_edge(bbox, image_height, image_width): x1, y1, x2, y2 = bbox is_near_edge = any([ x1 < CLOSE_TO_EDGE_DISTANCE, y1 < CLOSE_TO_EDGE_DISTANCE, image_width - x2 < CLOSE_TO_EDGE_DISTANCE, image_height - y2 < CLOSE_TO_EDGE_DISTANCE, ]) return is_near_edge def is_bbox_small(bbox, average_area): threshold_area = CLOSE_TO_EDGE_SIZE_THRESHOLD * average_area bbox_area = calculate_bbox_area(bbox) return bbox_area < threshold_area def remove_extremley_small_detections(predictions): average_area = calculate_average_bbox_area(predictions) final_indices = [] for i, bbox_i in enumerate(predictions.pred_boxes): if is_bbox_extremley_small(bbox_i, average_area): continue else: final_indices.append(i) select_predictions(predictions, final_indices) def is_bbox_extremley_small(bbox, average_area): return calculate_bbox_area(bbox) < average_area * SIZE_THRESHOLD def remove_outliers_from_records(output_directory): for filename in os.listdir(output_directory): if ".json" in filename: if "-gt" in filename: continue filepath = os.path.join(output_directory, filename) record = load_json(filepath) record = record['detections'] if not "-prediction" in filename: predictions = unpack_predictions(record) else: predictions = record length_predictions = extract_lengths(predictions) before = len(record) remove_outliers(length_predictions, record) write_to_json(record, filepath) def load_json(filepath): with open(filepath, "r") as file: record = json.load(file) return record def unpack_predictions(record): predictions = [] for detection in record: predictions.append(detection['pred']) return predictions def extract_lengths(predictions): lengths = [] for prediction in predictions: lengths.append(prediction['length']) return lengths def remove_outliers(lengths, record): to_remove = find_outlier_indices(lengths) for index in reversed(to_remove): record.pop(index) def find_outlier_indices(lengths): to_remove = [] lower_bound, upper_bound = calculate_length_limits() for i, length in enumerate(lengths): if length < lower_bound: to_remove.append(i) return to_remove def calculate_length_limits(): inter_quartile_range = iqr(Predicted_Lengths, interpolation='midpoint') median = np.median(Predicted_Lengths) lower_whisker = median - 2.0 * inter_quartile_range higher_whisker = median + 2.0 * inter_quartile_range return lower_whisker, higher_whisker def write_to_json(record, filepath): with open(filepath, "w") as file: json.dump({"detections": record}, file)