from typing import List import torch from detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple from detectron2.modeling.poolers import ROIPooler as D2ROIPooler from detectron2.modeling.poolers import ( assign_boxes_to_levels, convert_boxes_to_pooler_format, ) class ROIPooler(D2ROIPooler): """Same as detectron2's ROIPooler but support non-square output size""" def forward(self, x: List[torch.Tensor], box_lists): """ Args: x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those used to construct this module. box_lists (list[Boxes] | list[RotatedBoxes]): A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch. The box coordinates are defined on the original image and will be scaled by the `scales` argument of :class:`ROIPooler`. Returns: Tensor: A tensor of shape (M, C, output_size, output_size) where M is the total number of boxes aggregated over all N batch images and C is the number of channels in `x`. """ num_level_assignments = len(self.level_poolers) assert isinstance(x, list) and isinstance( box_lists, list ), "Arguments to pooler must be lists" assert ( len(x) == num_level_assignments ), "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( num_level_assignments, len(x) ) assert len(box_lists) == x[0].size( 0 ), "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format( x[0].size(0), len(box_lists) ) pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists) if num_level_assignments == 1: return self.level_poolers[0](x[0], pooler_fmt_boxes) level_assignments = assign_boxes_to_levels( box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level, ) num_boxes = len(pooler_fmt_boxes) num_channels = x[0].shape[1] dtype, device = x[0].dtype, x[0].device output = torch.zeros( (num_boxes, num_channels, self.output_size[0], self.output_size[1]), dtype=dtype, device=device, ) for level, (x_level, pooler) in enumerate(zip(x, self.level_poolers)): inds = nonzero_tuple(level_assignments == level)[0] pooler_fmt_boxes_level = pooler_fmt_boxes[inds] output[inds] = pooler(x_level, pooler_fmt_boxes_level) return output