Skip to content

Commit d3990d4

Browse files
nms kernel optimization (#25)
enable vectorized nms_kernel
1 parent 0626548 commit d3990d4

File tree

6 files changed

+288
-39
lines changed

6 files changed

+288
-39
lines changed

intel_pytorch_extension_py/ops/nms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22

3-
nms = torch.ops.torch_ipex.nms
3+
def nms(dets, scores, threshold, sorted=False):
4+
return torch.ops.torch_ipex.nms(dets, scores, threshold, sorted)
45
batch_score_nms = torch.ops.torch_ipex.batch_score_nms
56
parallel_scale_back_batch = torch.ops.torch_ipex.parallel_scale_back_batch

tests/cpu/data/nms_plabel.pt

4.68 MB
Binary file not shown.

tests/cpu/data/nms_ploc.pt

237 KB
Binary file not shown.

tests/cpu/test_nms.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import intel_pytorch_extension as ipex
55
from common_utils import TestCase
66
import time, sys
7-
from intel_pytorch_extension import batch_score_nms, parallel_scale_back_batch
7+
from intel_pytorch_extension import batch_score_nms, parallel_scale_back_batch, nms
88
import torch.nn.functional as F
99
import os
1010

@@ -118,17 +118,16 @@ def decode_single(self, bboxes_in, scores_in, criteria, max_output, max_num=200)
118118
max_ids = max_ids[-max_output:]
119119
return bboxes_out[max_ids, :], labels_out[max_ids], scores_out[max_ids]
120120

121-
def test_nms_result(self):
121+
def test_batch_nms_result(self):
122122
batch_size = 1
123123
number_boxes = 15130
124124
scale_xy = 0.1
125125
scale_wh = 0.2
126126
criteria = 0.50
127127
max_output = 200
128-
predicted_loc = torch.randn((batch_size, number_boxes, 4)).contiguous().to(torch.float32)
129-
predicted_score = torch.randn((batch_size, number_boxes, 81)).contiguous().to(torch.float32)
130-
dboxes_xywh = torch.randn((1, number_boxes, 4)).contiguous().to(torch.float64)
131-
dboxes_xywh = torch.load(os.path.dirname(__file__) + "/data/nms_dboxes_xywh.pt")
128+
predicted_loc = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")) # sizes: [1, 15130, 4]
129+
predicted_score = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")) # sizes: [1, 15130, 81]
130+
dboxes_xywh = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt"))
132131
bboxes, probs = parallel_scale_back_batch(predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh)
133132
bboxes_clone = bboxes.clone()
134133
probs_clone = probs.clone()
@@ -147,5 +146,46 @@ def test_nms_result(self):
147146
self.assertEqual(label, label2)
148147
self.assertTrue(torch.allclose(prob, prob2, rtol=1e-4, atol=1e-4))
149148

149+
def test_nms_kernel_result(self):
150+
batch_size = 1
151+
class_number = 81
152+
scale_xy = 0.1
153+
scale_wh = 0.2
154+
criteria = 0.50
155+
max_output = 200
156+
predicted_loc = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_ploc.pt")) # sizes: [1, 15130, 4]
157+
predicted_score = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_plabel.pt")) # sizes: [1, 15130, 81]
158+
dboxes_xywh = torch.load(os.path.join(os.path.dirname(__file__), "data/nms_dboxes_xywh.pt"))
159+
bboxes, probs = parallel_scale_back_batch(predicted_loc, predicted_score, dboxes_xywh, scale_xy, scale_wh)
160+
161+
for bs in range(batch_size):
162+
loc = bboxes[bs].squeeze(0)
163+
for class_id in range(class_number):
164+
if class_id == 0:
165+
# Skip the background
166+
continue
167+
score = probs[bs, :, class_id]
168+
169+
score_sorted, indices = torch.sort(score, descending=True)
170+
loc_sorted = torch.index_select(loc, 0, indices)
171+
172+
result = nms(loc_sorted.clone(), score_sorted.clone(), criteria, True)
173+
result_ref = nms(loc.clone(), score.clone(), criteria, False)
174+
result_ref2 = nms(loc_sorted.clone().to(dtype=torch.float64), score_sorted.clone().to(dtype=torch.float64), criteria, True)
175+
176+
bbox_keep, _ = torch.sort(torch.index_select(loc_sorted, 0, result).squeeze(0), 0)
177+
bbox_keep_ref, _ = torch.sort(torch.index_select(loc, 0, result_ref).squeeze(0), 0)
178+
bbox_keep_ref2, _ = torch.sort(torch.index_select(loc_sorted, 0, result_ref2).squeeze(0), 0)
179+
180+
score_keep, _ = torch.sort(torch.index_select(score_sorted, 0, result).squeeze(0), 0)
181+
score_keep_ref, _ = torch.sort(torch.index_select(score, 0, result_ref).squeeze(0), 0)
182+
score_keep_ref2, _ = torch.sort(torch.index_select(score_sorted, 0, result_ref2).squeeze(0), 0)
183+
184+
self.assertEqual(result.size(0), result_ref.size(0))
185+
self.assertTrue(torch.allclose(bbox_keep, bbox_keep_ref, rtol=1e-4, atol=1e-4))
186+
self.assertTrue(torch.allclose(score_keep, score_keep_ref, rtol=1e-4, atol=1e-4))
187+
self.assertTrue(torch.allclose(bbox_keep, bbox_keep_ref2, rtol=1e-4, atol=1e-4))
188+
self.assertTrue(torch.allclose(score_keep, score_keep_ref2, rtol=1e-4, atol=1e-4))
189+
150190
if __name__ == '__main__':
151191
test = unittest.main()

torch_ipex/csrc/cpu/ExtendOPs.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,21 @@ class AtenIpexTypeExt {
2626
const int64_t height,
2727
const int64_t width,
2828
const int64_t sampling_ratio);
29-
29+
30+
/// \brief Perform non-maximum suppression.
31+
///
32+
/// \param dets: predicted loc in ltrb format for one batchsize, size [number_boxes, 4], for example: [200, 4].
33+
/// \param scores: predicted score for one batchsize and one class, size [number_boxes], for example: [200].
34+
/// \param threshold: IOU threshold(scalar) to suppress bboxs which has the IOU val larger than the threshold.
35+
/// \param sorted: The score and dets are already sorted in Descending order.
36+
///
37+
/// \return result is a Tensor of dets' indexs to be keeped.
3038
static at::Tensor nms(const at::Tensor& dets,
3139
const at::Tensor& scores,
32-
const double threshold);
40+
const double threshold,
41+
const bool sorted);
3342

34-
/// \brief Perform non-maximum suppression.
43+
/// \brief Perform batch non-maximum suppression.
3544
///
3645
/// C++ version of Encoder::decode_single.
3746
/// Refer to https://github.com/mlcommons/inference/blob/v0.7/others/cloud/single_stage_detector/pytorch/utils.py.

0 commit comments

Comments
 (0)