from typing import Sequence, List, Tuple import torch from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.layers import Conv2d, ShapeSpec, cat, interpolate from detectron2.structures import Instances from detectron2.modeling import ( ROI_HEADS_REGISTRY, ROI_KEYPOINT_HEAD_REGISTRY, BaseKeypointRCNNHead, StandardROIHeads, build_keypoint_head, ) from detectron2.modeling.roi_heads.keypoint_head import keypoint_rcnn_inference from detectron2.utils.events import get_event_storage from .poolers import ROIPooler _TOTAL_SKIPPED = 0 def _keypoints_to_heatmap( keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: Sequence[int] ) -> Tuple[torch.Tensor, torch.Tensor]: """Support non-square heatmap""" if rois.numel() == 0: return rois.new().long(), rois.new().long() offset_x = rois[:, 0] offset_y = rois[:, 1] scale_x = heatmap_size[1] / (rois[:, 2] - rois[:, 0]) scale_y = heatmap_size[0] / (rois[:, 3] - rois[:, 1]) offset_x = offset_x[:, None] offset_y = offset_y[:, None] scale_x = scale_x[:, None] scale_y = scale_y[:, None] x = keypoints[..., 0] y = keypoints[..., 1] x_boundary_inds = x == rois[:, 2][:, None] y_boundary_inds = y == rois[:, 3][:, None] x = (x - offset_x) * scale_x x = x.floor().long() y = (y - offset_y) * scale_y y = y.floor().long() x[x_boundary_inds] = heatmap_size[1] - 1 y[y_boundary_inds] = heatmap_size[0] - 1 valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size[1]) & (y < heatmap_size[0]) vis = keypoints[..., 2] > 0 valid = (valid_loc & vis).long() lin_ind = y * heatmap_size[1] + x heatmaps = lin_ind * valid return heatmaps, valid def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer): """ Arguments: pred_keypoint_logits (Tensor): A tensor of shape (N, K, S, S) where N is the total number of instances in the batch, K is the number of keypoints, and S is the side length of the keypoint heatmap. The values are spatial logits. instances (list[Instances]): A list of M Instances, where M is the batch size. These instances are predictions from the model that are in 1:1 correspondence with pred_keypoint_logits. Each Instances should contain a `gt_keypoints` field containing a `structures.Keypoint` instance. normalizer (float): Normalize the loss by this amount. If not specified, we normalize by the number of visible keypoints in the minibatch. Returns a scalar tensor containing the loss. """ heatmaps = [] valid = [] keypoint_side_len = pred_keypoint_logits.shape[2:] for instances_per_image in instances: if len(instances_per_image) == 0: continue keypoints = instances_per_image.gt_keypoints heatmaps_per_image, valid_per_image = _keypoints_to_heatmap( keypoints.tensor, instances_per_image.proposal_boxes.tensor, keypoint_side_len, ) heatmaps.append(heatmaps_per_image.view(-1)) valid.append(valid_per_image.view(-1)) if len(heatmaps): keypoint_targets = cat(heatmaps, dim=0) valid = cat(valid, dim=0).to(dtype=torch.uint8) valid = torch.nonzero(valid).squeeze(1) # torch.mean (in binary_cross_entropy_with_logits) doesn't # accept empty tensors, so handle it separately if len(heatmaps) == 0 or valid.numel() == 0: global _TOTAL_SKIPPED _TOTAL_SKIPPED += 1 storage = get_event_storage() storage.put_scalar( "kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False ) return pred_keypoint_logits.sum() * 0 N, K, H, W = pred_keypoint_logits.shape pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W) keypoint_loss = F.cross_entropy( pred_keypoint_logits[valid], keypoint_targets[valid], reduction="sum" ) # If a normalizer isn't specified, normalize by the number of visible keypoints in the minibatch if normalizer is None: normalizer = valid.numel() keypoint_loss /= normalizer return keypoint_loss @ROI_HEADS_REGISTRY.register() class KPROIHeads(StandardROIHeads): """Same as StandardROIHeads but use local ROIPooler which supports non-square output size.""" @classmethod def _init_keypoint_head(cls, cfg, input_shape): if not cfg.MODEL.KEYPOINT_ON: return {} in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES pooler_resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO pooler_type = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE in_channels = [input_shape[f].channels for f in in_features][0] ret = {"keypoint_in_features": in_features} ret["keypoint_pooler"] = ROIPooler( output_size=pooler_resolution, scales=pooler_scales, sampling_ratio=sampling_ratio, pooler_type=pooler_type, ) ret["keypoint_head"] = build_keypoint_head( cfg, ShapeSpec( channels=in_channels, width=pooler_resolution, height=pooler_resolution ), ) return ret @ROI_KEYPOINT_HEAD_REGISTRY.register() class KRCNNConvHead(BaseKeypointRCNNHead): """ A standard keypoint head containing a series of 3x3 convs """ @configurable def __init__(self, input_shape, *, num_keypoints, conv_dims, **kwargs): """ NOTE: this interface is experimental. Args: input_shape (ShapeSpec): shape of the input feature conv_dims: an iterable of output channel counts for each conv in the head e.g. (512, 512, 512) for three convs outputting 512 channels. """ super().__init__(num_keypoints=num_keypoints, **kwargs) in_channels = input_shape.channels self.blocks = [] for idx, layer_channels in enumerate(conv_dims, 1): module = Conv2d(in_channels, layer_channels, 3, stride=1, padding=1) self.add_module("conv_fcn{}".format(idx), module) self.blocks.append(module) in_channels = layer_channels self.score_lowres = Conv2d(in_channels, num_keypoints, 3, stride=1, padding=1) self.up_scale = 2 for name, param in self.named_parameters(): if "bias" in name: nn.init.constant_(param, 0) elif "weight" in name: # Caffe2 implementation uses MSRAFill, which in fact # corresponds to kaiming_normal_ in PyTorch nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") nn.init.normal_(self.score_lowres.weight, mean=0.0, std=0.0001) @classmethod def from_config(cls, cfg, input_shape): ret = super().from_config(cfg, input_shape) ret["input_shape"] = input_shape ret["conv_dims"] = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS return ret def layers(self, x): for layer in self.blocks: x = F.relu(layer(x)) x = self.score_lowres(x) x = interpolate( x, scale_factor=self.up_scale, mode="bilinear", align_corners=False ) return x def forward(self, x, instances: List[Instances]): """Use local keypoint_rcnn_loss""" x = self.layers(x) if self.training: num_images = len(instances) normalizer = ( None if self.loss_normalizer == "visible" else num_images * self.loss_normalizer ) return { "loss_keypoint": keypoint_rcnn_loss(x, instances, normalizer=normalizer) * self.loss_weight } else: keypoint_rcnn_inference(x, instances) return instances