diff --git a/lib/model/rpn/proposal_target_layer_cascade.py b/lib/model/rpn/proposal_target_layer_cascade.py index bb4d780bb..2f92f40d3 100644 --- a/lib/model/rpn/proposal_target_layer_cascade.py +++ b/lib/model/rpn/proposal_target_layer_cascade.py @@ -130,8 +130,8 @@ def _sample_rois_pytorch(self, all_rois, gt_boxes, fg_rois_per_image, rois_per_i offset = torch.arange(0, batch_size)*gt_boxes.size(1) offset = offset.view(-1, 1).type_as(gt_assignment) + gt_assignment - labels = gt_boxes[:,:,4].contiguous().view(-1).index((offset.view(-1),)).view(batch_size, -1) - + labels = gt_boxes[:,:,4].contiguous().view(-1)[(offset.view(-1),)].view(batch_size, -1) + labels_batch = labels.new(batch_size, rois_per_image).zero_() rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_() gt_rois_batch = all_rois.new(batch_size, rois_per_image, 5).zero_()